""" NaFlex Vision Transformer An improved version of the Vision Transformer with: 1. Encapsulated embedding and position encoding in a single module 2. Support for linear patch embedding on pre-patchified inputs 3. Support for NaFlex variable aspect, variable resolution 4. Support for FlexiViT variable patch size 5. Support for NaViT fractional/factorized position embedding Based on ideas from: - Original Vision Transformer: https://arxiv.org/abs/2010.11929 - FlexiViT: https://arxiv.org/abs/2212.08013 - NaViT: https://arxiv.org/abs/2307.06304 - NaFlex (SigLip-2): https://arxiv.org/abs/2502.14786 Hacked together by / Copyright 2025, Ross Wightman, Hugging Face """ import logging import math from dataclasses import dataclass, fields, replace from functools import partial from typing import Callable, Dict, List, Optional, Set, Tuple, Type, Union, Any import torch import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from timm.layers import ( AttentionPoolLatent, Mlp, LayerNorm, PatchDropoutWithIndices, PatchEmbedInterpolator, _assert, to_2tuple, get_act_layer, get_norm_layer, apply_keep_indices_nlc, disable_compiler, calculate_drop_path_rates, ) from ._builder import build_model_with_cfg from ._features import feature_take_indices from ._features_fx import register_notrace_function, register_notrace_module from ._manipulate import checkpoint, named_apply from ._registry import register_model, generate_default_cfgs from .eva import EvaBlock from .vision_transformer import Block, global_pool_nlc __all__ = ['NaFlexVitCfg', 'NaFlexVit'] _logger = logging.getLogger(__name__) @dataclass class NaFlexVitCfg: """Configuration for FlexVit model. This dataclass contains the bulk of model configuration parameters, with core parameters (img_size, in_chans, num_classes, etc.) remaining as direct constructor arguments for API compatibility. """ # Architecture parameters patch_size: Union[int, Tuple[int, int]] = 16 embed_dim: int = 768 depth: int = 12 num_heads: int = 12 mlp_ratio: float = 4.0 scale_mlp_norm: bool = False # Apply scaling norm to MLP # Attention parameters qkv_bias: bool = True qk_norm: bool = False proj_bias: bool = True attn_drop_rate: float = 0.0 scale_attn_inner_norm: bool = False # Apply scaling norm to attn context # Regularization init_values: Optional[float] = None # Layer-scale init values (layer-scale enabled if not None) drop_rate: float = 0.0 # Dropout rate for classifier pos_drop_rate: float = 0.0 # Dropout rate for position embeddings patch_drop_rate: float = 0.0 # Dropout rate for patch tokens proj_drop_rate: float = 0.0 # Dropout rate for linear projections drop_path_rate: float = 0.0 # Stochastic depth drop rate # Prefix token configuration class_token: bool = False # Use class token reg_tokens: int = 0 # Number of register tokens # Position embedding configuration pos_embed: str = 'learned' # Type of position embedding ('learned', 'factorized', 'rope', 'none') pos_embed_grid_size: Optional[Tuple[int, int]] = (16, 16) # Grid size for position embedding initialization pos_embed_interp_mode: str = 'bicubic' # Interpolation mode for position embedding resizing pos_embed_ar_preserving: bool = False # Whether to preserve aspect ratio during position embedding interpolation pos_embed_use_grid_sample: bool = False # Whether to use grid_sample for naflex position embedding interpolation # ROPE specific configuration rope_type: str = '' # ROPE type: '' or 'none' for no ROPE, 'axial' for standard, 'mixed' for learnable frequencies rope_temperature: float = 10000.0 # Temperature for ROPE frequency computation rope_ref_feat_shape: Optional[Tuple[int, int]] = None rope_grid_offset: float = 0. # Grid offset for non-pixel ROPE mode rope_grid_indexing: str = 'ij' # Grid indexing mode for ROPE ('ij' or 'xy') # Image processing dynamic_img_pad: bool = False # Whether to enable dynamic padding for variable resolution # Other architecture choices pre_norm: bool = False # Whether to apply normalization before attention/MLP layers (start of blocks) final_norm: bool = True # Whether to apply final normalization before pooling and classifier (end of blocks) fc_norm: Optional[bool] = None # Whether to normalize features before final classifier (after pooling) # Global pooling setup global_pool: str = 'map' # Type of global pooling for final sequence pool_include_prefix: bool = False # Whether to include class/register prefix tokens in global pooling attn_pool_num_heads: Optional[int] = None # Override num_heads for attention pool attn_pool_mlp_ratio: Optional[float] = None # Override mlp_ratio for attention pool # Weight initialization weight_init: str = '' # Weight initialization scheme fix_init: bool = True # Apply weight initialization fix (scaling w/ layer index) # Embedding configuration embed_proj_type: str = 'linear' # Type of embedding layer ('conv' or 'linear') input_norm_layer: Optional[str] = None # Normalization layer for embeddings input (before input projection) embed_norm_layer: Optional[str] = None # Normalization layer for embeddings (after input projection) # Layer implementations norm_layer: Optional[str] = None # Normalization layer for transformer blocks act_layer: Optional[str] = None # Activation layer for MLP blocks block_fn: Optional[str] = None # Transformer block implementation class name mlp_layer: Optional[str] = None # MLP implementation class name # EVA-specific parameters attn_type: str = 'standard' # Attention type: 'standard', 'eva', 'rope' swiglu_mlp: bool = False # Use SwiGLU MLP variant qkv_fused: bool = True # Whether to use fused QKV projections # Variable patch size support enable_patch_interpolator: bool = False # Enable dynamic patch size support def _overlay_kwargs(cfg: NaFlexVitCfg, **kwargs) -> NaFlexVitCfg: """Overlay kwargs onto config, replacing config values with provided kwargs.""" # Only update fields that exist in the config config_fields = set(cfg.__dataclass_fields__.keys()) config_kwargs = {k: v for k, v in kwargs.items() if k in config_fields} if config_kwargs: cfg = replace(cfg, **config_kwargs) return cfg def batch_patchify( x: torch.Tensor, patch_size: Tuple[int, int], pad: bool = True, ) -> Tuple[torch.Tensor, Tuple[int, int]]: """Patchify a batch of images. Args: x: Input tensor of shape [B, C, H, W]. patch_size: Patch dimensions (patch_h, patch_w). pad: Whether to pad images to be divisible by patch size. Returns: Tuple of (patches, grid_size) where patches has shape [B, N, P*P*C] and grid_size is (num_patches_h, num_patches_w). """ B, C, H, W = x.shape ph, pw = patch_size # Ensure the image is divisible by patch size if pad and (H % ph != 0 or W % pw != 0): pad_h = (ph - H % ph) % ph pad_w = (pw - W % pw) % pw x = F.pad(x, (0, pad_w, 0, pad_h)) nh, nw = H // ph, W // pw patches = x.view(B, C, nh, ph, nw, pw).permute(0, 2, 4, 3, 5, 1).reshape(B, nh * nw, ph * pw * C) # FIXME confirm we want 'channels last' in the patch channel layout, egg ph, ph, C instead of C, ph, hw return patches, (nh, nw) def calculate_naflex_grid_sizes(_coord: torch.Tensor): # Calculate the appropriate grid size from coords max_y = _coord[:, :, 0].amax(dim=1) + 1 max_x = _coord[:, :, 1].amax(dim=1) + 1 return [(int(h.item()), int(w.item())) for h, w in zip(max_y, max_x)] class NaFlexRopeIterator: """Iterator for generating batched ROPE embeddings for mixed mode with multiple grid sizes.""" def __init__( self, rope_module, size_to_indices: Dict[Tuple[int, int], List[int]], unique_sizes: List[Tuple[int, int]], batch_size: int, seq_len: int, device: torch.device, dtype: torch.dtype, ): self.rope = rope_module self.size_to_indices = size_to_indices self.unique_sizes = unique_sizes self.batch_size = batch_size self.seq_len = seq_len self.dtype = dtype self.device = device self.depth = rope_module.depth self.num_heads = rope_module.num_heads self.head_dim = 2 * rope_module.dim // rope_module.num_heads self._depth_idx = 0 # Pre-compute embeddings for each unique size self._embeddings_per_size = {} for grid_size in unique_sizes: # get_embed returns all depths at once for mixed mode rope_embed = rope_module.get_embed(shape=grid_size) self._embeddings_per_size[grid_size] = rope_embed def __iter__(self): self._depth_idx = 0 return self @disable_compiler def __next__(self): if self._depth_idx >= self.depth: raise StopIteration # Create batch tensor for current depth batch_embed = torch.zeros( self.batch_size, self.num_heads, self.seq_len, self.head_dim, dtype=self.dtype, device=self.device ) # Fill in embeddings for each unique grid size for grid_size in self.unique_sizes: h, w = grid_size actual_len = h * w batch_indices = self.size_to_indices[grid_size] # Get pre-computed embeddings for this size at current depth embed = self._embeddings_per_size[grid_size][self._depth_idx] # [num_heads, H*W, dim] # Assign to batch indices for bi in batch_indices: batch_embed[bi, :, :actual_len, :] = embed[:, :actual_len, :] self._depth_idx += 1 return batch_embed def get_block_fn(cfg: NaFlexVitCfg) -> Callable: """Get appropriate block function based on configuration. Returns a partially applied block constructor with EVA-specific or conflicting parameters pre-configured if needed. """ # Check if we need EVA block features use_eva_features = ( cfg.attn_type in ('eva', 'rope') or cfg.rope_type not in ('', 'none') or # Any ROPE type requires EVA blocks cfg.swiglu_mlp ) if use_eva_features: # Determine attention type based on rope_type if not explicitly set attn_type = cfg.attn_type if attn_type == 'standard' and cfg.rope_type not in ('', 'none'): attn_type = 'rope' num_prefix_tokens = (1 if cfg.class_token else 0) + cfg.reg_tokens return partial( EvaBlock, attn_type=attn_type, swiglu_mlp=cfg.swiglu_mlp, scale_mlp=cfg.scale_mlp_norm, scale_attn_inner=cfg.scale_attn_inner_norm, qkv_fused=cfg.qkv_fused, num_prefix_tokens=num_prefix_tokens, ) else: # Standard ViT block block_fn = cfg.block_fn or Block if cfg.scale_mlp_norm or cfg.scale_attn_inner_norm: # param names differ between EVA vs non-EVA block types block_fn = partial( block_fn, scale_mlp_norm=cfg.scale_mlp_norm, scale_attn_norm=cfg.scale_attn_inner_norm ) return block_fn @register_notrace_module class NaFlexEmbeds(nn.Module): """NaFlex Embedding module for Vision Transformers. This module encapsulates the complete embedding process for Vision Transformers, supporting both standard and NaFlex (NaViT + FlexiViT) functionality: 1. Patch embedding (via Conv2d or Linear) 2. Class and register token preparation 3. Position embedding addition with interpolation support 4. Pre-normalization (if requested) 5. Dropout application NaFlex capabilities include: - Variable aspect ratio and resolution via patch coordinates - Patch type indicators for handling padding tokens in attention - Flexible position embedding interpolation for arbitrary grid sizes - Support for factorized position embeddings The patch embedding can be one of two types: - Conv2d-based (default): For standard image inputs [B, C, H, W] - Linear-based: For pre-patchified inputs [B, N, P*P*C] Args: patch_size: Size of patches for patch embedding in_chans: Number of input image channels embed_dim: Dimensionality of patch embedding proj_type: Type of embedding projection layer ('conv' or 'linear') input_norm_layer: Normalization layer applied to input (linear mode only) proj_norm_layer: Normalization layer applied after projection pos_embed: Type of position embedding ('learned', 'factorized', 'none') pos_drop_rate: Dropout rate for position embeddings class_token: Whether to include a class token reg_tokens: Number of register tokens to include bias: Whether to use bias in projection layers dynamic_img_pad: Whether to enable dynamic padding for variable resolution pos_embed_grid_size: Grid size for position embedding initialization pos_embed_interp_mode: Interpolation mode for position embedding resizing pos_embed_ar_preserving: Whether to preserve aspect ratio during position embedding interpolation default_img_size: Default image size for position embedding grid calculation """ def __init__( self, patch_size: Union[int, Tuple[int, int]] = 16, in_chans: int = 3, embed_dim: int = 768, proj_type: Optional[str] = None, proj_bias: bool = True, class_token: bool = True, reg_tokens: int = 0, dynamic_img_pad: bool = False, default_img_size: Optional[Union[int, Tuple[int, int]]] = None, pos_embed: str = 'learned', pos_embed_grid_size: Optional[Tuple[int, int]] = (14, 14), pos_embed_interp_mode: str = 'bicubic', pos_embed_ar_preserving: bool = False, pos_embed_use_grid_sample: bool = False, input_norm_layer: Optional[Type[nn.Module]] = None, proj_norm_layer: Union[bool, Optional[Type[nn.Module]]] = None, norm_layer: Optional[Type[nn.Module]] = None, pos_drop_rate: float = 0., enable_patch_interpolator: bool = False, device=None, dtype=None, ) -> None: """Initialize NaFlexEmbeds module. Args: patch_size: Size of patches for patch embedding. in_chans: Number of input image channels. embed_dim: Dimensionality of patch embedding. proj_type: Type of embedding projection layer ('conv' or 'linear'). proj_bias: Whether to use bias in projection layers. class_token: Whether to include a class token. reg_tokens: Number of register tokens to include. dynamic_img_pad: Whether to enable dynamic padding for variable resolution. default_img_size: Default image size for position embedding grid calculation. pos_embed: Type of position embedding ('learned', 'factorized', 'none'). pos_embed_grid_size: Grid size for position embedding initialization. pos_embed_interp_mode: Interpolation mode for position embedding resizing. pos_embed_ar_preserving: Whether to preserve aspect ratio during interpolation. input_norm_layer: Normalization layer applied to input (linear mode only). proj_norm_layer: Normalization layer applied after projection. norm_layer: Default normalization layer. pos_drop_rate: Dropout rate for position embeddings. enable_patch_interpolator: Enable dynamic patch size support. """ dd = {'device': device, 'dtype': dtype} super().__init__() self.has_class_token = class_token self.num_reg_tokens = reg_tokens self.pos_embed_interp_mode = pos_embed_interp_mode self.pos_embed_ar_preserving = pos_embed_ar_preserving self.pos_embed_use_grid_sample = pos_embed_use_grid_sample self.patch_size = to_2tuple(patch_size) self.in_chans = in_chans self.embed_dim = embed_dim self.dynamic_img_pad = dynamic_img_pad self.enable_patch_interpolator = enable_patch_interpolator # Calculate number of prefix tokens self.num_prefix_tokens = 1 if class_token else 0 self.num_prefix_tokens += reg_tokens # Create class and register tokens self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim, **dd)) if class_token else None self.reg_token = nn.Parameter(torch.zeros(1, reg_tokens, embed_dim, **dd)) if reg_tokens else None # Calculate grid size and number of patches self.default_img_size: Optional[Tuple[int, int]] = None self.pos_embed_grid_size: Optional[Tuple[int, int]] = None # Grid size used for learned pos embed init if pos_embed_grid_size is not None: # Highest priority, use provided pos_embed_grid_size self.pos_embed_grid_size = pos_embed_grid_size elif default_img_size is not None: # Fallback to calculating grid size from img_size + patch_size if img size provided. self.default_img_size = to_2tuple(default_img_size) self.pos_embed_grid_size = tuple([s // p for s, p in zip(self.default_img_size, self.patch_size)]) # Determine patch embedding type (linear or conv2d) if proj_type == 'linear': # Create linear projection for pre-patchified inputs # Input dimension is patch_size^2 * in_chans patch_dim = self.patch_size[0] * self.patch_size[1] * in_chans assert not (input_norm_layer is True and norm_layer is None), \ "`norm_layer` must be given when input_norm_layer=True" input_norm_layer = norm_layer if input_norm_layer is True else (input_norm_layer or None) self.norm_input = input_norm_layer(patch_dim) if input_norm_layer else None self.proj = nn.Linear(patch_dim, embed_dim, bias=proj_bias, **dd) self.flatten = False self.is_linear = True else: # Default to convolutional patch embedding for image inputs assert not input_norm_layer self.norm_input = None self.proj = nn.Conv2d( in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=proj_bias, **dd, ) self.flatten = True self.is_linear = False # Create patch embedding interpolator if enabled if self.enable_patch_interpolator: self.patch_interpolator = PatchEmbedInterpolator( base_patch_size=self.patch_size, in_chans=in_chans, embed_dim=embed_dim, interpolation=pos_embed_interp_mode, antialias=True, ) else: self.patch_interpolator = None # Create normalization layer after the projection assert not (proj_norm_layer is True and norm_layer is None), \ "`norm_layer` must be given when proj_norm_layer=True" proj_norm_layer = norm_layer if proj_norm_layer is True else (proj_norm_layer or None) self.norm = proj_norm_layer(embed_dim) if proj_norm_layer else nn.Identity() # Create position embedding if needed - only for patches, never for prefix tokens if pos_embed in ('factorized', 'learned') and self.pos_embed_grid_size is None: raise ValueError( "Cannot initialize position embeddings without grid_size." "Please provide img_size or pos_embed_grid_size.") self.pos_embed: Optional[torch.Tensor] = None self.pos_embed_y: Optional[torch.Tensor] = None self.pos_embed_x: Optional[torch.Tensor] = None if not pos_embed or pos_embed == 'none': self.pos_embed_type = 'none' elif pos_embed == 'factorized': assert self.pos_embed_grid_size is not None h, w = self.pos_embed_grid_size self.pos_embed_type = 'factorized' self.pos_embed_y = nn.Parameter(torch.randn(1, h, embed_dim, **dd) * .02) self.pos_embed_x = nn.Parameter(torch.randn(1, w, embed_dim, **dd) * .02) else: assert self.pos_embed_grid_size is not None h, w = self.pos_embed_grid_size self.pos_embed = nn.Parameter(torch.randn(1, h, w, embed_dim, **dd) * .02) self.pos_embed_type = 'learned' # Dropout layer self.pos_drop = nn.Dropout(p=pos_drop_rate) def feature_info(self, location) -> Dict[str, Any]: """Get feature information for feature extraction. Args: location: Feature extraction location identifier Returns: Dictionary containing feature channel count and reduction factor """ return dict(num_chs=self.embed_dim, reduction=self.patch_size) def feat_ratio(self, as_scalar: bool = True) -> Union[int, Tuple[int, int]]: """Get the feature reduction ratio (stride) of the patch embedding. Args: as_scalar: Whether to return the maximum dimension as a scalar Returns: Feature reduction ratio as scalar or tuple """ if as_scalar: return max(self.patch_size) else: return self.patch_size def dynamic_feat_size(self, img_size: Tuple[int, int]) -> Tuple[int, int]: """Calculate grid (feature) size for given image size. Takes into account dynamic padding when enabled. Args: img_size: Input image size as (height, width) Returns: Grid size as (grid_height, grid_width) """ if self.dynamic_img_pad: return math.ceil(img_size[0] / self.patch_size[0]), math.ceil(img_size[1] / self.patch_size[1]) else: return img_size[0] // self.patch_size[0], img_size[1] // self.patch_size[1] @disable_compiler def _apply_learned_naflex_pos_embed( self, x: torch.Tensor, patch_coord: torch.Tensor, ) -> None: """Apply learned position embeddings to NaFlex batch in-place. Interpolates learned 2D position embeddings for each sample in the batch based on their individual grid sizes. Args: x: Input tensor to add position embeddings to [B, N, C] patch_coord: Patch coordinates [B, N, 2] with (y, x) values """ # Calculate grid sizes from patch coordinates naflex_grid_sizes = calculate_naflex_grid_sizes(patch_coord) orig_h, orig_w = self.pos_embed.shape[1:3] pos_embed_nchw = self.pos_embed.permute(0, 3, 1, 2).float() # B,C,H,W def _interp2d(size): """ Return a flattened positional-embedding grid at an arbitrary spatial resolution. Converts the learned 2-D table stored in NCHW format (pos_embed_nchw) into a (1, H*W, C) sequence that matches the requested size. """ if (size[0] == orig_h) and (size[1] == orig_w): pos_embed_flat = self.pos_embed.reshape(1, orig_h * orig_w, -1) else: _interp_size = to_2tuple(max(size)) if self.pos_embed_ar_preserving else size pos_embed_flat = F.interpolate( pos_embed_nchw, size=_interp_size, mode=self.pos_embed_interp_mode, align_corners=False, antialias=True, )[:, :, :size[0], :size[1]].flatten(2).transpose(1, 2) return pos_embed_flat.to(dtype=x.dtype) # Determine unique grid sizes to avoid duplicate interpolation size_to_indices: Dict[Tuple[int, int], List[int]] = {} for bi, k in enumerate(naflex_grid_sizes): # k = h << 16 | w # FIXME can get jit compat with this size_to_indices.setdefault(k, []).append(bi) for k, batch_indices in size_to_indices.items(): # h, w = k >> 16, k & 0xFFFF # FIXME can get jit compat with this # Interpolate only once for this (h, w) pos_embed_flat = _interp2d(k) seq_len = min(x.shape[1], pos_embed_flat.shape[1]) x[:, :seq_len].index_add_( 0, torch.as_tensor(batch_indices, device=x.device), pos_embed_flat[:, :seq_len].expand(len(batch_indices), -1, -1) ) @disable_compiler def _apply_learned_naflex_pos_embed_grid_sample( self, x: torch.Tensor, patch_coord: torch.Tensor, ) -> None: """Apply learned position embeddings to NaFlex batch using grid_sample. Uses F.grid_sample for efficient interpolation of learned 2D position embeddings based on patch coordinates. Based on proposal by https://github.com/stas-sl Args: x: Input tensor to add position embeddings to [B, N, C] patch_coord: Patch coordinates [B, N, 2] with (y, x) values """ device = x.device B, N, C = x.shape shapes = patch_coord.max(dim=1).values + 1 # (B, 2) containing [h_i, w_i] if self.pos_embed_ar_preserving: L_i = shapes.amax(dim=1) # (B,) max(h_i, w_i) L_global = L_i.amax() grid_size_y = grid_size_x = L_global scale_x = scale_y = L_global / L_i # uniform zoom (B,) else: grid_size_y, grid_size_x = shapes.amax(dim=0) # (2,) scale_y = grid_size_y / shapes[:, 0] # vertical zoom (B,) scale_x = grid_size_x / shapes[:, 1] # horizontal zoom (B,) theta = torch.zeros(B, 2, 3, device=device, dtype=torch.float32) theta[:, 0, 0] = scale_x theta[:, 1, 1] = scale_y theta[:, 0, 2] = scale_x - 1 # translate x theta[:, 1, 2] = scale_y - 1 # translate y grid = F.affine_grid(theta, (B, C, grid_size_y, grid_size_x), align_corners=False) pos_embed = F.grid_sample( self.pos_embed.permute(0, 3, 1, 2).expand(B, -1, -1, -1).float(), grid, mode=self.pos_embed_interp_mode, align_corners=False, padding_mode='border', ).to(dtype=x.dtype) # (B, C, H_out, W_out) bi = torch.arange(B, device=device, dtype=torch.long).unsqueeze(1) x += pos_embed[bi, :, patch_coord[..., 0], patch_coord[..., 1]] # NOTE leave as '+=' def _apply_learned_pos_embed( self, x: torch.Tensor, grid_size: List[int], ) -> None: """Apply learned position embeddings to standard 2D batch in-place. Interpolates learned 2D position embeddings to match the specified grid size. Args: x: Input tensor to add position embeddings to [B, H*W, C] grid_size: Target grid size as [height, width] """ orig_h, orig_w = self.pos_embed.shape[1:3] if grid_size[0] == orig_h and grid_size[1] == orig_w: # No resize needed, just flatten pos_embed_flat = self.pos_embed.reshape(1, orig_h * orig_w, -1) else: # Resize if needed - directly using F.interpolate if self.pos_embed_ar_preserving: L = max(grid_size) _interp_size = L, L else: _interp_size = grid_size pos_embed_flat = F.interpolate( self.pos_embed.permute(0, 3, 1, 2).float(), # B,C,H,W size=_interp_size, mode=self.pos_embed_interp_mode, align_corners=False, antialias=True, )[:, :, :grid_size[0], :grid_size[1]].flatten(2).transpose(1, 2) pos_embed_flat = pos_embed_flat.to(dtype=x.dtype) x.add_(pos_embed_flat) @disable_compiler def _apply_factorized_naflex_pos_embed( self, x: torch.Tensor, patch_coord: torch.Tensor, ) -> None: """Apply factorized position embeddings to NaFlex batch in-place. Uses separate Y and X position embedding tables that are interpolated and combined for each sample's grid size. Args: x: Input tensor to add position embeddings to [B, N, C] patch_coord: Patch coordinates [B, N, 2] with (y, x) values """ # Calculate grid sizes from patch coordinates naflex_grid_sizes = calculate_naflex_grid_sizes(patch_coord) assert len(naflex_grid_sizes) == x.size(0) # one (H,W) per sample # Handle each batch element separately with its own grid size orig_h, orig_w = self.pos_embed_y.shape[1], self.pos_embed_x.shape[1] # bucket samples that share the same (H, W) so we build each grid once size_to_indices: Dict[Tuple[int, int], List[int]] = {} for bi, k in enumerate(naflex_grid_sizes): size_to_indices.setdefault(k, []).append(bi) def _interp1d(table: torch.Tensor, new_length: int, orig_length: int) -> torch.Tensor: """ Resample a 1-D positional-embedding table to specified length and return it in (1, L, C) layout, dtype matching x. """ if new_length == orig_length: return table.to(dtype=x.dtype) return F.interpolate( table.permute(0, 2, 1).float(), # (1,C,L) → (1,C,L_out) size=new_length, mode='linear', align_corners=False, ).permute(0, 2, 1).to(dtype=x.dtype) # → (1,L_out,C) for k, batch_indices in size_to_indices.items(): target_h, target_w = k if self.pos_embed_ar_preserving: len_y = len_x = max(target_h, target_w) else: len_y, len_x = target_h, target_w pe_y = _interp1d(self.pos_embed_y, len_y, orig_h)[:, :target_h] # (1,H,C) pe_x = _interp1d(self.pos_embed_x, len_x, orig_w)[:, :target_w] # (1,W,C) # Broadcast, add and flatten to sequence layout (row major) pos = pe_y.unsqueeze(2) + pe_x.unsqueeze(1) # (1,H,W,C) pos = pos.flatten(1, 2) seq_len = min(x.shape[1], pos.shape[1]) x[:, :seq_len].index_add_( 0, torch.as_tensor(batch_indices, device=x.device), pos[:, :seq_len].expand(len(batch_indices), -1, -1) ) @disable_compiler def _apply_factorized_naflex_pos_embed_grid_sample( self, x: torch.Tensor, patch_coord: torch.Tensor, ) -> None: """Apply factorized position embeddings to NaFlex batch using grid_sample. Uses F.grid_sample for efficient interpolation of separate Y and X position embedding tables based on patch coordinates. Based on proposal by https://github.com/stas-sl Args: x: Input tensor to add position embeddings to [B, N, C] patch_coord: Patch coordinates [B, N, 2] with (y, x) values """ device = x.device B, _, C = x.shape shapes = patch_coord.amax(dim=1) + 1 if self.pos_embed_ar_preserving: # Aspect ratio preserving mode: use square grid with uniform scaling L_i = shapes.amax(dim=1) # (B,) max(h_i, w_i) L_global = L_i.amax() grid_size_y = grid_size_x = L_global scale_x = scale_y = L_global / L_i # uniform zoom (B,) else: # Standard mode: different scaling for x and y grid_size_y, grid_size_x = shapes.amax(0) scale_x = grid_size_x / shapes[:, 1] # horizontal zoom (B,) scale_y = grid_size_y / shapes[:, 0] # vertical zoom (B,) def _interp1d(table: torch.Tensor, scale: torch.Tensor, out_length: torch.Tensor) -> torch.Tensor: pe = table.permute(0, 2, 1).unsqueeze(2).expand(B, -1, -1, -1).float() # (1, L, C) -> (B, C, 1, L) theta = torch.zeros(B, 2, 3, device=x.device) theta[:, 0, 0] = scale theta[:, 0, 2] = scale - 1 theta[:, 1, 1] = 1 grid = F.affine_grid(theta, (B, C, 1, out_length), align_corners=False) pe = F.grid_sample(pe, grid, mode='bilinear', align_corners=False, padding_mode='border') return pe.to(x.dtype) # Interpolate along each axis pe_x = _interp1d(self.pos_embed_x, scale=scale_x, out_length=grid_size_x) pe_y = _interp1d(self.pos_embed_y, scale=scale_y, out_length=grid_size_y) bi = torch.arange(B, device=device, dtype=torch.long).unsqueeze(1) x += pe_x[bi, :, 0, patch_coord[..., 1]] + pe_y[bi, :, 0, patch_coord[..., 0]] def _apply_factorized_pos_embed( self, x: torch.Tensor, grid_size: List[int], ) -> None: """Apply factorized position embeddings to standard 2D batch in-place. Uses separate Y and X position embedding tables that are interpolated and combined for the specified grid size. Args: x: Input tensor to add position embeddings to [B, H*W, C] grid_size: Target grid size as [height, width] """ orig_h, orig_w = self.pos_embed_y.shape[1], self.pos_embed_x.shape[1] target_h, target_w = grid_size if self.pos_embed_ar_preserving: len_y = len_x = max(target_h, target_w) else: len_y, len_x = target_h, target_w def _interp1d(table: torch.Tensor, new_length: int, orig_length: int) -> torch.Tensor: if new_length == orig_length: return table.to(dtype=x.dtype) return F.interpolate( table.permute(0, 2, 1).float(), # (1,L,C) -> (1,C,L) size=new_length, mode='linear', align_corners=False, ).permute(0, 2, 1).to(dtype=x.dtype) # (1,L,C) # Interpolate embeddings pe_y = _interp1d(self.pos_embed_y, len_y, orig_h)[:, :target_h] # (1,H,C) pe_x = _interp1d(self.pos_embed_x, len_x, orig_w)[:, :target_w] # (1,W,C) # Broadcast, add and flatten to sequence layout (row major) pos_embed = pe_y.unsqueeze(2) + pe_x.unsqueeze(1) # (1, H, W, C) pos_embed_flat = pos_embed.flatten(1, 2) # (1, H*W, C) x.add_(pos_embed_flat) def forward( self, x: torch.Tensor, patch_coord: Optional[torch.Tensor] = None, patch_valid: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[Tuple[int, int]]]: """Forward pass for patch embedding with position encoding. Args: x: Input tensor. Supported formats: - [B, C, H, W] for conv mode - [B, N, P*P*C] for pre-patchified linear mode (normal) - [B, N, Ph, Pw, C] for pre-patchified linear mode (variable patch size) patch_coord: Optional patch coordinates [B, N, 2] for NaFlex mode. patch_valid: Optional validity mask for patches [B, N] for NaFlex mode. Returns: Tuple of (embedded_tensor, grid_size) where: - embedded_tensor: [B, num_prefix_tokens + N, embed_dim] - grid_size: (H, W) tuple for standard mode, None for NaFlex mode """ grid_size: Optional[Tuple[int, int]] = None B = x.shape[0] if self.is_linear: # Linear embedding path, works with NaFlex mode or standard 2D mode if patch_coord is None: # Standard 2D (B, C, H, W) mode _assert(x.ndim == 4, 'Expecting 2D image input with input ndim == 4') x, grid_size = batch_patchify(x, self.patch_size, pad=self.dynamic_img_pad) else: # Pre-patchified NaFlex mode # Variable patch size mode: [B, N, Ph, Pw, C], normal mode: [B, N, P*P*C] _assert(x.ndim == 5 or x.ndim == 3, 'Expecting patchified input with ndim == 3 or 5.') # Handle variable patch size projection if self.enable_patch_interpolator and x.ndim == 5: _assert(self.norm_input is None, 'input norm not supported with patch resizing') # Apply projection with interpolation x = self.patch_interpolator( x, self.proj.weight, self.proj.bias, patch_size=tuple(x.shape[2:4]), # patch size from [B, N, Ph, Pw, C] shape is_linear=True, ) else: # Standard projection x = x.flatten(2) # ensure [B, N, P*P*C], flatten Ph*Pw*C if separate if self.norm_input is not None: x = self.norm_input(x) x = self.proj(x) else: _assert(x.ndim == 4, 'Convolutional input must be 4D') if self.dynamic_img_pad: H, W = x.shape[-2:] pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0] pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1] x = F.pad(x, (0, pad_w, 0, pad_h)) x = self.proj(x) grid_size = x.shape[-2:] if self.flatten: x = x.flatten(2).transpose(1, 2) # NCHW -> NLC # Apply normalization after flattening x = self.norm(x) if self.pos_embed_type == 'learned': if grid_size is not None: # Standard 2D mode self._apply_learned_pos_embed(x, grid_size=grid_size) else: # NaFlex mode if self.pos_embed_use_grid_sample: self._apply_learned_naflex_pos_embed_grid_sample(x, patch_coord=patch_coord) else: self._apply_learned_naflex_pos_embed(x, patch_coord=patch_coord) elif self.pos_embed_type == 'factorized': if grid_size is not None: # Standard 2D mode self._apply_factorized_pos_embed(x, grid_size=grid_size) else: # NaFlex mode if self.pos_embed_use_grid_sample: self._apply_factorized_naflex_pos_embed_grid_sample(x, patch_coord=patch_coord) else: self._apply_factorized_naflex_pos_embed(x, patch_coord=patch_coord) # Prepare and add class and register tokens to_cat = [] if self.cls_token is not None: to_cat.append(self.cls_token.expand(B, -1, -1)) if self.reg_token is not None: to_cat.append(self.reg_token.expand(B, -1, -1)) # Add tokens to the beginning if to_cat: x = torch.cat(to_cat + [x], dim=1) # Apply dropout x = self.pos_drop(x) return x, grid_size @register_notrace_function def create_attention_mask( patch_valid: torch.Tensor, num_prefix_tokens: int = 0, symmetric: bool = True, q_len: Optional[int] = None, dtype: torch.dtype = torch.float32, ) -> Optional[torch.Tensor]: """Creates an attention mask from patch validity information. Supports two modes controlled by `symmetric`: 1. `symmetric=True` (default): Creates a symmetric mask of shape [B, 1, seq_len, seq_len]. An attention pair (i, j) is allowed only if both token i and token j are valid. Suitable for standard self-attention. 2. `symmetric=False`: Creates a potentially non-square mask of shape [B, 1, q_len, kv_len]. An attention pair (q, k) is allowed only if the key/value token k is valid. Query token validity is not checked in the mask itself. Useful for cross-attention or specific self-attention implementations `q_len` can be specified. Used for NaFlex mode to handle variable token counts and padding tokens. Args: patch_valid: Tensor of shape [B, N] with True for valid patches, False for padding. num_prefix_tokens: Number of prefix tokens (class token, register tokens) to prepend, which are always considered valid. symmetric: If True, create a symmetric mask. If False, create an expanded mask based only on key/value validity. q_len: Query sequence length override. Only used when `symmetric` is False. Defaults to the key/value sequence length (`kv_len`) if None. dtype: Dtype of the output attention mask (e.g., torch.float32). Returns: Attention mask tensor. Additive mask (-inf for masked, 0 for unmasked). Shape is [B, 1, seq_len, seq_len] if symmetric=True, or [B, 1, q_len, kv_len] if symmetric=False. """ if patch_valid is None: return None patch_valid = patch_valid.bool() # Ensure boolean type B, N = patch_valid.shape kv_len = N # Initial key/value length is the number of patches # Prepend prefix tokens if any if num_prefix_tokens > 0: # Create prefix validity tensor on the same device/dtype base as patch_valid prefix_valid = patch_valid.new_ones((B, num_prefix_tokens), dtype=torch.bool) # Concatenate prefix and patch validity. Shape becomes [B, num_prefix_tokens + N] patch_valid = torch.cat([prefix_valid, patch_valid], dim=1) kv_len += num_prefix_tokens # Update total key/value sequence length if symmetric: # Symmetric mask is True where BOTH query and key are valid mask_bool = patch_valid.unsqueeze(-1) & patch_valid.unsqueeze(1) mask_bool = mask_bool.unsqueeze(1) # Add head dimension: [B, 1, seq_len, seq_len] else: # Expanded mask q_len = q_len or kv_len mask_bool = patch_valid[:, None, None, :].expand(B, 1, q_len, kv_len) # Create the float mask and apply masking using additive mask convention mask_float = torch.zeros_like(mask_bool, dtype=dtype) # Fill with negative infinity where mask_bool is False (masked positions) mask_float.masked_fill_(~mask_bool, torch.finfo(dtype).min) return mask_float @register_notrace_function def global_pool_naflex( x: torch.Tensor, patch_valid: Optional[torch.Tensor] = None, pool_type: str = 'token', num_prefix_tokens: int = 1, reduce_include_prefix: bool = False, ) -> torch.Tensor: """Global pooling with NaFlex support for masked tokens. Applies global pooling while respecting patch validity masks to exclude padding tokens from pooling operations. Args: x: Input tensor with shape [B, N, C] patch_valid: Optional validity mask for patches [B, N-num_prefix_tokens] pool_type: Type of pooling ('token', 'avg', 'avgmax', 'max') num_prefix_tokens: Number of prefix tokens (class/register) reduce_include_prefix: Whether to include prefix tokens in pooling reduction Returns: Pooled tensor with shape [B, C] """ if patch_valid is None or pool_type not in ('avg', 'avgmax', 'max'): # Fall back to standard pooling x = global_pool_nlc( x, pool_type=pool_type, num_prefix_tokens=num_prefix_tokens, reduce_include_prefix=reduce_include_prefix, ) return x # For NaFlex mode, we need to apply masked pooling to exclude padding tokens if num_prefix_tokens > 0: if reduce_include_prefix: # Include prefix tokens in pooling - they are always considered valid # patch_valid only covers patch tokens, so create combined validity mask prefix_valid = patch_valid.new_ones(x.shape[0], num_prefix_tokens) patch_valid = torch.cat([prefix_valid, patch_valid], dim=1) else: # Exclude prefix tokens from pooling (default behavior) x = x[:, num_prefix_tokens:] patch_valid_float = patch_valid.to(x.dtype) if pool_type == 'avg': # Compute masked average pooling, sum valid tokens and divide by count of valid tokens masked_sums = (x * patch_valid_float.unsqueeze(-1)).sum(dim=1) valid_counts = patch_valid_float.sum(dim=1, keepdim=True).clamp(min=1) pooled = masked_sums / valid_counts return pooled elif pool_type == 'avgmax': # For avgmax, compute masked average and masked max masked_sums = (x * patch_valid_float.unsqueeze(-1)).sum(dim=1) valid_counts = patch_valid_float.sum(dim=1, keepdim=True).clamp(min=1) masked_avg = masked_sums / valid_counts # For max pooling we set masked positions to large negative value masked_x = x.clone() masked_x[~patch_valid] = torch.finfo(masked_x.dtype).min masked_max = masked_x.amax(dim=1) # Combine average and max return 0.5 * (masked_avg + masked_max) elif pool_type == 'max': # For max pooling we set masked positions to large negative value masked_x = x.clone() masked_x[~patch_valid] = torch.finfo(masked_x.dtype).min return masked_x.amax(dim=1) else: assert False class NaFlexVit(nn.Module): """NaFlexVit: Vision Transformer with NaFlex support for flexible input handling. A flexible implementation of Vision Transformer that supports: - Standard image classification with various pooling strategies - NaFlex functionality for variable aspect ratios and resolutions - Linear patch embedding for pre-patchified inputs - Multiple position embedding strategies (learned, factorized, rope) - Comprehensive attention masking for efficient batch processing - Encapsulated embedding and position encoding in FlexEmbeds module - Compatible with standard ViT checkpoints through checkpoint filtering """ def __init__( self, cfg: Optional[NaFlexVitCfg] = None, in_chans: int = 3, num_classes: int = 1000, img_size: Optional[Union[int, Tuple[int, int]]] = None, device=None, dtype=None, **kwargs, ) -> None: """Initialize NaFlexVit model. Args: cfg: Model configuration. If None, uses default NaFlexVitCfg. in_chans: Number of input image channels. num_classes: Number of classification classes. img_size: Input image size (for backwards compatibility with classic vit). **kwargs: Additional config parameters to override cfg values. """ super().__init__() dd = {'device': device, 'dtype': dtype} # Initialize config cfg = cfg or NaFlexVitCfg() if kwargs: cfg = _overlay_kwargs(cfg, **kwargs) # Validate configuration assert cfg.global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map') assert cfg.class_token or cfg.global_pool != 'token' assert cfg.pos_embed in ('', 'none', 'learned', 'factorized') # Resolve layer implementations norm_layer = get_norm_layer(cfg.norm_layer) or LayerNorm embed_norm_layer = get_norm_layer(cfg.embed_norm_layer) act_layer = get_act_layer(cfg.act_layer) or nn.GELU block_fn = get_block_fn(cfg) mlp_layer = cfg.mlp_layer or Mlp # TODO: Support configurable mlp_layer via string lookup # Store instance variables self.num_classes = num_classes self.global_pool = cfg.global_pool self.num_features = self.head_hidden_size = self.embed_dim = cfg.embed_dim # for consistency with other models self.num_prefix_tokens = 1 if cfg.class_token else 0 self.num_prefix_tokens += cfg.reg_tokens self.num_reg_tokens = cfg.reg_tokens self.has_class_token = cfg.class_token self.pool_include_prefix = cfg.pool_include_prefix self.grad_checkpointing = False # Initialize embedding module (includes patch, position embedding, and class/reg tokens) # FlexEmbeds is always used - handles both linear and conv embedding self.embeds = NaFlexEmbeds( patch_size=cfg.patch_size, in_chans=in_chans, embed_dim=cfg.embed_dim, proj_type=cfg.embed_proj_type, proj_bias=not cfg.pre_norm, # disable bias if pre-norm is used (e.g. CLIP) class_token=cfg.class_token, reg_tokens=cfg.reg_tokens, default_img_size=img_size, dynamic_img_pad=cfg.dynamic_img_pad, pos_embed=cfg.pos_embed, pos_embed_grid_size=cfg.pos_embed_grid_size, pos_embed_interp_mode=cfg.pos_embed_interp_mode, pos_embed_ar_preserving=cfg.pos_embed_ar_preserving, pos_embed_use_grid_sample=cfg.pos_embed_use_grid_sample, proj_norm_layer=embed_norm_layer, pos_drop_rate=cfg.pos_drop_rate, enable_patch_interpolator=getattr(cfg, 'enable_patch_interpolator', False), **dd, ) self.norm_pre = norm_layer(cfg.embed_dim, **dd) if cfg.pre_norm else nn.Identity() # ROPE position embeddings at model level self.rope: Optional[nn.Module] = None self.rope_is_mixed = False if cfg.rope_type and cfg.rope_type != 'none': from timm.layers.pos_embed_sincos import RotaryEmbeddingCat, RotaryEmbeddingMixed if cfg.rope_type == 'mixed': self.rope = RotaryEmbeddingMixed( cfg.embed_dim, depth=cfg.depth, num_heads=cfg.num_heads, temperature=cfg.rope_temperature, feat_shape=None, # Dynamic shapes for NaFlex grid_indexing=cfg.rope_grid_indexing, **dd, ) self.rope_is_mixed = True elif cfg.rope_type == 'axial': self.rope = RotaryEmbeddingCat( cfg.embed_dim // cfg.num_heads, temperature=cfg.rope_temperature, in_pixels=False, feat_shape=None, # Dynamic shapes for NaFlex ref_feat_shape=cfg.rope_ref_feat_shape, grid_offset=cfg.rope_grid_offset, grid_indexing=cfg.rope_grid_indexing, **dd, ) self.rope_is_mixed = False else: raise ValueError(f"Unknown rope_type: {cfg.rope_type}") # Patch dropout if cfg.patch_drop_rate > 0: self.patch_drop = PatchDropoutWithIndices( cfg.patch_drop_rate, num_prefix_tokens=self.num_prefix_tokens, ) else: self.patch_drop = None # Transformer blocks dpr = calculate_drop_path_rates(cfg.drop_path_rate, cfg.depth) # stochastic depth decay rule # Create transformer blocks self.blocks = nn.Sequential(*[ block_fn( dim=cfg.embed_dim, num_heads=cfg.num_heads, mlp_ratio=cfg.mlp_ratio, qkv_bias=cfg.qkv_bias, qk_norm=cfg.qk_norm, proj_bias=cfg.proj_bias, init_values=cfg.init_values, proj_drop=cfg.proj_drop_rate, attn_drop=cfg.attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer, mlp_layer=mlp_layer, **dd, ) for i in range(cfg.depth) ]) # Feature info for downstream tasks patch_reduction = self.embeds.feat_ratio(as_scalar=True) self.feature_info = [ dict(module=f'blocks.{i}', num_chs=cfg.embed_dim, reduction=patch_reduction) for i in range(cfg.depth) ] self.norm = norm_layer(cfg.embed_dim, **dd) if cfg.final_norm and not cfg.fc_norm else nn.Identity() # Classifier Head if cfg.global_pool == 'map': self.attn_pool = AttentionPoolLatent( self.embed_dim, num_heads=cfg.attn_pool_num_heads or cfg.num_heads, mlp_ratio=cfg.attn_pool_mlp_ratio or cfg.mlp_ratio, norm_layer=norm_layer, act_layer=act_layer, **dd, ) else: self.attn_pool = None # Handle fc_norm default value fc_norm = cfg.fc_norm if fc_norm is None: fc_norm = cfg.global_pool == 'avg' self.fc_norm = norm_layer(cfg.embed_dim, **dd) if cfg.final_norm and fc_norm else nn.Identity() self.head_drop = nn.Dropout(cfg.drop_rate) self.head = nn.Linear(self.embed_dim, num_classes, **dd) if num_classes > 0 else nn.Identity() if cfg.weight_init != 'skip': self.init_weights(cfg.weight_init) if cfg.fix_init: self.fix_init_weight() def fix_init_weight(self) -> None: """Apply initialization weight fix with layer-wise scaling.""" def rescale(param: torch.Tensor, _layer_id: int) -> None: with torch.no_grad(): param.div_(math.sqrt(2.0 * _layer_id)) for layer_id, layer in enumerate(self.blocks): if hasattr(layer, 'attn'): rescale(layer.attn.proj.weight, layer_id + 1) if hasattr(layer, 'mlp'): rescale(layer.mlp.fc2.weight, layer_id + 1) if hasattr(layer, 'attn_out_proj'): rescale(layer.attn_out_proj.weight, layer_id + 1) if hasattr(layer, 'mlp_out_proj'): rescale(layer.mlp_out_proj.weight, layer_id + 1) def init_weights(self, mode: str = '') -> None: """Initialize model weights according to specified scheme. Args: mode: Initialization mode ('jax', 'jax_nlhb', 'moco', or '') """ assert mode in ('jax', 'jax_nlhb', 'moco', '') head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. named_apply(get_init_weights_vit(mode, head_bias), self) @torch.jit.ignore() def load_pretrained(self, checkpoint_path: str, prefix: str = '') -> None: # Custom loading for the new model structure from .vision_transformer import _load_weights as _orig_load_weights def _load_weights_adapter(model, checkpoint_path, prefix=''): """Adapter function to handle the different model structure""" state_dict = torch.load(checkpoint_path, map_location='cpu') if isinstance(state_dict, dict) and 'state_dict' in state_dict: state_dict = state_dict['state_dict'] # Map original keys to new structure for k in list(state_dict.keys()): if k.startswith('cls_token'): state_dict['embeds.' + k] = state_dict.pop(k) elif k.startswith('reg_token'): state_dict['embeds.' + k] = state_dict.pop(k) elif k.startswith('pos_embed'): state_dict['embeds.' + k] = state_dict.pop(k) elif k.startswith('patch_embed'): state_dict['embeds.' + k[12:]] = state_dict.pop(k) return _orig_load_weights(model, state_dict, prefix) _load_weights_adapter(self, checkpoint_path, prefix) @torch.jit.ignore def no_weight_decay(self) -> Set: """Get set of parameter names that should not have weight decay applied. Returns: Set of parameter names to skip during weight decay """ skip_list = {'embeds.pos_embed', 'embeds.cls_token', 'embeds.reg_token'} if self.rope and hasattr(self.rope, 'no_weight_decay'): skip_list.update(self.rope.no_weight_decay()) return skip_list @torch.jit.ignore def group_matcher(self, coarse: bool = False) -> Dict: """Get parameter group matcher for optimizer parameter grouping. Args: coarse: Whether to use coarse-grained grouping Returns: Dictionary mapping group names to regex patterns """ return dict( stem=r'^embeds', # stem and embed blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))] ) @torch.jit.ignore def set_grad_checkpointing(self, enable: bool = True) -> None: """Enable or disable gradient checkpointing for memory efficiency. Args: enable: Whether to enable gradient checkpointing """ self.grad_checkpointing = enable if hasattr(self.embeds, 'patch_embed') and hasattr(self.embeds.patch_embed, 'set_grad_checkpointing'): self.embeds.patch_embed.set_grad_checkpointing(enable) @torch.jit.ignore def get_classifier(self) -> nn.Module: """Get the classification head module. Returns: Classification head module """ return self.head @disable_compiler def _generate_rope_naflex( self, x: torch.Tensor, patch_coord: torch.Tensor, ) -> Union[torch.Tensor, List[torch.Tensor], Any]: """Generate ROPE position embeddings for NaFlex batch with variable grid sizes. Args: x: Input tensor [B, N, C] patch_coord: Patch coordinates [B, N, 2] with (y, x) values Returns: ROPE embeddings: - Axial mode: Tensor of shape [B, 1, N, dim*2] - Mixed mode: List of tensors, each of shape [B, num_heads, N, dim], one per depth layer - Mixed mode with iterator: Iterator yielding tensors per depth """ # Calculate grid sizes for each sample naflex_grid_sizes = calculate_naflex_grid_sizes(patch_coord) # Build ROPE embeddings for each unique grid size size_to_indices = {} unique_sizes = [] for bi, grid_size in enumerate(naflex_grid_sizes): if grid_size not in size_to_indices: size_to_indices[grid_size] = [] unique_sizes.append(grid_size) size_to_indices[grid_size].append(bi) B, N, C = x.shape seq_len = N - self.num_prefix_tokens if self.rope_is_mixed: # Use an iterator for Mixed mode, returns [batch_size, depth, num_heads, seq_len, dim] return NaFlexRopeIterator( self.rope, size_to_indices, unique_sizes, B, seq_len, x.dtype, x.device ) # Axial mode: [batch_size, seq_len, dim*2] rope_embeds = torch.zeros(B, seq_len, self.rope.dim * 2, dtype=x.dtype, device=x.device) if hasattr(self.rope, 'get_batch_embeds'): # Batch mode - generate unique embeds from one grid and then assign unique_embeds = self.rope.get_batch_embeds(unique_sizes) for grid_size, embed, batch_indices in zip(unique_sizes, unique_embeds, size_to_indices.values()): h, w = grid_size actual_len = h * w for bi in batch_indices: rope_embeds[bi, :actual_len] = embed[:actual_len] else: # Generate each unique size separately and assign for grid_size, bi in size_to_indices.items(): rope_embed = self.rope.get_embed(shape=grid_size) h, w = grid_size actual_len = h * w rope_embeds[bi, :actual_len] = rope_embed[:actual_len] rope_embeds = rope_embeds.unsqueeze(1) return rope_embeds def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None) -> None: """Reset the classification head with new number of classes and pooling. Args: num_classes: Number of classes for new classification head global_pool: Optional new global pooling type """ self.num_classes = num_classes if global_pool is not None: assert global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map') if global_pool == 'map' and self.attn_pool is None: assert False, "Cannot currently add attention pooling in reset_classifier()." elif global_pool != 'map' and self.attn_pool is not None: self.attn_pool = None # remove attention pooling self.global_pool = global_pool self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() def _forward_embeds( self, x, patch_coord, patch_valid, attn_mask, ) -> Dict[str, torch.Tensor]: """ Forward pass through patch / abs pos / rope pos embeds and patch dropout """ naflex_mode = patch_coord is not None # patch embed, abs pos embed, returns global grid size as calculated from 'standard' NCHW batches x, grid_size = self.embeds( x, patch_coord=patch_coord, patch_valid=patch_valid, ) # Generate ROPE embeddings at model level rope_embeds = None if self.rope is not None: if patch_coord is not None: # NaFlex mode - variable grid sizes rope_embeds = self._generate_rope_naflex(x, patch_coord) elif grid_size is not None: # Standard mode - fixed grid size rope_embeds = self.rope.get_embed(shape=grid_size) else: assert False, 'Expected one of patch_coord or grid_size to be valid' # Apply patch dropout with coordinated updates keep_indices: Optional[torch.Tensor] = None if self.training and self.patch_drop is not None: x, keep_indices = self.patch_drop(x) # keep_indices excludes prefix tokens, can use directly on patch_valid & rope embeds if patch_valid is not None: patch_valid = patch_valid.gather(1, keep_indices) if rope_embeds is not None and not self.rope_is_mixed: # Update ROPE embeddings to match dropped tokens (only for axial mode) # Batch dim already present in NaFlex mode, but will be added in standard mode. rope_embeds = apply_keep_indices_nlc(x, rope_embeds, keep_indices, pos_embed_has_batch=naflex_mode) if not naflex_mode: # B, N, dim -> B, 1, N, dim. Need head dim added for standard mode, already added in NaFlex. rope_embeds = rope_embeds.unsqueeze(1) # Create attention mask from patch_valid after patch dropout applied if attn_mask is None: attn_mask = create_attention_mask( patch_valid, num_prefix_tokens=self.num_prefix_tokens, dtype=x.dtype ) x = self.norm_pre(x) return { 'patches': x, 'patch_valid': patch_valid, 'rope_embeds': rope_embeds, 'attn_mask': attn_mask, 'keep_indices': keep_indices, } def forward_intermediates( self, x: Union[torch.Tensor, Dict[str, torch.Tensor]], indices: Optional[Union[int, List[int]]] = None, return_prefix_tokens: bool = False, norm: bool = False, stop_early: bool = False, output_fmt: str = 'NCHW', intermediates_only: bool = False, output_dict: bool = False, patch_coord: Optional[torch.Tensor] = None, patch_valid: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None, ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]], Dict[str, Any]]: """ Forward features that returns intermediates. Args: x: Input image tensor indices: Take last n blocks if int, all if None, select matching indices if sequence return_prefix_tokens: Return both prefix and spatial intermediate tokens norm: Apply norm layer to all intermediates stop_early: Stop iterating over blocks when last desired intermediate hit output_fmt: Shape of intermediate feature outputs intermediates_only: Only return intermediate features output_dict: Return outputs as a dictionary with 'image_features' and 'image_intermediates' keys patch_coord: Optional patch coordinates [B, N, 2] for NaFlex mode patch_valid: Optional patch type indicators (1=patch, 0=padding) for NaFlex attn_mask: Optional attention mask for masked attention Returns: A tuple with (final_features, intermediates), a list of intermediate features, or a dictionary containing 'image_features' and 'image_intermediates' (and optionally 'image_intermediates_prefix') """ # FIXME unfinished / untested assert output_fmt in ('NCHW', 'NLC'), 'Output format must be one of NCHW or NLC.' reshape = output_fmt == 'NCHW' intermediates = [] take_indices, max_index = feature_take_indices(len(self.blocks), indices) if isinstance(x, Dict): # Handle dictionary input from NaFlex collator patch_coord = x['patch_coord'] patch_valid = x['patch_valid'] patches = x['patches'] assert False, 'WIP, patch mode needs more work' else: patches = x height, width = x.shape[-2:] H, W = self.embeds.dynamic_feat_size((height, width)) # Forward pass through patch and abs position embedding embeds = self._forward_embeds( patches, patch_coord=patch_coord, patch_valid=patch_valid, attn_mask=attn_mask, ) x = embeds['patches'] rope_embeds = embeds.get('rope_embeds', None) keep_indices = embeds.get('keep_indices', None) attn_mask = embeds.get('attn_mask', None) # Forward pass through blocks if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript blocks = self.blocks else: blocks = self.blocks[:max_index + 1] do_checkpointing = self.grad_checkpointing and not torch.jit.is_scripting() if self.rope_is_mixed and rope_embeds is not None: # Mixed mode with per-layer embeddings (list or iterator) for i, (blk, rope_embed) in enumerate(zip(self.blocks, rope_embeds)): # Apply patch dropout to rope_embed if needed if self.training and self.patch_drop is not None and keep_indices is not None: # Apply patch dropout to rope_embed if needed (batch dim already present in naflex mode) rope_embed = apply_keep_indices_nlc( x, rope_embed, keep_indices, pos_embed_has_batch=embeds.get('naflex_mode', False), ) if do_checkpointing: x = checkpoint(blk, x, rope=rope_embed, attn_mask=attn_mask) else: x = blk(x, rope=rope_embed, attn_mask=attn_mask) if i in take_indices: # normalize intermediates with final norm layer if enabled intermediates.append(self.norm(x) if norm else x) else: for i, blk in enumerate(blocks): # Axial ROPE mode with shared embeddings if rope_embeds is not None: if do_checkpointing: x = checkpoint(blk, x, rope=rope_embeds, attn_mask=attn_mask) else: x = blk(x, rope=rope_embeds, attn_mask=attn_mask) else: if do_checkpointing: x = checkpoint(blk, x, attn_mask=attn_mask) else: x = blk(x, attn_mask=attn_mask) if i in take_indices: # normalize intermediates with final norm layer if enabled intermediates.append(self.norm(x) if norm else x) # Process intermediates if self.num_prefix_tokens: # split prefix (e.g. class, distill) and spatial feature tokens prefix_tokens = [y[:, 0:self.num_prefix_tokens] for y in intermediates] intermediates = [y[:, self.num_prefix_tokens:] for y in intermediates] else: prefix_tokens = None if reshape: # reshape to BCHW output format intermediates = [ y.reshape(y.shape[0], H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates ] # FIXME always use dict for NaFlex mode to return masks and more? # For dictionary output if output_dict: result_dict = {} # Intermediates are always included result_dict['image_intermediates'] = intermediates if prefix_tokens is not None and return_prefix_tokens: result_dict['image_intermediates_prefix'] = prefix_tokens # Only include features if not intermediates_only if not intermediates_only: x_final = self.norm(x) result_dict['image_features'] = x_final return result_dict # For non-dictionary output, maintain the original behavior if not torch.jit.is_scripting() and return_prefix_tokens and prefix_tokens is not None: # return_prefix not support in torchscript due to poor type handling intermediates = list(zip(intermediates, prefix_tokens)) if intermediates_only: return intermediates x = self.norm(x) return x, intermediates def forward_features( self, patches: torch.Tensor, patch_coord: Optional[torch.Tensor] = None, patch_valid: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: """ """ naflex_mode = patch_coord is not None # Pass through patch & abs position embedding module with patch coordinate/type support embeds = self._forward_embeds( patches, patch_coord=patch_coord, patch_valid=patch_valid, attn_mask=attn_mask, ) x = embeds['patches'] rope_embeds = embeds.get('rope_embeds', None) keep_indices = embeds.get('keep_indices', None) attn_mask = embeds.get('attn_mask', None) # Apply transformer blocks with masked attention and/or ROPE if provided do_checkpointing = self.grad_checkpointing and not torch.jit.is_scripting() if self.rope_is_mixed and rope_embeds is not None: # Mixed mode with per-layer embeddings (list or iterator) for i, (blk, rope_embed) in enumerate(zip(self.blocks, rope_embeds)): if self.training and self.patch_drop is not None and keep_indices is not None: # Apply patch dropout to rope_embed if needed (batch dim already present in naflex mode) rope_embed = apply_keep_indices_nlc( x, rope_embed, keep_indices, pos_embed_has_batch=naflex_mode, ) if do_checkpointing: x = checkpoint(blk, x, rope=rope_embed, attn_mask=attn_mask) else: x = blk(x, rope=rope_embed, attn_mask=attn_mask) elif rope_embeds is not None: # Axial ROPE mode with shared embeddings for blk in self.blocks: if do_checkpointing: x = checkpoint(blk, x, rope=rope_embeds, attn_mask=attn_mask) else: x = blk(x, rope=rope_embeds, attn_mask=attn_mask) else: for blk in self.blocks: if do_checkpointing: x = checkpoint(blk, x, attn_mask=attn_mask) else: x = blk(x, attn_mask=attn_mask) x = self.norm(x) if naflex_mode: return { 'patches': x, 'patch_valid': embeds.get('patch_valid', None), } return x def _pool( self, x: torch.Tensor, pool_type: Optional[str] = None, patch_valid: Optional[torch.Tensor] = None, ) -> torch.Tensor: if self.attn_pool is not None: attn_mask = create_attention_mask( patch_valid, num_prefix_tokens=self.num_prefix_tokens if self.pool_include_prefix else 0, symmetric=False, q_len=1, dtype=x.dtype, ) if not self.pool_include_prefix: x = x[:, self.num_prefix_tokens:] x = self.attn_pool(x, attn_mask=attn_mask) return x pool_type = self.global_pool if pool_type is None else pool_type x = global_pool_naflex( x, patch_valid, pool_type=pool_type, num_prefix_tokens=self.num_prefix_tokens, reduce_include_prefix=self.pool_include_prefix, ) return x def forward_head( self, patches: torch.Tensor, pre_logits: bool = False, patch_valid: Optional[torch.Tensor] = None, ) -> torch.Tensor: x = self._pool(patches, patch_valid=patch_valid) x = self.fc_norm(x) x = self.head_drop(x) return x if pre_logits else self.head(x) def forward( self, x: Union[torch.Tensor, Dict[str, torch.Tensor]], patch_coord: Optional[torch.Tensor] = None, patch_valid: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with optional NaFlex support. Args: x: Input tensor. Supported formats: - [B, C, H, W] standard image input - [B, N, P*P*C] pre-patchified tensor (flattened patches) - [B, N, Ph, Pw, C] pre-patchified tensor (variable patch size) - Dict from NaFlex collator patch_coord: Optional patch coordinates [B, N, 2] for NaFlex mode. patch_valid: Optional patch validity indicators for NaFlex. attn_mask: Optional attn mask to override defaults generated from patch_valid Returns: Model output tensor. """ input_is_dict = isinstance(x, Dict) naflex_mode = input_is_dict or patch_coord is not None if naflex_mode: if input_is_dict: # Handle dictionary input from NaFlex collator, dict inputs take priority over args patches = x['patches'] patch_valid = x.get('patch_valid', patch_valid) patch_coord = x.get('patch_coord', patch_coord) attn_mask = x.get('attn_mask', attn_mask) else: patches = x _assert(patch_coord is not None, "patch_coord is required in naflex mode") _assert(patch_valid is not None, "patch_valid is required in naflex mode") features = self.forward_features( patches=patches, patch_valid=patch_valid, patch_coord=patch_coord, attn_mask=attn_mask, ) # Pass patches & patch_valid to forward_head for masked pooling x = self.forward_head(**features) else: x = self.forward_features(x) x = self.forward_head(x) return x def _debug_dump_patches(x): # DEBUG, reconstruct patches & save patch_coord = x['patch_coord'] patch_valid = x['patch_valid'] patches = x['patches'] for i in range(len(patches)): patch = patches[i][patch_valid[i]] h = (patch_coord[i, :, 0].max() + 1).item() w = (patch_coord[i, :, 1].max() + 1).item() patch = patch.reshape(h, w, 16, 16, 3).permute(4, 0, 2, 1, 3) patch = patch.reshape(3, h*16, w*16) from torchvision.utils import save_image save_image(patch, f'patch_{i}.jpg', normalize=True) def get_init_weights_vit(mode: str = 'jax', head_bias: float = 0.0) -> Callable: """Function imported from vision_transformer.py to maintain compatibility""" from .vision_transformer import init_weights_vit_jax, init_weights_vit_moco, init_weights_vit_timm if 'jax' in mode: return partial(init_weights_vit_jax, head_bias=head_bias) elif 'moco' in mode: return init_weights_vit_moco else: return init_weights_vit_timm def checkpoint_filter_fn(state_dict: Dict[str, Any], model: NaFlexVit) -> Dict[str, Any]: """Handle state dict conversion from original ViT to the new version with combined embedding.""" # Handle CombinedEmbed module pattern out_dict = {} for k, v in state_dict.items(): # Convert tokens and embeddings to combined_embed structure if k == 'pos_embed': # Handle position embedding format conversion - from (1, N, C) to (1, H, W, C) if hasattr(model.embeds, 'pos_embed') and v.ndim == 3: num_cls_token = 0 num_reg_token = 0 if 'reg_token' in state_dict: num_reg_token = state_dict['reg_token'].shape[1] if 'cls_token' in state_dict: num_cls_token = state_dict['cls_token'].shape[1] num_prefix_tokens = num_cls_token + num_reg_token # Original format is (1, N, C), need to reshape to (1, H, W, C) num_patches = v.shape[1] num_patches_no_prefix = num_patches - num_prefix_tokens grid_size_no_prefix = math.sqrt(num_patches_no_prefix) grid_size = math.sqrt(num_patches) if (grid_size_no_prefix != grid_size and (grid_size_no_prefix.is_integer() and not grid_size.is_integer()) ): # make a decision, did the pos_embed of the original include the prefix tokens? num_patches = num_patches_no_prefix cls_token_emb = v[:, 0:num_cls_token] if cls_token_emb.numel(): state_dict['cls_token'] += cls_token_emb reg_token_emb = v[:, num_cls_token:num_reg_token] if reg_token_emb.numel(): state_dict['reg_token'] += reg_token_emb v = v[:, num_prefix_tokens:] grid_size = grid_size_no_prefix grid_size = int(grid_size) # Check if it's a perfect square for a standard grid if grid_size * grid_size == num_patches: # Reshape from (1, N, C) to (1, H, W, C) v = v.reshape(1, grid_size, grid_size, v.shape[2]) else: # Not a square grid, we need to get the actual dimensions if hasattr(model.embeds.patch_embed, 'grid_size'): h, w = model.embeds.patch_embed.grid_size if h * w == num_patches: # We have the right dimensions v = v.reshape(1, h, w, v.shape[2]) else: # Dimensions don't match, use interpolation _logger.warning( f"Position embedding size mismatch: checkpoint={num_patches}, model={(h * w)}. " f"Using default initialization and will resize in forward pass." ) # Keep v as is, the forward pass will handle resizing out_dict['embeds.pos_embed'] = v elif k == 'cls_token': out_dict['embeds.cls_token'] = v elif k == 'reg_token': out_dict['embeds.reg_token'] = v # Convert patch_embed.X to embeds.patch_embed.X elif k.startswith('patch_embed.'): suffix = k[12:] if suffix == 'proj.weight': v = v.permute(0, 2, 3, 1).flatten(1) new_key = 'embeds.' + suffix out_dict[new_key] = v else: out_dict[k] = v return out_dict def _cfg(url: str = '', **kwargs) -> Dict[str, Any]: return { 'url': url, 'num_classes': 1000, 'input_size': (3, 384, 384), 'pool_size': None, 'crop_pct': 1.0, 'interpolation': 'bicubic', 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, 'first_conv': 'embeds.proj', 'classifier': 'head', 'license': 'apache-2.0', **kwargs, } default_cfgs = generate_default_cfgs({ 'naflexvit_base_patch16_gap.e300_s576_in1k': _cfg( hf_hub_id='timm/', ), 'naflexvit_base_patch16_par_gap.e300_s576_in1k': _cfg( hf_hub_id='timm/', ), 'naflexvit_base_patch16_parfac_gap.e300_s576_in1k': _cfg( hf_hub_id='timm/', ), 'naflexvit_base_patch16_map.untrained': _cfg(), 'naflexvit_so150m2_patch16_reg1_gap.untrained': _cfg(), 'naflexvit_so150m2_patch16_reg1_map.untrained': _cfg(), # SigLIP-2 NaFlex vit encoder weights 'naflexvit_base_patch16_siglip.v2_webli': _cfg( hf_hub_id='timm/', num_classes=0), 'naflexvit_so400m_patch16_siglip.v2_webli': _cfg( hf_hub_id='timm/', num_classes=0), }) def _create_naflexvit(variant: str, pretrained: bool = False, **kwargs) -> NaFlexVit: out_indices = kwargs.pop('out_indices', 3) cfg = kwargs.pop('cfg', NaFlexVitCfg()) cfg_field_names = {f.name for f in fields(NaFlexVitCfg)} # pop in-place so the original kwargs is emptied of cfg-specific keys cfg_updates = {k: kwargs.pop(k) for k in list(kwargs) if k in cfg_field_names} if cfg_updates: cfg = _overlay_kwargs(cfg, **cfg_updates) model = build_model_with_cfg( NaFlexVit, variant, pretrained, pretrained_filter_fn=checkpoint_filter_fn, cfg=cfg, feature_cfg=dict(out_indices=out_indices, feature_cls='getter'), **kwargs, ) return model def _create_naflexvit_from_classic( variant: str, pretrained: bool = False, **kwargs, ) -> NaFlexVit: """Create FlexVit model from classic VisionTransformer configuration. This function handles the parameter mapping and configuration logic needed to create FlexVit models that are compatible with classic VisionTransformer configurations and pretrained weights. Args: variant: Model variant name pretrained: Whether to load pretrained weights **kwargs: Classic VisionTransformer parameters Returns: FlexVit model instance """ # Remove VisionTransformer-specific parameters that don't apply to FlexVit kwargs.pop('no_embed_class', None) kwargs.pop('dynamic_img_size', None) # Handle global pooling and fc_norm defaults that differ between ViT and FlexVit gp = kwargs.pop('global_pool', 'token') # Original ViTs default to cls token pooling fc_norm = kwargs.pop('fc_norm', None) # Original ViTs used fc_norm when not set and avg pooling used if fc_norm is None and gp == 'avg': fc_norm = True # Set FlexVit-specific defaults that differ from VisionTransformer flex_kwargs = { 'pos_embed_grid_size': None, # rely on img_size (// patch_size) that will be passed through 'class_token': kwargs.get('class_token', True), 'global_pool': gp, 'fc_norm': fc_norm, 'scale_mlp_norm': kwargs.pop('scale_mlp_norm', False), 'scale_attn_inner_norm': kwargs.pop('scale_attn_norm', False), **kwargs # User overrides take precedence } return _create_naflexvit(variant, pretrained, **flex_kwargs) def _create_naflexvit_from_eva( variant: str, pretrained: bool = False, **kwargs, ) -> NaFlexVit: """Create NaFlexVit model from EVA configuration. This function handles the parameter mapping and configuration logic needed to create NaFlexVit models that are compatible with EVA configurations and pretrained weights. Args: variant: Model variant name pretrained: Whether to load pretrained weights **kwargs: EVA model parameters Returns: NaFlexVit model instance """ # Handle EVA's unique parameters & block args kwargs.pop('no_embed_class', None) # EVA specific, not used in NaFlexVit (always no-embed) # Map EVA's rope parameters use_rot_pos_emb = kwargs.pop('use_rot_pos_emb', False) rope_mixed_mode = kwargs.pop('rope_mixed_mode', False) rope_temperature = kwargs.pop('rope_temperature', 10000.) rope_grid_offset = kwargs.pop('rope_grid_offset', 0.) rope_grid_indexing = kwargs.pop('rope_grid_indexing', 'ij') if use_rot_pos_emb: rope_type = 'mixed' if rope_mixed_mode else 'axial' else: rope_type = 'none' # Handle norm/pool resolution logic to mirror EVA gp = kwargs.pop('global_pool', 'avg') use_pre_transformer_norm = kwargs.pop('use_pre_transformer_norm', False) use_post_transformer_norm = kwargs.pop('use_post_transformer_norm', True) use_fc_norm = kwargs.pop('use_fc_norm', None) if use_fc_norm is None: use_fc_norm = gp == 'avg' # default on if avg pool used # Set NaFlexVit-specific parameters naflex_kwargs = { 'pos_embed_grid_size': None, # rely on img_size (// patch_size) 'class_token': kwargs.get('class_token', True), 'reg_tokens': kwargs.pop('num_reg_tokens', kwargs.get('reg_tokens', 0)), 'global_pool': gp, 'pre_norm': use_pre_transformer_norm, 'final_norm': use_post_transformer_norm, 'fc_norm': use_fc_norm, 'pos_embed': 'learned' if kwargs.pop('use_abs_pos_emb', True) else 'none', 'rope_type': rope_type, 'rope_temperature': rope_temperature, 'rope_grid_offset': rope_grid_offset, 'rope_grid_indexing': rope_grid_indexing, 'rope_ref_feat_shape': kwargs.get('ref_feat_shape', None), 'attn_type': kwargs.pop('attn_type', 'eva'), 'swiglu_mlp': kwargs.pop('swiglu_mlp', False), 'qkv_fused': kwargs.pop('qkv_fused', True), 'scale_mlp_norm': kwargs.pop('scale_mlp', False), 'scale_attn_inner_norm': kwargs.pop('scale_attn_inner', False), **kwargs # Pass remaining kwargs through } return _create_naflexvit(variant, pretrained, **naflex_kwargs) @register_model def naflexvit_base_patch16_gap(pretrained: bool = False, **kwargs) -> NaFlexVit: """ViT-Base with NaFlex functionality and global average pooling. """ cfg = NaFlexVitCfg( patch_size=16, embed_dim=768, depth=12, num_heads=12, init_values=1e-5, global_pool='avg', reg_tokens=4, fc_norm=True, ) model = _create_naflexvit('naflexvit_base_patch16_gap', pretrained=pretrained, cfg=cfg, **kwargs) return model @register_model def naflexvit_base_patch16_par_gap(pretrained: bool = False, **kwargs) -> NaFlexVit: """ViT-Base with NaFlex functionality, aspect preserving pos embed, global average pooling. """ cfg = NaFlexVitCfg( patch_size=16, embed_dim=768, depth=12, num_heads=12, init_values=1e-5, pos_embed_ar_preserving=True, global_pool='avg', reg_tokens=4, fc_norm=True, ) model = _create_naflexvit('naflexvit_base_patch16_par_gap', pretrained=pretrained, cfg=cfg, **kwargs) return model @register_model def naflexvit_base_patch16_parfac_gap(pretrained: bool = False, **kwargs) -> NaFlexVit: """ViT-Base with NaFlex functionality, aspect preserving & factorized pos embed, global average pooling. """ cfg = NaFlexVitCfg( patch_size=16, embed_dim=768, depth=12, num_heads=12, init_values=1e-5, pos_embed_ar_preserving=True, pos_embed='factorized', global_pool='avg', reg_tokens=4, fc_norm=True, ) model = _create_naflexvit('naflexvit_base_patch16_parfac_gap', pretrained=pretrained, cfg=cfg, **kwargs) return model @register_model def naflexvit_base_patch16_map(pretrained: bool = False, **kwargs) -> NaFlexVit: """ViT-Base with NaFlex functionality and MAP attention pooling. """ cfg = NaFlexVitCfg( patch_size=16, embed_dim=768, depth=12, num_heads=12, init_values=1e-5, global_pool='map', reg_tokens=1, ) model = _create_naflexvit('naflexvit_base_patch16_map', pretrained=pretrained, cfg=cfg, **kwargs) return model @register_model def naflexvit_so150m2_patch16_reg1_gap(pretrained: bool = False, **kwargs) -> NaFlexVit: """ViT-SO150M2 with NaFlex functionality for variable aspect ratios and resolutions. This model supports: 1. Variable aspect ratios and resolutions via patch coordinates 2. Position embedding interpolation for arbitrary grid sizes 3. Explicit patch coordinates and valid token masking """ cfg = NaFlexVitCfg( patch_size=16, embed_dim=832, depth=21, num_heads=13, mlp_ratio=34/13, init_values=1e-5, qkv_bias=False, reg_tokens=1, global_pool='avg', fc_norm=True, ) model = _create_naflexvit('naflexvit_so150m2_patch16_reg1_gap', pretrained=pretrained, cfg=cfg, **kwargs) return model @register_model def naflexvit_so150m2_patch16_reg1_map(pretrained: bool = False, **kwargs) -> NaFlexVit: """ViT-SO150M2 with NaFlex functionality for variable aspect ratios and resolutions. This model supports: 1. Variable aspect ratios and resolutions via patch coordinates 2. Position embedding interpolation for arbitrary grid sizes 3. Explicit patch coordinates and valid token masking """ cfg = NaFlexVitCfg( patch_size=16, embed_dim=832, depth=21, num_heads=13, mlp_ratio=34/13, init_values=1e-5, qkv_bias=False, reg_tokens=1, global_pool='map', ) model = _create_naflexvit('naflexvit_so150m2_patch16_reg1_map', pretrained=pretrained, cfg=cfg, **kwargs) return model @register_model def naflexvit_base_patch16_siglip(pretrained: bool = False, **kwargs) -> NaFlexVit: """ViT-Base with NaFlex functionality and SigLIP-style configuration. """ cfg = NaFlexVitCfg( patch_size=16, embed_dim=768, depth=12, num_heads=12, act_layer='gelu_tanh', global_pool='map', ) model = _create_naflexvit('naflexvit_base_patch16_siglip', pretrained=pretrained, cfg=cfg, **kwargs) return model @register_model def naflexvit_so400m_patch16_siglip(pretrained: bool = False, **kwargs) -> NaFlexVit: """ViT-SO400M with NaFlex functionality for variable aspect ratios and resolutions. """ cfg = NaFlexVitCfg( patch_size=16, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, act_layer='gelu_tanh', global_pool='map', ) model = _create_naflexvit('naflexvit_so400m_patch16_siglip', pretrained=pretrained, cfg=cfg, **kwargs) return model