""" Sin-cos, fourier, rotary position embedding modules and functions Hacked together by / Copyright 2022 Ross Wightman """ import math from typing import List, Tuple, Optional, Union import torch from torch import nn as nn from ._fx import register_notrace_function from .grid import ndgrid from .trace_utils import _assert def pixel_freq_bands( num_bands: int, max_freq: float = 224., linear_bands: bool = True, device: Optional[torch.device] = None, ): if linear_bands: bands = torch.linspace(1.0, max_freq / 2, num_bands, dtype=torch.float32, device=device) else: bands = 2 ** torch.linspace(0, math.log(max_freq, 2) - 1, num_bands, dtype=torch.float32, device=device) return bands * torch.pi def freq_bands( num_bands: int, temperature: float = 10000., step: int = 2, device: Optional[torch.device] = None, ) -> torch.Tensor: exp = torch.arange(0, num_bands, step, dtype=torch.int64, device=device).to(torch.float32) / num_bands bands = 1. / (temperature ** exp) return bands def build_sincos2d_pos_embed( feat_shape: List[int], dim: int = 64, temperature: float = 10000., reverse_coord: bool = False, interleave_sin_cos: bool = False, device: Optional[torch.device] = None, dtype: torch.dtype = torch.float32, ) -> torch.Tensor: """ Args: feat_shape: dim: temperature: reverse_coord: stack grid order W, H instead of H, W interleave_sin_cos: sin, cos, sin, cos stack instead of sin, sin, cos, cos dtype: device: Returns: """ assert dim % 4 == 0, 'Embed dimension must be divisible by 4 for sin-cos 2D position embedding' pos_dim = dim // 4 bands = freq_bands(pos_dim, temperature=temperature, step=1, device=device) if reverse_coord: feat_shape = feat_shape[::-1] # stack W, H instead of H, W grid = torch.stack(ndgrid([ torch.arange(s, device=device, dtype=torch.int64).to(torch.float32) for s in feat_shape ])).flatten(1).transpose(0, 1) pos2 = grid.unsqueeze(-1) * bands.unsqueeze(0) # FIXME add support for unflattened spatial dim? stack_dim = 2 if interleave_sin_cos else 1 # stack sin, cos, sin, cos instead of sin sin cos cos pos_emb = torch.stack([torch.sin(pos2), torch.cos(pos2)], dim=stack_dim).flatten(1) return pos_emb.to(dtype=dtype) def swap_shape_xy(seq: List[int]) -> List[int]: if len(seq) < 2: return seq return [seq[1], seq[0]] + list(seq[2:]) def build_fourier_pos_embed( feat_shape: List[int], bands: Optional[torch.Tensor] = None, num_bands: int = 64, max_res: int = 224, temperature: float = 10000., linear_bands: bool = False, include_grid: bool = False, in_pixels: bool = True, ref_feat_shape: Optional[List[int]] = None, grid_offset: float = 0., grid_indexing: str = 'ij', device: Optional[torch.device] = None, dtype: torch.dtype = torch.float32, ) -> List[torch.Tensor]: """ Args: feat_shape: Feature shape for embedding. bands: Pre-calculated frequency bands. num_bands: Number of frequency bands (determines output dim). max_res: Maximum resolution for pixel based freq. temperature: Temperature for non-pixel freq. linear_bands: Linear band spacing for pixel based freq. include_grid: Include the spatial grid in output. in_pixels: Output in pixel freq. ref_feat_shape: Reference feature shape for resize / fine-tune. grid_offset: Constant offset to add to grid for non-pixel freq. grid_indexing: Indexing mode for meshgrid ('ij' or 'xy') dtype: Output dtype. device: Output device. Returns: """ if bands is None: if in_pixels: bands = pixel_freq_bands( num_bands, float(max_res), linear_bands=linear_bands, device=device, ) else: bands = freq_bands( num_bands, temperature=temperature, step=1, device=device, ) else: if device is None: device = bands.device if dtype is None: dtype = bands.dtype if grid_indexing == 'xy': feat_shape = swap_shape_xy(feat_shape) if ref_feat_shape is not None: ref_feat_shape = swap_shape_xy(ref_feat_shape) if in_pixels: t = [ torch.linspace(-1., 1., steps=s, device=device, dtype=torch.float32) for s in feat_shape ] else: t = [ torch.arange(s, device=device, dtype=torch.int64).to(torch.float32) + grid_offset for s in feat_shape ] if ref_feat_shape is not None: # eva's scheme for resizing rope embeddings (ref shape = pretrain) t = [x / f * r for x, f, r in zip(t, feat_shape, ref_feat_shape)] grid = torch.stack(torch.meshgrid(t, indexing=grid_indexing), dim=-1) grid = grid.unsqueeze(-1) pos = grid * bands pos_sin, pos_cos = pos.sin().to(dtype=dtype), pos.cos().to(dtype=dtype) out = [grid, pos_sin, pos_cos] if include_grid else [pos_sin, pos_cos] return out class FourierEmbed(nn.Module): def __init__( self, max_res: int = 224, num_bands: int = 64, concat_grid=True, keep_spatial=False, device=None, dtype=None, ): super().__init__() self.max_res = max_res self.num_bands = num_bands self.concat_grid = concat_grid self.keep_spatial = keep_spatial self.register_buffer( 'bands', pixel_freq_bands(max_res, num_bands).to(device=device, dtype=dtype), persistent=False, ) def forward(self, x): B, C = x.shape[:2] feat_shape = x.shape[2:] emb = build_fourier_pos_embed( feat_shape, self.bands, include_grid=self.concat_grid, dtype=x.dtype, device=x.device, ) emb = torch.cat(emb, dim=-1) emb = emb.transpose(-1, -2).flatten(len(feat_shape)) batch_expand = (B,) + (-1,) * (x.ndim - 1) # FIXME support nD if self.keep_spatial: x = torch.cat([x, emb.unsqueeze(0).expand(batch_expand).permute(0, 3, 1, 2)], dim=1) else: x = torch.cat([x.permute(0, 2, 3, 1), emb.unsqueeze(0).expand(batch_expand)], dim=-1) x = x.reshape(B, feat_shape.numel(), -1) return x def rot(x): # x: [ x0 x1 x2 x3 x4 x5] # out: [-x1 x0 -x3 x2 -x5 x4] return torch.stack([-x[..., 1::2], x[..., ::2]], -1).reshape(x.shape) def rope_rotate_half(x: torch.Tensor) -> torch.Tensor: # x: [ x0 x1 x2 x3 x4 x5] # out: [-x3 -x4 -x5 x0 x1 x2] x1, x2 = x.chunk(2, dim=-1) return torch.cat([-x2, x1], dim=-1) def apply_rot_embed( x: torch.Tensor, sin_emb: torch.Tensor, cos_emb: torch.Tensor, half: bool = False, ) -> torch.Tensor: # x: [..., D], eg [x0, x1, x2, x3, x4, x5] if half: # sin: [..., D], eg [sin0, sin1, sin2, sin0, sin1, sin2] # cos: [..., D], eg [cos0, cos1, cos2, cos0, cos1, cos2 # rope_rotate_half(x): eg [-x3, -x4, -x5, x0, x1, x2] return x * cos_emb + rope_rotate_half(x) * sin_emb else: # sin: [..., D], eg [sin0, sin0, sin1, sin1, sin2, sin2] # cos: [..., D], eg [cos0, cos0, cos1, cos1, cos2, cos2] # rot(x): eg [-x1, x0, -x3, x2, -x5, x4] return x * cos_emb + rot(x) * sin_emb def apply_rot_embed_list( x: List[torch.Tensor], sin_emb: torch.Tensor, cos_emb: torch.Tensor, half: bool = False ) -> List[torch.Tensor]: if isinstance(x, torch.Tensor): x = [x] # x: [..., D], eg [x0, x1, x2, x3, x4, x5] if half: # sin: [..., D], eg [sin0, sin1, sin2, sin0, sin1, sin2] # cos: [..., D], eg [cos0, cos1, cos2, cos0, cos1, cos2 # rope_rotate_half(x): eg [-x3, -x4, -x5, x0, x1, x2] return [t * cos_emb + rope_rotate_half(t) * sin_emb for t in x] else: # sin: [..., D], eg [sin0, sin0, sin1, sin1, sin2, sin2] # cos: [..., D], eg [cos0, cos0, cos1, cos1, cos2, cos2] # rot(x): eg [-x1, x0, -x3, x2, -x5, x4] return [t * cos_emb + rot(t) * sin_emb for t in x] def apply_rot_embed_cat( x: torch.Tensor, emb: torch.Tensor, half: bool = False ) -> torch.Tensor: sin_emb, cos_emb = emb.chunk(2, -1) # x: [..., D], eg [x0, x1, x2, x3, x4, x5] if half: # sin: [..., D], eg [sin0, sin1, sin2, sin0, sin1, sin2] # cos: [..., D], eg [cos0, cos1, cos2, cos0, cos1, cos2 # rope_rotate_half(x), eg [-x3, -x4, -x5, x0, x1, x2] return x * cos_emb + rope_rotate_half(x) * sin_emb else: # sin: [..., D], eg [sin0, sin0, sin1, sin1, sin2, sin2] # cos: [..., D], eg [cos0, cos0, cos1, cos1, cos2, cos2] # rot(x), eg [-x1, x0, -x3, x2, -x5, x4] return x * cos_emb + rot(x) * sin_emb def apply_keep_indices_nlc( x: torch.Tensor, pos_embed: torch.Tensor, keep_indices: torch.Tensor, pos_embed_has_batch: bool = False, ) -> torch.Tensor: """ Apply keep indices to different ROPE shapes Expected pos_embed shapes: * [seq_len, pos_embed_dim] --> output [batch_size, seq_len, pos_embed_dim] * [num_heads, seq_len, pos_embed_dim] --> output [batch_size, num_heads, seq_len, pos_embed_dim] * [depth, num_heads, seq_len, pos_embed_dim] --> output [batch_size, depth, num_heads, seq_len, pos_embed_dim] And all of the above with leading batch dimension already present if `pos_embed_has_batch == True` """ if pos_embed_has_batch: # Pos embed already includes batch dim _assert(pos_embed.ndim >= 3, 'Incorrect number of dimensions') # At least [batch, seq_len, pos_embed_dim] else: # Add batch dimension and expand to batch size _assert(pos_embed.ndim >= 2, 'Incorrect number of dimensions') # At least [seq_len, pos_embed_dim] expand_shape = (x.shape[0],) + (-1,) * pos_embed.ndim pos_embed = pos_embed.unsqueeze(0).expand(expand_shape) # Reshape keep_indices to add singleton dims keep_shape = (keep_indices.shape[0],) + (1,) * (pos_embed.ndim - 3) + (keep_indices.shape[1], 1) keep_indices = keep_indices.view(keep_shape) # Expand all dims to match position embedding except the gather dim (second-last) keep_expand = list(pos_embed.shape) keep_expand[-2] = -1 keep_indices = keep_indices.expand(keep_expand) return pos_embed.gather(-2, keep_indices) def build_rotary_pos_embed( feat_shape: List[int], bands: Optional[torch.Tensor] = None, dim: int = 64, max_res: int = 224, temperature: float = 10000., linear_bands: bool = False, in_pixels: bool = True, ref_feat_shape: Optional[List[int]] = None, grid_offset: float = 0., grid_indexing: str = 'ij', device: Optional[torch.device] = None, dtype: torch.dtype = torch.float32, ): """ Args: feat_shape: Spatial shape of the target tensor for embedding. bands: Optional pre-generated frequency bands dim: Output dimension of embedding tensor. max_res: Maximum resolution for pixel mode. temperature: Temperature (inv freq) for non-pixel mode linear_bands: Linearly (instead of log) spaced bands for pixel mode in_pixels: Pixel vs language (inv freq) mode. ref_feat_shape: Reference feature shape for resize / fine-tune. grid_offset: Constant offset to add to grid for non-pixel freq. grid_indexing: Indexing mode for meshgrid ('ij' or 'xy') device: Output device. dtype: Output dtype. Returns: """ sin_emb, cos_emb = build_fourier_pos_embed( feat_shape, bands=bands, num_bands=dim // 4, max_res=max_res, temperature=temperature, linear_bands=linear_bands, in_pixels=in_pixels, ref_feat_shape=ref_feat_shape, grid_offset=grid_offset, grid_indexing=grid_indexing, device=device, dtype=dtype, ) num_spatial_dim = 1 # this would be much nicer as a .numel() call to torch.Size(), but torchscript sucks for x in feat_shape: num_spatial_dim *= x sin_emb = sin_emb.reshape(num_spatial_dim, -1).repeat_interleave(2, -1) cos_emb = cos_emb.reshape(num_spatial_dim, -1).repeat_interleave(2, -1) return sin_emb, cos_emb class RotaryEmbedding(nn.Module): """ Rotary position embedding NOTE: This is my initial attempt at impl rotary embedding for spatial use, it has not been well tested, and will likely change. It will be moved to its own file. The following impl/resources were referenced for this impl: * https://github.com/lucidrains/vit-pytorch/blob/6f3a5fcf0bca1c5ec33a35ef48d97213709df4ba/vit_pytorch/rvt.py * https://blog.eleuther.ai/rotary-embeddings/ """ def __init__( self, dim, max_res=224, temperature=10000, in_pixels=True, linear_bands: bool = False, feat_shape: Optional[List[int]] = None, ref_feat_shape: Optional[List[int]] = None, grid_offset: float = 0., grid_indexing: str = 'ij', device=None, dtype=None, ): super().__init__() self.dim = dim self.max_res = max_res self.temperature = temperature self.linear_bands = linear_bands self.in_pixels = in_pixels self.feat_shape = feat_shape self.ref_feat_shape = ref_feat_shape self.grid_offset = grid_offset self.grid_indexing = grid_indexing if feat_shape is None: # only cache bands if in_pixels: bands = pixel_freq_bands( dim // 4, float(max_res), linear_bands=linear_bands, ) else: bands = freq_bands( dim // 4, temperature=temperature, step=1, ) self.register_buffer( 'bands', bands.to(device=device, dtype=dtype), persistent=False, ) self.pos_embed_sin = None self.pos_embed_cos = None else: # cache full sin/cos embeddings if shape provided up front emb_sin, emb_cos = self._get_pos_embed_values(feat_shape, device=device, dtype=dtype) self.bands = None self.register_buffer( 'pos_embed_sin', emb_sin, persistent=False, ) self.register_buffer( 'pos_embed_cos', emb_cos, persistent=False, ) def _get_pos_embed_values(self, feat_shape: List[int], device=None, dtype=torch.float32): emb_sin, emb_cos = build_rotary_pos_embed( feat_shape=feat_shape, dim=self.dim, max_res=self.max_res, temperature=self.temperature, linear_bands=self.linear_bands, in_pixels=self.in_pixels, ref_feat_shape=self.ref_feat_shape, grid_offset=self.grid_offset, grid_indexing=self.grid_indexing, device=device, dtype=dtype, ) return emb_sin, emb_cos def update_feat_shape(self, feat_shape: List[int]): if self.feat_shape is not None and feat_shape != self.feat_shape: # only update if feat_shape was set and different from previous value assert self.pos_embed_sin is not None assert self.pos_embed_cos is not None self.pos_embed_sin, self.pos_embed_cos = self._get_pos_embed_values( feat_shape, device=self.pos_embed_sin.device, dtype=self.pos_embed_sin.dtype, ) self.feat_shape = feat_shape def get_embed(self, shape: Optional[List[int]] = None): if shape is not None and self.bands is not None: # rebuild embeddings every call, use if target shape changes return build_rotary_pos_embed( shape, self.bands, in_pixels=self.in_pixels, ref_feat_shape=self.ref_feat_shape, grid_offset=self.grid_offset, grid_indexing=self.grid_indexing, ) elif self.pos_embed_sin is not None and self.pos_embed_cos is not None: return self.pos_embed_sin, self.pos_embed_cos else: assert False, "get_embed() requires pre-computed pos embeds or valid shape w/ pre-computed bands" def forward(self, x): # assuming channel-first tensor where spatial dim are >= 2 sin_emb, cos_emb = self.get_embed(x.shape[2:]) return apply_rot_embed(x, sin_emb, cos_emb) class RotaryEmbeddingCat(nn.Module): """ Rotary position embedding w/ concatenatd sin & cos The following impl/resources were referenced for this impl: * https://github.com/lucidrains/vit-pytorch/blob/6f3a5fcf0bca1c5ec33a35ef48d97213709df4ba/vit_pytorch/rvt.py * https://blog.eleuther.ai/rotary-embeddings/ """ def __init__( self, dim: int, max_res: int = 224, temperature: float = 10000, in_pixels: bool = True, linear_bands: bool = False, feat_shape: Optional[List[int]] = None, ref_feat_shape: Optional[List[int]] = None, grid_offset: float = 0., grid_indexing: str = 'ij', device=None, dtype=None, ): super().__init__() self.dim = dim self.max_res = max_res self.temperature = temperature self.in_pixels = in_pixels self.linear_bands = linear_bands self.feat_shape = feat_shape self.ref_feat_shape = ref_feat_shape self.grid_offset = grid_offset self.grid_indexing = grid_indexing if feat_shape is None: # only cache bands if in_pixels: bands = pixel_freq_bands( dim // 4, float(max_res), linear_bands=linear_bands, ) else: bands = freq_bands( dim // 4, temperature=temperature, step=1, ) self.register_buffer( 'bands', bands.to(device=device, dtype=dtype), persistent=False, ) self.pos_embed = None else: # cache full sin/cos embeddings if shape provided up front self.bands = None self.register_buffer( 'pos_embed', self._get_pos_embed_values(feat_shape=feat_shape, device=device, dtype=dtype), persistent=False, ) def _get_pos_embed_values(self, feat_shape: List[int], device=None, dtype=torch.float32): embeds = build_rotary_pos_embed( feat_shape=feat_shape, dim=self.dim, max_res=self.max_res, temperature=self.temperature, linear_bands=self.linear_bands, in_pixels=self.in_pixels, ref_feat_shape=self.ref_feat_shape, grid_offset=self.grid_offset, grid_indexing=self.grid_indexing, device=device, dtype=dtype, ) return torch.cat(embeds, -1) def update_feat_shape(self, feat_shape: List[int]): if self.feat_shape is not None and feat_shape != self.feat_shape: # only update if feat_shape was set and different from previous value assert self.pos_embed is not None self.pos_embed = self._get_pos_embed_values( feat_shape, device=self.pos_embed.device, dtype=self.pos_embed.dtype, ) self.feat_shape = feat_shape def get_embed(self, shape: Optional[List[int]] = None): if shape is not None and self.bands is not None: # rebuild embeddings from cached bands every call, use if target shape changes embeds = build_rotary_pos_embed( shape, self.bands, in_pixels=self.in_pixels, ref_feat_shape=self.ref_feat_shape, grid_offset=self.grid_offset, grid_indexing=self.grid_indexing, ) return torch.cat(embeds, -1) elif self.pos_embed is not None: return self.pos_embed else: assert False, "get_embed() requires pre-computed pos embed or valid shape w/ pre-computed bands" def get_batch_embeds( self, shapes: List[Tuple[int, int]], seq_len: Optional[int] = None, ) -> Union[torch.Tensor, List[torch.Tensor]]: """Generate ROPE embeddings for multiple grid shapes efficiently. Computes embeddings for the maximum grid size once, then extracts and flattens the relevant portions for each requested shape. Args: shapes: List of (H, W) tuples representing different grid sizes Returns: List of concatenated sin/cos embeddings for each shape, where each tensor has shape (H*W, dim) """ if not shapes: return [] # Check if we have pre-computed bands if self.bands is None: # If we have pre-computed pos_embed for a fixed shape, we can't do batch generation raise RuntimeError("Batch embedding generation requires cached bands, not pre-computed embeddings") # Find max dimensions across all shapes max_h = max(h for h, w in shapes) max_w = max(w for h, w in shapes) # Generate embeddings for max size ONCE sin_emb, cos_emb = build_rotary_pos_embed( feat_shape=(max_h, max_w), bands=self.bands, in_pixels=self.in_pixels, ref_feat_shape=self.ref_feat_shape, grid_offset=self.grid_offset, grid_indexing=self.grid_indexing, ) # sin_emb and cos_emb are (max_h * max_w, dim//2) # concat and reshape to 2D for slicing rope_embed_2d = torch.cat([sin_emb, cos_emb], dim=-1).view(max_h, max_w, -1) if seq_len is not None: flat_embeds = torch.zeros(len(shapes), seq_len, rope_embed_2d.shape[-1]).type_as(sin_emb) for i, (h, w) in enumerate(shapes): src_len = h * w flat_embeds[i, :src_len] = rope_embed_2d[:h, :w].reshape(src_len, -1) return flat_embeds else: flat_embeds_list = [rope_embed_2d[:h, :w].reshape(h * w, -1) for h, w in shapes] return flat_embeds_list def forward(self, x): # assuming channel-first tensor where spatial dim are >= 2 pos_embed = self.get_embed(x.shape[2:]) return apply_rot_embed_cat(x, pos_embed) def init_random_2d_freqs( head_dim: int, depth: int, num_heads: int, temperature: float = 10.0, rotate: bool = True, *, device=None, dtype=torch.float32, ) -> torch.Tensor: """ Vectorised 2D ROPE frequencies with random rotation for mixed mode ROPE. Returns: Tensor (2, depth, num_heads, head_dim//2) """ # base magnitudes, shape: (head_dim//4,) mag = 1.0 / (temperature ** (torch.arange(0, head_dim, 4, device=device, dtype=dtype) / head_dim)) # (1,1,L) so it broadcasts over both depth and heads mag = mag.unsqueeze(0).unsqueeze(0) # (1,1,L) # random (or zero) rotation per head *and* per block if rotate: angles = torch.rand(depth, num_heads, 1, device=device, dtype=dtype) * 2 * torch.pi else: angles = torch.zeros(depth, num_heads, 1, device=device, dtype=dtype) # build (depth, num_heads, 2·L) == head_dim//2 on the last axis fx = torch.cat([mag * torch.cos(angles), mag * torch.cos(angles + torch.pi / 2)], dim=-1) fy = torch.cat([mag * torch.sin(angles), mag * torch.sin(angles + torch.pi / 2)], dim=-1) # (2, depth, num_heads, head_dim//2) return torch.stack([fx, fy], dim=0) @torch.fx.wrap @register_notrace_function def get_mixed_grid( shape: List[int], grid_indexing: str = 'ij', device: Optional[torch.device] = None, dtype: torch.dtype = torch.float32, ) -> Tuple[torch.Tensor, torch.Tensor]: if grid_indexing == 'xy': shape = swap_shape_xy(shape) x_pos, y_pos = torch.meshgrid( torch.arange(shape[0], device=device, dtype=torch.float32), torch.arange(shape[1], device=device, dtype=torch.float32), indexing=grid_indexing, ) t_x = x_pos.to(dtype).flatten() t_y = y_pos.to(dtype).flatten() return t_x, t_y def get_mixed_freqs( freqs: torch.Tensor, t_x: torch.Tensor, t_y: torch.Tensor, ) -> torch.Tensor: """Compute mixed (learnable) frequencies.""" # Create position indices dtype = freqs.dtype freqs = freqs.float() freqs_x = (t_x.unsqueeze(-1) @ freqs[0].unsqueeze(-2)) freqs_y = (t_y.unsqueeze(-1) @ freqs[1].unsqueeze(-2)) combined = freqs_x + freqs_y # shape: (num_heads, N, dim//4) sin_emb = torch.sin(combined).repeat_interleave(2, -1) # (N, dim//2) cos_emb = torch.cos(combined).repeat_interleave(2, -1) # (N, dim//2) rope_embeds = torch.cat([sin_emb, cos_emb], dim=-1) # (num_heads, H*W, head_dim) return rope_embeds.to(dtype) class RotaryEmbeddingMixed(nn.Module): """Rotary position embedding with depth-dependent learnable frequencies. This implementation supports mixed (learnable) ROPE. In mixed mode, each transformer block has its own set of learnable frequency parameters. Based on 'Rotary Position Embedding for Vision: https://arxiv.org/abs/2403.13298)' Compatible with original at https://github.com/naver-ai/rope-vit """ def __init__( self, dim: int, depth: int, num_heads: int, temperature: float = 10.0, feat_shape: Optional[List[int]] = None, grid_indexing: str = 'xy', device=None, dtype=None, ): """Initialize rotary embeddings. Args: dim: Embedding dimension (should be divisible by 4) depth: Number of transformer blocks num_heads: Number of attention heads temperature: Base for frequency computation feat_shape: Spatial dimensions [H, W] if known in advance grid_indexing: How to index grid positions ('xy' or 'ij') """ super().__init__() self.dim = dim self.depth = depth self.num_heads = num_heads self.temperature = temperature self.feat_shape = feat_shape self.grid_indexing = grid_indexing head_dim = dim // num_heads assert head_dim % 4 == 0, f"head_dim must be divisible by 4, got {head_dim}" freqs = init_random_2d_freqs( head_dim, depth, num_heads, temperature=temperature, rotate=True, device=device, dtype=dtype, ) # (2, depth, num_heads, head_dim//2) self.freqs = nn.Parameter(freqs) if feat_shape is not None: # cache pre-computed grid t_x, t_y = self._get_grid_values(feat_shape) self.register_buffer('t_x', t_x, persistent=False) self.register_buffer('t_y', t_y, persistent=False) else: self.t_x = self.t_y = None def _get_grid_values(self, feat_shape: Optional[List[int]]): t_x, t_y = get_mixed_grid( feat_shape, grid_indexing=self.grid_indexing, device=self.freqs.device, ) return t_x, t_y def update_feat_shape(self, feat_shape: Optional[List[int]]): if self.feat_shape is not None and feat_shape != self.feat_shape: assert self.t_x is not None assert self.t_y is not None t_x, t_y = self._get_grid_values(feat_shape) self.t_x = t_x.to(self.t_x.device, self.t_x.dtype) self.t_y = t_y.to(self.t_y.device, self.t_y.dtype) self.feat_shape = feat_shape def get_embed(self, shape: Optional[List[int]] = None) -> torch.Tensor: """Generate rotary embeddings for the given spatial shape. Args: shape: Spatial dimensions [H, W] Returns: Tensor of shape (depth, H*W, dim) containing concatenated sin/cos embeddings """ if shape is not None: t_x, t_y = get_mixed_grid( shape, grid_indexing=self.grid_indexing, device=self.freqs.device ) elif self.t_x is not None and self.t_y is not None: t_x, t_y = self.t_x, self.t_y else: assert False, "get_embed() requires pre-computed t_x/t_y or valid shape" return get_mixed_freqs(self.freqs, t_x, t_y) def get_batch_embeds( self, shapes: List[Tuple[int, int]], seq_len: Optional[int] = None, ) -> Union[torch.Tensor, List[torch.Tensor]]: """Generate ROPE embeddings for multiple grid shapes efficiently. Computes embeddings for the maximum grid size once, then extracts and flattens the relevant portions for each requested shape. Args: shapes: List of (H, W) tuples representing different grid sizes seq_len: If provided, return padded tensor of this length. Otherwise return list. Returns: If seq_len is provided: Padded tensor of shape (len(shapes), depth, num_heads, seq_len, dim) Otherwise: List of tensors with shape (depth, num_heads, H*W, dim) for each shape """ if not shapes: return [] # Find max dimensions max_h = max(h for h, w in shapes) max_w = max(w for h, w in shapes) # Generate embeddings for max size ONCE t_x, t_y = get_mixed_grid( [max_h, max_w], grid_indexing=self.grid_indexing, device=self.freqs.device ) max_embed = get_mixed_freqs(self.freqs, t_x, t_y) # (depth, num_heads, max_h*max_w, dim) # Reshape to 2D grid for easy slicing depth, num_heads, _, dim = max_embed.shape max_embed_2d = max_embed.view(depth, num_heads, max_h, max_w, dim) if seq_len is not None: # Return padded tensor B = len(shapes) padded = torch.zeros(B, depth, num_heads, seq_len, dim, device=self.freqs.device, dtype=self.freqs.dtype) for i, (h, w) in enumerate(shapes): # Slice and flatten embed_slice = max_embed_2d[:, :, :h, :w].reshape(depth, num_heads, h * w, dim) actual_len = h * w padded[i, :, :, :actual_len] = embed_slice return padded else: # Return list results = [] for h, w in shapes: # Slice and flatten embed_slice = max_embed_2d[:, :, :h, :w].reshape(depth, num_heads, h * w, dim) results.append(embed_slice) return results def forward(self, x): # assuming channel-first tensor where spatial dim are >= 2 pos_embed = self.get_embed(x.shape[2:]) return apply_rot_embed_cat(x, pos_embed) def no_weight_decay(self): """Exclude frequency parameters from weight decay.""" return {'freqs'} @torch.fx.wrap @register_notrace_function def make_coords_dinov3( height: int, width: int, normalize_coords: str = 'separate', grid_indexing: str = 'ij', grid_offset: float = 0., device: torch.device = 'cpu', dtype: torch.dtype = torch.float32, ) -> torch.Tensor: """Make coordinate grid matching offset and normalization of original. Returns: coords with shape (HW, 2) in [-1, 1]. """ # 0.5-centered indices with optional offset coords_h = torch.arange(0.5, height, device=device, dtype=torch.float32) + grid_offset coords_w = torch.arange(0.5, width, device=device, dtype=torch.float32) + grid_offset # Normalization denominators if normalize_coords == "max": denom = float(max(height, width)) h_denom = denom w_denom = denom elif normalize_coords == "min": denom = float(min(height, width)) h_denom = denom w_denom = denom elif normalize_coords == "separate": h_denom = float(height) w_denom = float(width) else: raise ValueError(f"Unknown normalize_coords: {normalize_coords}") # Normalize to [0, 1] coords_h = coords_h / h_denom coords_w = coords_w / w_denom coords_h = coords_h.to(dtype) coords_w = coords_w.to(dtype) # Create grid then map to [-1, 1] if grid_indexing == "xy": grid_w, grid_h = torch.meshgrid(coords_w, coords_h, indexing="xy") coords = torch.stack([grid_h, grid_w], dim=-1) # (H, W, 2) -> (h, w order) else: coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1) # (H, W, 2) coords = coords.flatten(0, 1) # (HW, 2) coords = 2.0 * coords - 1.0 # (H, W, 2) in [-1, 1] return coords class RotaryEmbeddingDinoV3(nn.Module): """RoPE for timm DinoV3 port, numerically matching original. Math is aligned to original DinoV3 RopePositionEmbedding at https://github.com/facebookresearch/dinov3: - 0.5-centered coords normalized by H/W (or min/max), mapped to [-1,1] - training-time augmentations (shift/jitter/rescale) - periods schedule equals Rope's temperature (base) or min/max period """ def __init__( self, dim: int, temperature: Optional[float] = 100.0, min_period: Optional[float] = None, max_period: Optional[float] = None, feat_shape: Optional[List[int]] = None, normalize_coords: str = "separate", # 'min', 'max', 'separate' grid_offset: float = 0.0, grid_indexing: str = "ij", rotate_half: bool = True, shift_coords: Optional[float] = None, jitter_coords: Optional[float] = None, # interpreted as factor J >= 1 rescale_coords: Optional[float] = None, # interpreted as factor R >= 1 device=None, dtype=None, ): super().__init__() # Dimensions / output format self.dim = dim # equal to head_dim for most vit applications self.rotate_half = rotate_half # Period schedule parameters self.temperature = float(temperature) self.min_period = min_period self.max_period = max_period # Coord processing + augs self.normalize_coords = normalize_coords self.shift_coords = shift_coords self.jitter_coords = jitter_coords self.rescale_coords = rescale_coords self.aug_active = any([a is not None for a in [self.shift_coords, self.jitter_coords, self.rescale_coords]]) # Grid config self.feat_shape = feat_shape self.grid_offset = grid_offset self.grid_indexing = grid_indexing # Precompute periods periods = self._compute_periods(device=device, dtype=dtype) self.register_buffer("periods", periods, persistent=False) if feat_shape is not None: self._cache_embed(feat_shape) else: self.register_buffer("pos_embed_cached", None, persistent=False) self.feat_shape = None def _compute_periods(self, device: torch.device = 'cpu', dtype: torch.dtype = torch.float32) -> torch.Tensor: """Construct periods from either min/max or temperature.""" dim = self.dim // 4 if self.min_period is not None and self.max_period is not None: exponents = torch.linspace(0, 1, dim, device='cpu', dtype=torch.float32) periods = self.min_period * ((self.max_period / self.min_period) ** exponents) else: if self.temperature is None: raise ValueError("Provide either min/max periods or `temperature`.") exponents = 2.0 * torch.arange(dim, device='cpu', dtype=torch.float32) / (self.dim // 2) periods = self.temperature ** exponents # NOTE: The original dinv3 model weights have periods downcast to bfloat16 in persistent buffers, # loaded models will differ a bit vs timm as periods is not persistent and generated in float32 by default return periods.to(device=device, dtype=dtype) def _apply_coord_augs(self, coords: torch.Tensor) -> torch.Tensor: """Apply shift/jitter/rescale train time augmentations.""" if not self.training or not self.aug_active: return coords device = coords.device dtype = coords.dtype # Shift per-axis in [-s, +s] if self.shift_coords is not None: shift = float(self.shift_coords) shift_hw = torch.empty(2, device=device, dtype=dtype).uniform_(-shift, shift) coords = coords + shift_hw[None, :] # Jitter: per-axis log-uniform factor in [1/J, J] if self.jitter_coords is not None: jitter_factor = float(self.jitter_coords) if jitter_factor <= 0: raise ValueError("jitter_coords must be > 0 (interpreted as multiplicative factor).") jitter_max = math.log(jitter_factor) jitter_hw = torch.empty(2, device=device, dtype=dtype).uniform_(-jitter_max, jitter_max).exp() coords = coords * jitter_hw[None, :] # Rescale: shared scalar log-uniform factor in [1/R, R] if self.rescale_coords is not None: rescale_factor = float(self.rescale_coords) if rescale_factor <= 0: raise ValueError("rescale_coords must be > 0 (interpreted as multiplicative factor).") rescale_max = math.log(rescale_factor) rescale = torch.empty(1, device=device, dtype=dtype).uniform_(-rescale_max, rescale_max).exp() coords = coords * rescale return coords def _get_pos_embed_from_coords(self, coords: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """Return sin/cos embeddings with either 'half' or 'interleaved' layout.""" # coords: (HW, 2); periods: (dim) dim = self.dim // 4 device = self.periods.device dtype = self.periods.dtype assert self.periods.numel() == dim # NOTE this is a slightly later device/dtype switch than original coords = coords[:, :, None].to(device=device, dtype=dtype) angles = 2 * math.pi * coords / self.periods[None, None, :] angles = angles.flatten(1) # (HW, dim // 2) if self.rotate_half: # Tile (half layout) (HW, dim // 2) -> (HW, dim) angles = angles.tile(2) else: # Interleaved layout (HW, dim // 2) -> (HW, dim) angles = angles.repeat_interleave(2, dim=-1) sin = torch.sin(angles) cos = torch.cos(angles) return sin, cos def _create_embed( self, feat_shape: List[int], no_aug: bool = False, ) -> torch.Tensor: H, W = feat_shape coords = make_coords_dinov3( H, W, normalize_coords=self.normalize_coords, grid_indexing=self.grid_indexing, grid_offset=self.grid_offset, ) # (HW, 2) if not no_aug: coords = self._apply_coord_augs(coords) sin, cos = self._get_pos_embed_from_coords(coords) # 2 * (HW, dim) rope_embed = torch.cat([sin, cos], dim=-1) # (HW, 2*dim) return rope_embed def _cache_embed(self, feat_shape: List[int]): # create non-augmented embeds for cache rope_embed = self._create_embed(feat_shape, no_aug=True) self.register_buffer("pos_embed_cached", rope_embed, persistent=False) self.feat_shape = feat_shape def update_feat_shape(self, feat_shape: List[int]): if self.feat_shape is not None and feat_shape != self.feat_shape: # only update if feat_shape was set (valid cache) and different from previous value self._cache_embed(feat_shape) def get_embed(self, shape: Optional[List[int]] = None) -> torch.Tensor: """Generate rope_embed matching DINOv3 RopePositionEmbedding numerics. Returns: (HW, num_heads, 2 * head_dim) with last dim = [sin, cos] cat. """ if shape is not None: rope_embed = self._create_embed(shape) else: need_create = self.pos_embed_cached is None or (self.training and self.aug_active) if need_create: assert self.feat_shape is not None, 'feature shape must be cached on create' rope_embed = self._create_embed(self.feat_shape) else: assert self.pos_embed_cached is not None rope_embed = self.pos_embed_cached return rope_embed def forward(self, x: torch.Tensor) -> torch.Tensor: """Get and apply rotary embeddings to x""" # assuming channel-first tensor where spatial dim are >= 2 pos_embed = self.get_embed(x.shape[2:]) return apply_rot_embed_cat(x, pos_embed, half=self.rotate_half) def create_rope_embed( rope_type: str = 'cat', dim: int = 768, num_heads: int = 12, **kwargs ) -> nn.Module: """Factory function for creating rotary position embeddings. Args: rope_type: Type of RoPE to create. Options: - 'base': Basic RotaryEmbedding - 'cat': RotaryEmbeddingCat (concatenated sin/cos) - 'mixed': RotaryEmbeddingMixed (learnable per-depth frequencies) - 'dinov3': RotaryEmbeddingDinoV3 (with coordinate transforms) dim: Total embedding dimension num_heads: Number of attention heads **kwargs: Additional arguments passed to the specific RoPE class Returns: Rotary embedding module """ if rope_type == 'base': return RotaryEmbedding(dim=dim // num_heads, **kwargs) elif rope_type == 'cat': return RotaryEmbeddingCat(dim=dim // num_heads, **kwargs) elif rope_type == 'mixed': # Mixed requires depth parameter, generates differing embeddings per layer and head return RotaryEmbeddingMixed(dim=dim, num_heads=num_heads, **kwargs) elif rope_type == 'dinov3': return RotaryEmbeddingDinoV3(dim=dim // num_heads, **kwargs) else: raise ValueError(f"Unknown RoPE type: {rope_type}")