Files
railseek6/openclip_env/Lib/site-packages/timm/layers/pos_embed_sincos.py

1177 lines
43 KiB
Python

""" Sin-cos, fourier, rotary position embedding modules and functions
Hacked together by / Copyright 2022 Ross Wightman
"""
import math
from typing import List, Tuple, Optional, Union
import torch
from torch import nn as nn
from ._fx import register_notrace_function
from .grid import ndgrid
from .trace_utils import _assert
def pixel_freq_bands(
num_bands: int,
max_freq: float = 224.,
linear_bands: bool = True,
device: Optional[torch.device] = None,
):
if linear_bands:
bands = torch.linspace(1.0, max_freq / 2, num_bands, dtype=torch.float32, device=device)
else:
bands = 2 ** torch.linspace(0, math.log(max_freq, 2) - 1, num_bands, dtype=torch.float32, device=device)
return bands * torch.pi
def freq_bands(
num_bands: int,
temperature: float = 10000.,
step: int = 2,
device: Optional[torch.device] = None,
) -> torch.Tensor:
exp = torch.arange(0, num_bands, step, dtype=torch.int64, device=device).to(torch.float32) / num_bands
bands = 1. / (temperature ** exp)
return bands
def build_sincos2d_pos_embed(
feat_shape: List[int],
dim: int = 64,
temperature: float = 10000.,
reverse_coord: bool = False,
interleave_sin_cos: bool = False,
device: Optional[torch.device] = None,
dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""
Args:
feat_shape:
dim:
temperature:
reverse_coord: stack grid order W, H instead of H, W
interleave_sin_cos: sin, cos, sin, cos stack instead of sin, sin, cos, cos
dtype:
device:
Returns:
"""
assert dim % 4 == 0, 'Embed dimension must be divisible by 4 for sin-cos 2D position embedding'
pos_dim = dim // 4
bands = freq_bands(pos_dim, temperature=temperature, step=1, device=device)
if reverse_coord:
feat_shape = feat_shape[::-1] # stack W, H instead of H, W
grid = torch.stack(ndgrid([
torch.arange(s, device=device, dtype=torch.int64).to(torch.float32)
for s in feat_shape
])).flatten(1).transpose(0, 1)
pos2 = grid.unsqueeze(-1) * bands.unsqueeze(0)
# FIXME add support for unflattened spatial dim?
stack_dim = 2 if interleave_sin_cos else 1 # stack sin, cos, sin, cos instead of sin sin cos cos
pos_emb = torch.stack([torch.sin(pos2), torch.cos(pos2)], dim=stack_dim).flatten(1)
return pos_emb.to(dtype=dtype)
def swap_shape_xy(seq: List[int]) -> List[int]:
if len(seq) < 2:
return seq
return [seq[1], seq[0]] + list(seq[2:])
def build_fourier_pos_embed(
feat_shape: List[int],
bands: Optional[torch.Tensor] = None,
num_bands: int = 64,
max_res: int = 224,
temperature: float = 10000.,
linear_bands: bool = False,
include_grid: bool = False,
in_pixels: bool = True,
ref_feat_shape: Optional[List[int]] = None,
grid_offset: float = 0.,
grid_indexing: str = 'ij',
device: Optional[torch.device] = None,
dtype: torch.dtype = torch.float32,
) -> List[torch.Tensor]:
"""
Args:
feat_shape: Feature shape for embedding.
bands: Pre-calculated frequency bands.
num_bands: Number of frequency bands (determines output dim).
max_res: Maximum resolution for pixel based freq.
temperature: Temperature for non-pixel freq.
linear_bands: Linear band spacing for pixel based freq.
include_grid: Include the spatial grid in output.
in_pixels: Output in pixel freq.
ref_feat_shape: Reference feature shape for resize / fine-tune.
grid_offset: Constant offset to add to grid for non-pixel freq.
grid_indexing: Indexing mode for meshgrid ('ij' or 'xy')
dtype: Output dtype.
device: Output device.
Returns:
"""
if bands is None:
if in_pixels:
bands = pixel_freq_bands(
num_bands,
float(max_res),
linear_bands=linear_bands,
device=device,
)
else:
bands = freq_bands(
num_bands,
temperature=temperature,
step=1,
device=device,
)
else:
if device is None:
device = bands.device
if dtype is None:
dtype = bands.dtype
if grid_indexing == 'xy':
feat_shape = swap_shape_xy(feat_shape)
if ref_feat_shape is not None:
ref_feat_shape = swap_shape_xy(ref_feat_shape)
if in_pixels:
t = [
torch.linspace(-1., 1., steps=s, device=device, dtype=torch.float32)
for s in feat_shape
]
else:
t = [
torch.arange(s, device=device, dtype=torch.int64).to(torch.float32) + grid_offset
for s in feat_shape
]
if ref_feat_shape is not None:
# eva's scheme for resizing rope embeddings (ref shape = pretrain)
t = [x / f * r for x, f, r in zip(t, feat_shape, ref_feat_shape)]
grid = torch.stack(torch.meshgrid(t, indexing=grid_indexing), dim=-1)
grid = grid.unsqueeze(-1)
pos = grid * bands
pos_sin, pos_cos = pos.sin().to(dtype=dtype), pos.cos().to(dtype=dtype)
out = [grid, pos_sin, pos_cos] if include_grid else [pos_sin, pos_cos]
return out
class FourierEmbed(nn.Module):
def __init__(
self,
max_res: int = 224,
num_bands: int = 64,
concat_grid=True,
keep_spatial=False,
device=None,
dtype=None,
):
super().__init__()
self.max_res = max_res
self.num_bands = num_bands
self.concat_grid = concat_grid
self.keep_spatial = keep_spatial
self.register_buffer(
'bands',
pixel_freq_bands(max_res, num_bands).to(device=device, dtype=dtype),
persistent=False,
)
def forward(self, x):
B, C = x.shape[:2]
feat_shape = x.shape[2:]
emb = build_fourier_pos_embed(
feat_shape,
self.bands,
include_grid=self.concat_grid,
dtype=x.dtype,
device=x.device,
)
emb = torch.cat(emb, dim=-1)
emb = emb.transpose(-1, -2).flatten(len(feat_shape))
batch_expand = (B,) + (-1,) * (x.ndim - 1)
# FIXME support nD
if self.keep_spatial:
x = torch.cat([x, emb.unsqueeze(0).expand(batch_expand).permute(0, 3, 1, 2)], dim=1)
else:
x = torch.cat([x.permute(0, 2, 3, 1), emb.unsqueeze(0).expand(batch_expand)], dim=-1)
x = x.reshape(B, feat_shape.numel(), -1)
return x
def rot(x):
# x: [ x0 x1 x2 x3 x4 x5]
# out: [-x1 x0 -x3 x2 -x5 x4]
return torch.stack([-x[..., 1::2], x[..., ::2]], -1).reshape(x.shape)
def rope_rotate_half(x: torch.Tensor) -> torch.Tensor:
# x: [ x0 x1 x2 x3 x4 x5]
# out: [-x3 -x4 -x5 x0 x1 x2]
x1, x2 = x.chunk(2, dim=-1)
return torch.cat([-x2, x1], dim=-1)
def apply_rot_embed(
x: torch.Tensor,
sin_emb: torch.Tensor,
cos_emb: torch.Tensor,
half: bool = False,
) -> torch.Tensor:
# x: [..., D], eg [x0, x1, x2, x3, x4, x5]
if half:
# sin: [..., D], eg [sin0, sin1, sin2, sin0, sin1, sin2]
# cos: [..., D], eg [cos0, cos1, cos2, cos0, cos1, cos2
# rope_rotate_half(x): eg [-x3, -x4, -x5, x0, x1, x2]
return x * cos_emb + rope_rotate_half(x) * sin_emb
else:
# sin: [..., D], eg [sin0, sin0, sin1, sin1, sin2, sin2]
# cos: [..., D], eg [cos0, cos0, cos1, cos1, cos2, cos2]
# rot(x): eg [-x1, x0, -x3, x2, -x5, x4]
return x * cos_emb + rot(x) * sin_emb
def apply_rot_embed_list(
x: List[torch.Tensor],
sin_emb: torch.Tensor,
cos_emb: torch.Tensor,
half: bool = False
) -> List[torch.Tensor]:
if isinstance(x, torch.Tensor):
x = [x]
# x: [..., D], eg [x0, x1, x2, x3, x4, x5]
if half:
# sin: [..., D], eg [sin0, sin1, sin2, sin0, sin1, sin2]
# cos: [..., D], eg [cos0, cos1, cos2, cos0, cos1, cos2
# rope_rotate_half(x): eg [-x3, -x4, -x5, x0, x1, x2]
return [t * cos_emb + rope_rotate_half(t) * sin_emb for t in x]
else:
# sin: [..., D], eg [sin0, sin0, sin1, sin1, sin2, sin2]
# cos: [..., D], eg [cos0, cos0, cos1, cos1, cos2, cos2]
# rot(x): eg [-x1, x0, -x3, x2, -x5, x4]
return [t * cos_emb + rot(t) * sin_emb for t in x]
def apply_rot_embed_cat(
x: torch.Tensor,
emb: torch.Tensor,
half: bool = False
) -> torch.Tensor:
sin_emb, cos_emb = emb.chunk(2, -1)
# x: [..., D], eg [x0, x1, x2, x3, x4, x5]
if half:
# sin: [..., D], eg [sin0, sin1, sin2, sin0, sin1, sin2]
# cos: [..., D], eg [cos0, cos1, cos2, cos0, cos1, cos2
# rope_rotate_half(x), eg [-x3, -x4, -x5, x0, x1, x2]
return x * cos_emb + rope_rotate_half(x) * sin_emb
else:
# sin: [..., D], eg [sin0, sin0, sin1, sin1, sin2, sin2]
# cos: [..., D], eg [cos0, cos0, cos1, cos1, cos2, cos2]
# rot(x), eg [-x1, x0, -x3, x2, -x5, x4]
return x * cos_emb + rot(x) * sin_emb
def apply_keep_indices_nlc(
x: torch.Tensor,
pos_embed: torch.Tensor,
keep_indices: torch.Tensor,
pos_embed_has_batch: bool = False,
) -> torch.Tensor:
""" Apply keep indices to different ROPE shapes
Expected pos_embed shapes:
* [seq_len, pos_embed_dim] --> output [batch_size, seq_len, pos_embed_dim]
* [num_heads, seq_len, pos_embed_dim] --> output [batch_size, num_heads, seq_len, pos_embed_dim]
* [depth, num_heads, seq_len, pos_embed_dim] --> output [batch_size, depth, num_heads, seq_len, pos_embed_dim]
And all of the above with leading batch dimension already present if `pos_embed_has_batch == True`
"""
if pos_embed_has_batch:
# Pos embed already includes batch dim
_assert(pos_embed.ndim >= 3, 'Incorrect number of dimensions') # At least [batch, seq_len, pos_embed_dim]
else:
# Add batch dimension and expand to batch size
_assert(pos_embed.ndim >= 2, 'Incorrect number of dimensions') # At least [seq_len, pos_embed_dim]
expand_shape = (x.shape[0],) + (-1,) * pos_embed.ndim
pos_embed = pos_embed.unsqueeze(0).expand(expand_shape)
# Reshape keep_indices to add singleton dims
keep_shape = (keep_indices.shape[0],) + (1,) * (pos_embed.ndim - 3) + (keep_indices.shape[1], 1)
keep_indices = keep_indices.view(keep_shape)
# Expand all dims to match position embedding except the gather dim (second-last)
keep_expand = list(pos_embed.shape)
keep_expand[-2] = -1
keep_indices = keep_indices.expand(keep_expand)
return pos_embed.gather(-2, keep_indices)
def build_rotary_pos_embed(
feat_shape: List[int],
bands: Optional[torch.Tensor] = None,
dim: int = 64,
max_res: int = 224,
temperature: float = 10000.,
linear_bands: bool = False,
in_pixels: bool = True,
ref_feat_shape: Optional[List[int]] = None,
grid_offset: float = 0.,
grid_indexing: str = 'ij',
device: Optional[torch.device] = None,
dtype: torch.dtype = torch.float32,
):
"""
Args:
feat_shape: Spatial shape of the target tensor for embedding.
bands: Optional pre-generated frequency bands
dim: Output dimension of embedding tensor.
max_res: Maximum resolution for pixel mode.
temperature: Temperature (inv freq) for non-pixel mode
linear_bands: Linearly (instead of log) spaced bands for pixel mode
in_pixels: Pixel vs language (inv freq) mode.
ref_feat_shape: Reference feature shape for resize / fine-tune.
grid_offset: Constant offset to add to grid for non-pixel freq.
grid_indexing: Indexing mode for meshgrid ('ij' or 'xy')
device: Output device.
dtype: Output dtype.
Returns:
"""
sin_emb, cos_emb = build_fourier_pos_embed(
feat_shape,
bands=bands,
num_bands=dim // 4,
max_res=max_res,
temperature=temperature,
linear_bands=linear_bands,
in_pixels=in_pixels,
ref_feat_shape=ref_feat_shape,
grid_offset=grid_offset,
grid_indexing=grid_indexing,
device=device,
dtype=dtype,
)
num_spatial_dim = 1
# this would be much nicer as a .numel() call to torch.Size(), but torchscript sucks
for x in feat_shape:
num_spatial_dim *= x
sin_emb = sin_emb.reshape(num_spatial_dim, -1).repeat_interleave(2, -1)
cos_emb = cos_emb.reshape(num_spatial_dim, -1).repeat_interleave(2, -1)
return sin_emb, cos_emb
class RotaryEmbedding(nn.Module):
""" Rotary position embedding
NOTE: This is my initial attempt at impl rotary embedding for spatial use, it has not
been well tested, and will likely change. It will be moved to its own file.
The following impl/resources were referenced for this impl:
* https://github.com/lucidrains/vit-pytorch/blob/6f3a5fcf0bca1c5ec33a35ef48d97213709df4ba/vit_pytorch/rvt.py
* https://blog.eleuther.ai/rotary-embeddings/
"""
def __init__(
self,
dim,
max_res=224,
temperature=10000,
in_pixels=True,
linear_bands: bool = False,
feat_shape: Optional[List[int]] = None,
ref_feat_shape: Optional[List[int]] = None,
grid_offset: float = 0.,
grid_indexing: str = 'ij',
device=None,
dtype=None,
):
super().__init__()
self.dim = dim
self.max_res = max_res
self.temperature = temperature
self.linear_bands = linear_bands
self.in_pixels = in_pixels
self.feat_shape = feat_shape
self.ref_feat_shape = ref_feat_shape
self.grid_offset = grid_offset
self.grid_indexing = grid_indexing
if feat_shape is None:
# only cache bands
if in_pixels:
bands = pixel_freq_bands(
dim // 4,
float(max_res),
linear_bands=linear_bands,
)
else:
bands = freq_bands(
dim // 4,
temperature=temperature,
step=1,
)
self.register_buffer(
'bands',
bands.to(device=device, dtype=dtype),
persistent=False,
)
self.pos_embed_sin = None
self.pos_embed_cos = None
else:
# cache full sin/cos embeddings if shape provided up front
emb_sin, emb_cos = self._get_pos_embed_values(feat_shape, device=device, dtype=dtype)
self.bands = None
self.register_buffer(
'pos_embed_sin',
emb_sin,
persistent=False,
)
self.register_buffer(
'pos_embed_cos',
emb_cos,
persistent=False,
)
def _get_pos_embed_values(self, feat_shape: List[int], device=None, dtype=torch.float32):
emb_sin, emb_cos = build_rotary_pos_embed(
feat_shape=feat_shape,
dim=self.dim,
max_res=self.max_res,
temperature=self.temperature,
linear_bands=self.linear_bands,
in_pixels=self.in_pixels,
ref_feat_shape=self.ref_feat_shape,
grid_offset=self.grid_offset,
grid_indexing=self.grid_indexing,
device=device,
dtype=dtype,
)
return emb_sin, emb_cos
def update_feat_shape(self, feat_shape: List[int]):
if self.feat_shape is not None and feat_shape != self.feat_shape:
# only update if feat_shape was set and different from previous value
assert self.pos_embed_sin is not None
assert self.pos_embed_cos is not None
self.pos_embed_sin, self.pos_embed_cos = self._get_pos_embed_values(
feat_shape,
device=self.pos_embed_sin.device,
dtype=self.pos_embed_sin.dtype,
)
self.feat_shape = feat_shape
def get_embed(self, shape: Optional[List[int]] = None):
if shape is not None and self.bands is not None:
# rebuild embeddings every call, use if target shape changes
return build_rotary_pos_embed(
shape,
self.bands,
in_pixels=self.in_pixels,
ref_feat_shape=self.ref_feat_shape,
grid_offset=self.grid_offset,
grid_indexing=self.grid_indexing,
)
elif self.pos_embed_sin is not None and self.pos_embed_cos is not None:
return self.pos_embed_sin, self.pos_embed_cos
else:
assert False, "get_embed() requires pre-computed pos embeds or valid shape w/ pre-computed bands"
def forward(self, x):
# assuming channel-first tensor where spatial dim are >= 2
sin_emb, cos_emb = self.get_embed(x.shape[2:])
return apply_rot_embed(x, sin_emb, cos_emb)
class RotaryEmbeddingCat(nn.Module):
""" Rotary position embedding w/ concatenatd sin & cos
The following impl/resources were referenced for this impl:
* https://github.com/lucidrains/vit-pytorch/blob/6f3a5fcf0bca1c5ec33a35ef48d97213709df4ba/vit_pytorch/rvt.py
* https://blog.eleuther.ai/rotary-embeddings/
"""
def __init__(
self,
dim: int,
max_res: int = 224,
temperature: float = 10000,
in_pixels: bool = True,
linear_bands: bool = False,
feat_shape: Optional[List[int]] = None,
ref_feat_shape: Optional[List[int]] = None,
grid_offset: float = 0.,
grid_indexing: str = 'ij',
device=None,
dtype=None,
):
super().__init__()
self.dim = dim
self.max_res = max_res
self.temperature = temperature
self.in_pixels = in_pixels
self.linear_bands = linear_bands
self.feat_shape = feat_shape
self.ref_feat_shape = ref_feat_shape
self.grid_offset = grid_offset
self.grid_indexing = grid_indexing
if feat_shape is None:
# only cache bands
if in_pixels:
bands = pixel_freq_bands(
dim // 4,
float(max_res),
linear_bands=linear_bands,
)
else:
bands = freq_bands(
dim // 4,
temperature=temperature,
step=1,
)
self.register_buffer(
'bands',
bands.to(device=device, dtype=dtype),
persistent=False,
)
self.pos_embed = None
else:
# cache full sin/cos embeddings if shape provided up front
self.bands = None
self.register_buffer(
'pos_embed',
self._get_pos_embed_values(feat_shape=feat_shape, device=device, dtype=dtype),
persistent=False,
)
def _get_pos_embed_values(self, feat_shape: List[int], device=None, dtype=torch.float32):
embeds = build_rotary_pos_embed(
feat_shape=feat_shape,
dim=self.dim,
max_res=self.max_res,
temperature=self.temperature,
linear_bands=self.linear_bands,
in_pixels=self.in_pixels,
ref_feat_shape=self.ref_feat_shape,
grid_offset=self.grid_offset,
grid_indexing=self.grid_indexing,
device=device,
dtype=dtype,
)
return torch.cat(embeds, -1)
def update_feat_shape(self, feat_shape: List[int]):
if self.feat_shape is not None and feat_shape != self.feat_shape:
# only update if feat_shape was set and different from previous value
assert self.pos_embed is not None
self.pos_embed = self._get_pos_embed_values(
feat_shape,
device=self.pos_embed.device,
dtype=self.pos_embed.dtype,
)
self.feat_shape = feat_shape
def get_embed(self, shape: Optional[List[int]] = None):
if shape is not None and self.bands is not None:
# rebuild embeddings from cached bands every call, use if target shape changes
embeds = build_rotary_pos_embed(
shape,
self.bands,
in_pixels=self.in_pixels,
ref_feat_shape=self.ref_feat_shape,
grid_offset=self.grid_offset,
grid_indexing=self.grid_indexing,
)
return torch.cat(embeds, -1)
elif self.pos_embed is not None:
return self.pos_embed
else:
assert False, "get_embed() requires pre-computed pos embed or valid shape w/ pre-computed bands"
def get_batch_embeds(
self,
shapes: List[Tuple[int, int]],
seq_len: Optional[int] = None,
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""Generate ROPE embeddings for multiple grid shapes efficiently.
Computes embeddings for the maximum grid size once, then extracts
and flattens the relevant portions for each requested shape.
Args:
shapes: List of (H, W) tuples representing different grid sizes
Returns:
List of concatenated sin/cos embeddings for each shape,
where each tensor has shape (H*W, dim)
"""
if not shapes:
return []
# Check if we have pre-computed bands
if self.bands is None:
# If we have pre-computed pos_embed for a fixed shape, we can't do batch generation
raise RuntimeError("Batch embedding generation requires cached bands, not pre-computed embeddings")
# Find max dimensions across all shapes
max_h = max(h for h, w in shapes)
max_w = max(w for h, w in shapes)
# Generate embeddings for max size ONCE
sin_emb, cos_emb = build_rotary_pos_embed(
feat_shape=(max_h, max_w),
bands=self.bands,
in_pixels=self.in_pixels,
ref_feat_shape=self.ref_feat_shape,
grid_offset=self.grid_offset,
grid_indexing=self.grid_indexing,
)
# sin_emb and cos_emb are (max_h * max_w, dim//2)
# concat and reshape to 2D for slicing
rope_embed_2d = torch.cat([sin_emb, cos_emb], dim=-1).view(max_h, max_w, -1)
if seq_len is not None:
flat_embeds = torch.zeros(len(shapes), seq_len, rope_embed_2d.shape[-1]).type_as(sin_emb)
for i, (h, w) in enumerate(shapes):
src_len = h * w
flat_embeds[i, :src_len] = rope_embed_2d[:h, :w].reshape(src_len, -1)
return flat_embeds
else:
flat_embeds_list = [rope_embed_2d[:h, :w].reshape(h * w, -1) for h, w in shapes]
return flat_embeds_list
def forward(self, x):
# assuming channel-first tensor where spatial dim are >= 2
pos_embed = self.get_embed(x.shape[2:])
return apply_rot_embed_cat(x, pos_embed)
def init_random_2d_freqs(
head_dim: int,
depth: int,
num_heads: int,
temperature: float = 10.0,
rotate: bool = True,
*,
device=None,
dtype=torch.float32,
) -> torch.Tensor:
""" Vectorised 2D ROPE frequencies with random rotation for mixed mode ROPE.
Returns:
Tensor (2, depth, num_heads, head_dim//2)
"""
# base magnitudes, shape: (head_dim//4,)
mag = 1.0 / (temperature ** (torch.arange(0, head_dim, 4, device=device, dtype=dtype) / head_dim))
# (1,1,L) so it broadcasts over both depth and heads
mag = mag.unsqueeze(0).unsqueeze(0) # (1,1,L)
# random (or zero) rotation per head *and* per block
if rotate:
angles = torch.rand(depth, num_heads, 1, device=device, dtype=dtype) * 2 * torch.pi
else:
angles = torch.zeros(depth, num_heads, 1, device=device, dtype=dtype)
# build (depth, num_heads, 2·L) == head_dim//2 on the last axis
fx = torch.cat([mag * torch.cos(angles), mag * torch.cos(angles + torch.pi / 2)], dim=-1)
fy = torch.cat([mag * torch.sin(angles), mag * torch.sin(angles + torch.pi / 2)], dim=-1)
# (2, depth, num_heads, head_dim//2)
return torch.stack([fx, fy], dim=0)
@torch.fx.wrap
@register_notrace_function
def get_mixed_grid(
shape: List[int],
grid_indexing: str = 'ij',
device: Optional[torch.device] = None,
dtype: torch.dtype = torch.float32,
) -> Tuple[torch.Tensor, torch.Tensor]:
if grid_indexing == 'xy':
shape = swap_shape_xy(shape)
x_pos, y_pos = torch.meshgrid(
torch.arange(shape[0], device=device, dtype=torch.float32),
torch.arange(shape[1], device=device, dtype=torch.float32),
indexing=grid_indexing,
)
t_x = x_pos.to(dtype).flatten()
t_y = y_pos.to(dtype).flatten()
return t_x, t_y
def get_mixed_freqs(
freqs: torch.Tensor,
t_x: torch.Tensor,
t_y: torch.Tensor,
) -> torch.Tensor:
"""Compute mixed (learnable) frequencies."""
# Create position indices
dtype = freqs.dtype
freqs = freqs.float()
freqs_x = (t_x.unsqueeze(-1) @ freqs[0].unsqueeze(-2))
freqs_y = (t_y.unsqueeze(-1) @ freqs[1].unsqueeze(-2))
combined = freqs_x + freqs_y # shape: (num_heads, N, dim//4)
sin_emb = torch.sin(combined).repeat_interleave(2, -1) # (N, dim//2)
cos_emb = torch.cos(combined).repeat_interleave(2, -1) # (N, dim//2)
rope_embeds = torch.cat([sin_emb, cos_emb], dim=-1) # (num_heads, H*W, head_dim)
return rope_embeds.to(dtype)
class RotaryEmbeddingMixed(nn.Module):
"""Rotary position embedding with depth-dependent learnable frequencies.
This implementation supports mixed (learnable) ROPE. In mixed mode,
each transformer block has its own set of learnable frequency parameters.
Based on 'Rotary Position Embedding for Vision: https://arxiv.org/abs/2403.13298)'
Compatible with original at https://github.com/naver-ai/rope-vit
"""
def __init__(
self,
dim: int,
depth: int,
num_heads: int,
temperature: float = 10.0,
feat_shape: Optional[List[int]] = None,
grid_indexing: str = 'xy',
device=None,
dtype=None,
):
"""Initialize rotary embeddings.
Args:
dim: Embedding dimension (should be divisible by 4)
depth: Number of transformer blocks
num_heads: Number of attention heads
temperature: Base for frequency computation
feat_shape: Spatial dimensions [H, W] if known in advance
grid_indexing: How to index grid positions ('xy' or 'ij')
"""
super().__init__()
self.dim = dim
self.depth = depth
self.num_heads = num_heads
self.temperature = temperature
self.feat_shape = feat_shape
self.grid_indexing = grid_indexing
head_dim = dim // num_heads
assert head_dim % 4 == 0, f"head_dim must be divisible by 4, got {head_dim}"
freqs = init_random_2d_freqs(
head_dim,
depth,
num_heads,
temperature=temperature,
rotate=True,
device=device,
dtype=dtype,
) # (2, depth, num_heads, head_dim//2)
self.freqs = nn.Parameter(freqs)
if feat_shape is not None:
# cache pre-computed grid
t_x, t_y = self._get_grid_values(feat_shape)
self.register_buffer('t_x', t_x, persistent=False)
self.register_buffer('t_y', t_y, persistent=False)
else:
self.t_x = self.t_y = None
def _get_grid_values(self, feat_shape: Optional[List[int]]):
t_x, t_y = get_mixed_grid(
feat_shape,
grid_indexing=self.grid_indexing,
device=self.freqs.device,
)
return t_x, t_y
def update_feat_shape(self, feat_shape: Optional[List[int]]):
if self.feat_shape is not None and feat_shape != self.feat_shape:
assert self.t_x is not None
assert self.t_y is not None
t_x, t_y = self._get_grid_values(feat_shape)
self.t_x = t_x.to(self.t_x.device, self.t_x.dtype)
self.t_y = t_y.to(self.t_y.device, self.t_y.dtype)
self.feat_shape = feat_shape
def get_embed(self, shape: Optional[List[int]] = None) -> torch.Tensor:
"""Generate rotary embeddings for the given spatial shape.
Args:
shape: Spatial dimensions [H, W]
Returns:
Tensor of shape (depth, H*W, dim) containing concatenated sin/cos embeddings
"""
if shape is not None:
t_x, t_y = get_mixed_grid(
shape,
grid_indexing=self.grid_indexing,
device=self.freqs.device
)
elif self.t_x is not None and self.t_y is not None:
t_x, t_y = self.t_x, self.t_y
else:
assert False, "get_embed() requires pre-computed t_x/t_y or valid shape"
return get_mixed_freqs(self.freqs, t_x, t_y)
def get_batch_embeds(
self,
shapes: List[Tuple[int, int]],
seq_len: Optional[int] = None,
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""Generate ROPE embeddings for multiple grid shapes efficiently.
Computes embeddings for the maximum grid size once, then extracts
and flattens the relevant portions for each requested shape.
Args:
shapes: List of (H, W) tuples representing different grid sizes
seq_len: If provided, return padded tensor of this length. Otherwise return list.
Returns:
If seq_len is provided: Padded tensor of shape (len(shapes), depth, num_heads, seq_len, dim)
Otherwise: List of tensors with shape (depth, num_heads, H*W, dim) for each shape
"""
if not shapes:
return []
# Find max dimensions
max_h = max(h for h, w in shapes)
max_w = max(w for h, w in shapes)
# Generate embeddings for max size ONCE
t_x, t_y = get_mixed_grid(
[max_h, max_w],
grid_indexing=self.grid_indexing,
device=self.freqs.device
)
max_embed = get_mixed_freqs(self.freqs, t_x, t_y) # (depth, num_heads, max_h*max_w, dim)
# Reshape to 2D grid for easy slicing
depth, num_heads, _, dim = max_embed.shape
max_embed_2d = max_embed.view(depth, num_heads, max_h, max_w, dim)
if seq_len is not None:
# Return padded tensor
B = len(shapes)
padded = torch.zeros(B, depth, num_heads, seq_len, dim, device=self.freqs.device, dtype=self.freqs.dtype)
for i, (h, w) in enumerate(shapes):
# Slice and flatten
embed_slice = max_embed_2d[:, :, :h, :w].reshape(depth, num_heads, h * w, dim)
actual_len = h * w
padded[i, :, :, :actual_len] = embed_slice
return padded
else:
# Return list
results = []
for h, w in shapes:
# Slice and flatten
embed_slice = max_embed_2d[:, :, :h, :w].reshape(depth, num_heads, h * w, dim)
results.append(embed_slice)
return results
def forward(self, x):
# assuming channel-first tensor where spatial dim are >= 2
pos_embed = self.get_embed(x.shape[2:])
return apply_rot_embed_cat(x, pos_embed)
def no_weight_decay(self):
"""Exclude frequency parameters from weight decay."""
return {'freqs'}
@torch.fx.wrap
@register_notrace_function
def make_coords_dinov3(
height: int,
width: int,
normalize_coords: str = 'separate',
grid_indexing: str = 'ij',
grid_offset: float = 0.,
device: torch.device = 'cpu',
dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""Make coordinate grid matching offset and normalization of original.
Returns: coords with shape (HW, 2) in [-1, 1].
"""
# 0.5-centered indices with optional offset
coords_h = torch.arange(0.5, height, device=device, dtype=torch.float32) + grid_offset
coords_w = torch.arange(0.5, width, device=device, dtype=torch.float32) + grid_offset
# Normalization denominators
if normalize_coords == "max":
denom = float(max(height, width))
h_denom = denom
w_denom = denom
elif normalize_coords == "min":
denom = float(min(height, width))
h_denom = denom
w_denom = denom
elif normalize_coords == "separate":
h_denom = float(height)
w_denom = float(width)
else:
raise ValueError(f"Unknown normalize_coords: {normalize_coords}")
# Normalize to [0, 1]
coords_h = coords_h / h_denom
coords_w = coords_w / w_denom
coords_h = coords_h.to(dtype)
coords_w = coords_w.to(dtype)
# Create grid then map to [-1, 1]
if grid_indexing == "xy":
grid_w, grid_h = torch.meshgrid(coords_w, coords_h, indexing="xy")
coords = torch.stack([grid_h, grid_w], dim=-1) # (H, W, 2) -> (h, w order)
else:
coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1) # (H, W, 2)
coords = coords.flatten(0, 1) # (HW, 2)
coords = 2.0 * coords - 1.0 # (H, W, 2) in [-1, 1]
return coords
class RotaryEmbeddingDinoV3(nn.Module):
"""RoPE for timm DinoV3 port, numerically matching original.
Math is aligned to original DinoV3 RopePositionEmbedding at https://github.com/facebookresearch/dinov3:
- 0.5-centered coords normalized by H/W (or min/max), mapped to [-1,1]
- training-time augmentations (shift/jitter/rescale)
- periods schedule equals Rope's temperature (base) or min/max period
"""
def __init__(
self,
dim: int,
temperature: Optional[float] = 100.0,
min_period: Optional[float] = None,
max_period: Optional[float] = None,
feat_shape: Optional[List[int]] = None,
normalize_coords: str = "separate", # 'min', 'max', 'separate'
grid_offset: float = 0.0,
grid_indexing: str = "ij",
rotate_half: bool = True,
shift_coords: Optional[float] = None,
jitter_coords: Optional[float] = None, # interpreted as factor J >= 1
rescale_coords: Optional[float] = None, # interpreted as factor R >= 1
device=None,
dtype=None,
):
super().__init__()
# Dimensions / output format
self.dim = dim # equal to head_dim for most vit applications
self.rotate_half = rotate_half
# Period schedule parameters
self.temperature = float(temperature)
self.min_period = min_period
self.max_period = max_period
# Coord processing + augs
self.normalize_coords = normalize_coords
self.shift_coords = shift_coords
self.jitter_coords = jitter_coords
self.rescale_coords = rescale_coords
self.aug_active = any([a is not None for a in [self.shift_coords, self.jitter_coords, self.rescale_coords]])
# Grid config
self.feat_shape = feat_shape
self.grid_offset = grid_offset
self.grid_indexing = grid_indexing
# Precompute periods
periods = self._compute_periods(device=device, dtype=dtype)
self.register_buffer("periods", periods, persistent=False)
if feat_shape is not None:
self._cache_embed(feat_shape)
else:
self.register_buffer("pos_embed_cached", None, persistent=False)
self.feat_shape = None
def _compute_periods(self, device: torch.device = 'cpu', dtype: torch.dtype = torch.float32) -> torch.Tensor:
"""Construct periods from either min/max or temperature."""
dim = self.dim // 4
if self.min_period is not None and self.max_period is not None:
exponents = torch.linspace(0, 1, dim, device='cpu', dtype=torch.float32)
periods = self.min_period * ((self.max_period / self.min_period) ** exponents)
else:
if self.temperature is None:
raise ValueError("Provide either min/max periods or `temperature`.")
exponents = 2.0 * torch.arange(dim, device='cpu', dtype=torch.float32) / (self.dim // 2)
periods = self.temperature ** exponents
# NOTE: The original dinv3 model weights have periods downcast to bfloat16 in persistent buffers,
# loaded models will differ a bit vs timm as periods is not persistent and generated in float32 by default
return periods.to(device=device, dtype=dtype)
def _apply_coord_augs(self, coords: torch.Tensor) -> torch.Tensor:
"""Apply shift/jitter/rescale train time augmentations."""
if not self.training or not self.aug_active:
return coords
device = coords.device
dtype = coords.dtype
# Shift per-axis in [-s, +s]
if self.shift_coords is not None:
shift = float(self.shift_coords)
shift_hw = torch.empty(2, device=device, dtype=dtype).uniform_(-shift, shift)
coords = coords + shift_hw[None, :]
# Jitter: per-axis log-uniform factor in [1/J, J]
if self.jitter_coords is not None:
jitter_factor = float(self.jitter_coords)
if jitter_factor <= 0:
raise ValueError("jitter_coords must be > 0 (interpreted as multiplicative factor).")
jitter_max = math.log(jitter_factor)
jitter_hw = torch.empty(2, device=device, dtype=dtype).uniform_(-jitter_max, jitter_max).exp()
coords = coords * jitter_hw[None, :]
# Rescale: shared scalar log-uniform factor in [1/R, R]
if self.rescale_coords is not None:
rescale_factor = float(self.rescale_coords)
if rescale_factor <= 0:
raise ValueError("rescale_coords must be > 0 (interpreted as multiplicative factor).")
rescale_max = math.log(rescale_factor)
rescale = torch.empty(1, device=device, dtype=dtype).uniform_(-rescale_max, rescale_max).exp()
coords = coords * rescale
return coords
def _get_pos_embed_from_coords(self, coords: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Return sin/cos embeddings with either 'half' or 'interleaved' layout."""
# coords: (HW, 2); periods: (dim)
dim = self.dim // 4
device = self.periods.device
dtype = self.periods.dtype
assert self.periods.numel() == dim
# NOTE this is a slightly later device/dtype switch than original
coords = coords[:, :, None].to(device=device, dtype=dtype)
angles = 2 * math.pi * coords / self.periods[None, None, :]
angles = angles.flatten(1) # (HW, dim // 2)
if self.rotate_half:
# Tile (half layout) (HW, dim // 2) -> (HW, dim)
angles = angles.tile(2)
else:
# Interleaved layout (HW, dim // 2) -> (HW, dim)
angles = angles.repeat_interleave(2, dim=-1)
sin = torch.sin(angles)
cos = torch.cos(angles)
return sin, cos
def _create_embed(
self,
feat_shape: List[int],
no_aug: bool = False,
) -> torch.Tensor:
H, W = feat_shape
coords = make_coords_dinov3(
H, W,
normalize_coords=self.normalize_coords,
grid_indexing=self.grid_indexing,
grid_offset=self.grid_offset,
) # (HW, 2)
if not no_aug:
coords = self._apply_coord_augs(coords)
sin, cos = self._get_pos_embed_from_coords(coords) # 2 * (HW, dim)
rope_embed = torch.cat([sin, cos], dim=-1) # (HW, 2*dim)
return rope_embed
def _cache_embed(self, feat_shape: List[int]):
# create non-augmented embeds for cache
rope_embed = self._create_embed(feat_shape, no_aug=True)
self.register_buffer("pos_embed_cached", rope_embed, persistent=False)
self.feat_shape = feat_shape
def update_feat_shape(self, feat_shape: List[int]):
if self.feat_shape is not None and feat_shape != self.feat_shape:
# only update if feat_shape was set (valid cache) and different from previous value
self._cache_embed(feat_shape)
def get_embed(self, shape: Optional[List[int]] = None) -> torch.Tensor:
"""Generate rope_embed matching DINOv3 RopePositionEmbedding numerics.
Returns: (HW, num_heads, 2 * head_dim) with last dim = [sin, cos] cat.
"""
if shape is not None:
rope_embed = self._create_embed(shape)
else:
need_create = self.pos_embed_cached is None or (self.training and self.aug_active)
if need_create:
assert self.feat_shape is not None, 'feature shape must be cached on create'
rope_embed = self._create_embed(self.feat_shape)
else:
assert self.pos_embed_cached is not None
rope_embed = self.pos_embed_cached
return rope_embed
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Get and apply rotary embeddings to x"""
# assuming channel-first tensor where spatial dim are >= 2
pos_embed = self.get_embed(x.shape[2:])
return apply_rot_embed_cat(x, pos_embed, half=self.rotate_half)
def create_rope_embed(
rope_type: str = 'cat',
dim: int = 768,
num_heads: int = 12,
**kwargs
) -> nn.Module:
"""Factory function for creating rotary position embeddings.
Args:
rope_type: Type of RoPE to create. Options:
- 'base': Basic RotaryEmbedding
- 'cat': RotaryEmbeddingCat (concatenated sin/cos)
- 'mixed': RotaryEmbeddingMixed (learnable per-depth frequencies)
- 'dinov3': RotaryEmbeddingDinoV3 (with coordinate transforms)
dim: Total embedding dimension
num_heads: Number of attention heads
**kwargs: Additional arguments passed to the specific RoPE class
Returns:
Rotary embedding module
"""
if rope_type == 'base':
return RotaryEmbedding(dim=dim // num_heads, **kwargs)
elif rope_type == 'cat':
return RotaryEmbeddingCat(dim=dim // num_heads, **kwargs)
elif rope_type == 'mixed':
# Mixed requires depth parameter, generates differing embeddings per layer and head
return RotaryEmbeddingMixed(dim=dim, num_heads=num_heads, **kwargs)
elif rope_type == 'dinov3':
return RotaryEmbeddingDinoV3(dim=dim // num_heads, **kwargs)
else:
raise ValueError(f"Unknown RoPE type: {rope_type}")