Files
railseek6/openclip_env/Lib/site-packages/timm/optim/muon.py

687 lines
26 KiB
Python

""" Muon Optimizer
Improved Muon optimizer implementation with flexible handling of high-dimensional tensors.
Combines PyTorch-style structure with options for:
- Batched spatial processing for convolutions in addition to flatten
- Optional spatial normalization
- Selectable coefficient presets
- Automatic fallback to AdamW for 1D / scalar parameters (biases, norms, etc.) and optional fallback via param groups
Based on implementation by Keller Jordan, see
- https://github.com/KellerJordan/Muon/blob/master/muon.py
- https://github.com/KellerJordan/modded-nanogpt/blob/master/train_gpt.py
- https://github.com/KellerJordan/modded-nanogpt/blob/master/train_gpt_medium.py
- https://github.com/NoahAmsel/PolarExpress/blob/main/polar_express.py
Hacked together by Ross Wightman
"""
import logging
import numbers
from typing import List, Mapping, Optional, Sequence, Tuple, Union
import torch
from ._types import ParamsT
from .adamw import adamw
from .nadamw import nadamw
_logger = logging.getLogger(__name__)
# Constants from Keller Jordan's Muon
MUON_EPS = 1e-7
DEFAULT_NS_STEPS = 5
_COEFFICIENTS = {
"original": [
# Keller Jordan's Muon https://kellerjordan.github.io/posts/muon/
(3.4445, -4.7750, 2.0315),
],
"quintic": [
# https://leloykun.github.io/ponder/muon-opt-coeffs/#how-do-we-optimize-the-coefficients
# From https://github.com/KellerJordan/modded-nanogpt/blob/master/train_gpt_medium.py#L44
(4.0848, -6.8946, 2.9270),
(3.9505, -6.3029, 2.6377),
(3.7418, -5.5913, 2.3037),
(2.8769, -3.1427, 1.2046),
(2.8366, -3.0525, 1.2012),
],
"polar_express": [
# Polar Express https://arxiv.org/abs/2505.16932
# From https://github.com/NoahAmsel/PolarExpress/tree/main with safety 1e-2
(8.237312490495555, -23.157747414558198, 16.680568411445915),
(4.082441999064835, -2.893047735332586, 0.5252849256975648),
(3.9263479922546582, -2.8547468034765298, 0.5318022422894988),
(3.2982187133085143, -2.424541981026706, 0.48632008358844075),
(2.2970369434552573, -1.63662558125903, 0.4002628455953627),
(1.8763805351440397, -1.2347896577722228, 0.35891887501668385),
(1.8564423485617974, -1.2132449880935525, 0.3568003487825883),
(1.8749994008682747, -1.2499988017229169, 0.3749994008546422),
],
"polar_express_safer": [
# from https://github.com/KellerJordan/modded-nanogpt/blob/master/train_gpt.py
# w/ safety 2e-2
(8.156554524902461, -22.48329292557795, 15.878769915207462),
(4.0429299351667245, -2.808917465908704, 0.5000178451051299),
(3.8916678022926563, -2.7724841532176825, 0.5060648178503389),
(3.285753657755658, -2.3681294933425394, 0.46449024233003117),
(2.3005307116270983, -1.6111665557258408, 0.3833374427545273),
(1.8631210546382593, -1.2042160621002727, 0.3421879560523383),
(1.8382572152247512, -1.1779263289537742, 0.3396513038637379),
(1.8749999923301852, -1.2499999836060613, 0.374999991275876),
],
}
NSCoeff = Union[str, Tuple[float, float, float], List[Tuple[float, float, float]]]
def zeropower_via_newtonschulz(
G: torch.Tensor,
steps: int,
coefficients: List[Tuple[float, float, float]],
eps: float = MUON_EPS,
safety_factor: float = 1.0,
dtype: torch.dtype = torch.bfloat16,
) -> torch.Tensor:
"""Newton-Schulz quintic iteration to compute the zeroth power / orthogonalization of gradient.
Supports batched operation over leading dimensions.
See
- https://github.com/KellerJordan/Muon/blob/master/muon.py
- https://github.com/NoahAmsel/PolarExpress/blob/main/polar_express.py
- https://github.com/KellerJordan/modded-nanogpt/blob/master/train_gpt.py
Args:
G: Input gradient tensor of shape (m, n) or (batch, m, n)
steps: Number of Newton-Schulz iterations
coefficients: Coefficients (a, b, c) for the iteration
eps: Numerical stability epsilon for norm
safety_factor: Multiplicative safety factor for norm (1.01 is common safety value in 'polar express' variants)
dtype: Computation dtype
Returns:
Orthogonalized tensor of same shape as G
"""
assert G.ndim in (2, 3), f"Input must be 2D or 3D, got {G.ndim}D. Flatten batch dims first."
num_cs = len(coefficients)
assert num_cs >= 1 and len(coefficients[0]) == 3
# match coefficients with # of steps, truncate or repeat last
coeff_sequence = coefficients[:steps] if steps <= num_cs else \
coefficients + [coefficients[-1]] * (steps - num_cs)
X = G.to(dtype=dtype, copy=True)
# Transpose if needed (operate on dimension with fewer elements)
transposed = X.size(-2) > X.size(-1)
if transposed:
X = X.mT
# Normalize spectral norm to at most 1
X.div_(X.norm(2, dim=(-2, -1), keepdim=True).mul(safety_factor).clamp_min(eps))
# Batched vs unbatched fused MM
mm_fn = torch.baddbmm if X.ndim > 2 else torch.addmm
# Pre-allocate
X = X.contiguous()
A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype)
B = torch.empty_like(A)
C = torch.empty_like(X)
# Perform Newton-Schulz iterations
for a, b, c in coeff_sequence:
mm_fn(A, X, X.mT, beta=0.0, alpha=1.0, out=A) # A = X @ X.mT
mm_fn(A, A, A, beta=b, alpha=c, out=B) # B = b * A + c * A @ A
mm_fn(X, B, X, beta=a, alpha=1.0, out=C) # C = a * X + B @ X
X, C = C, X # swap refs to avoid copy
if transposed:
X = X.mT
return X
def get_lr_scale(
param_shape: torch.Size,
adjust_lr_fn: str = "match_rms_adamw"
) -> float:
"""Adjust learning rate based on parameter shape."""
out_chs, in_chs = (param_shape[-2], param_shape[-1]) if len(param_shape) > 1 else (1., 1.)
if adjust_lr_fn == "original":
# Original Muon impl (https://kellerjordan.github.io/posts/muon/)
return max(1, out_chs / in_chs) ** 0.5
elif adjust_lr_fn == "match_rms_adamw":
# Kimi (https://arxiv.org/abs/2502.16982)
return 0.2 * max(out_chs, in_chs) ** 0.5
elif adjust_lr_fn == "rms_to_rms":
# Scion (https://arxiv.org/abs/2502.07529, https://github.com/LIONS-EPFL/scion)
# Bernstein et al. (https://jeremybernste.in/writing/deriving-muon)
return (out_chs / in_chs) ** 0.5
else:
assert False, f'Invalid scaling function "{adjust_lr_fn}"'
def _is_suitable_for_muon(
param: torch.Tensor,
min_dim_size: int = 4,
max_aspect_ratio: float = 128.,
return_reason: bool = False,
) -> Union[bool, Tuple[bool, str]]:
"""Check if a parameter is suitable for Muon optimization.
Args:
param: Parameter tensor
min_dim_size: Minimum size for non-unit dimensions
max_aspect_ratio: Maximum allowed aspect ratio
return_reason: If True, return (bool, reason_string), else just bool (faster)
Returns:
If return_reason=False: bool indicating suitability
If return_reason=True: Tuple of (is_suitable, reason_string)
Examples:
(64, 128) -> True (or (True, "ok") if return_reason=True)
(96, 3, 4, 4) -> True - will be flattened to (96, 48)
(4, 2048) -> False - extreme aspect ratio
(64,) -> False - insufficient dims
(1, 196, 768) -> False - leading unit dims
NOTE: these rules were created to balance complexity with covering common timm model cases
Please let me know if there are non-optimal cases that you run into.
"""
s = param.shape
# Must have at least 2 non-unit dimensions
if param.ndim < 2 or sum(1 for dim_size in s if dim_size > 1) < 2:
return (False, "insufficient_dims") if return_reason else False
# Unit dimension in first two positions indicates:
# - Position embeddings (1, seq, dim)
# - Depthwise convs (out, 1, h, w)
# - Other degenerate cases possibly not caught by first rule
if s[0] == 1 or s[1] == 1:
return (False, "leading_unit_dims") if return_reason else False
if param.ndim >= 3:
# For 3D+ tensors, check what dimensions will be AFTER flattening
# since that's what gets passed to Newton-Schulz iteration
# Flatten mode: (out, in, *spatial) -> (out, in * spatial_prod)
out_ch = s[0]
in_ch_with_spatial = 1
for d in s[1:]:
in_ch_with_spatial *= d
check_dims = (out_ch, in_ch_with_spatial)
else:
# For 2D tensors, check as-is
check_dims = s
# Both dims should be >= minimum size
min_size = min(check_dims)
if min_size < min_dim_size:
if return_reason:
return False, f"min_dim_too_small:{min_size}"
return False
# Aspect ratio shouldn't be too extreme
max_size = max(check_dims)
aspect_ratio = max_size / min_size
if aspect_ratio > max_aspect_ratio:
if return_reason:
return False, f"extreme_aspect_ratio:{aspect_ratio:.1f}"
return False
return (True, "ok") if return_reason else True
def reshape_for_muon(
tensor: torch.Tensor,
mode: str = "flatten",
) -> Tuple[torch.Tensor, torch.Size]:
"""Reshape high-dimensional tensor for Muon processing.
Args:
tensor: Input tensor of shape (out, in, *spatial)
mode: How to handle spatial dimensions
- "flatten": Flatten spatial into output dimension (out, in*H*W)
- "batched": Batch over spatial positions (spatial_prod, out, in) for per-position orthogonalization
Returns:
Reshaped tensor and original shape for restoration
"""
original_shape = tensor.shape
if tensor.ndim == 2:
return tensor, original_shape
if tensor.ndim < 2:
raise ValueError(f"Tensor must have at least 2 dimensions, got {tensor.ndim}")
out_ch, in_ch = tensor.shape[:2]
if mode == "flatten":
# Flatten: (out, in, *spatial) -> (out, in * spatial_prod)
return tensor.reshape(out_ch, -1), original_shape
elif mode == "batched":
# Batched: (out, in, *spatial) -> (spatial_prod, out, in)
# Move spatial dimension to front so zeropower_via_newtonschulz batches over it
reshaped = tensor.reshape(out_ch, in_ch, -1) # (out, in, spatial_prod)
reshaped = reshaped.permute(2, 0, 1) # (spatial_prod, out, in)
return reshaped, original_shape
else:
raise ValueError(f"Unknown mode: {mode}")
def muon(
params: List[torch.Tensor],
grads: List[torch.Tensor],
momentum_bufs: List[torch.Tensor],
*,
lr: float,
weight_decay: float,
momentum: float,
nesterov: bool,
ns_steps: int,
ns_coefficients: NSCoeff,
eps: float,
safety_factor: float,
adjust_lr_fn: Optional[str],
conv_mode: str,
normalize_spatial: bool,
) -> None:
"""Functional API that performs Muon algorithm computation."""
_single_tensor_muon(
params,
grads,
momentum_bufs,
lr=lr,
weight_decay=weight_decay,
momentum=momentum,
nesterov=nesterov,
ns_steps=ns_steps,
ns_coefficients=ns_coefficients,
eps=eps,
safety_factor=safety_factor,
adjust_lr_fn=adjust_lr_fn,
conv_mode=conv_mode,
normalize_spatial=normalize_spatial,
)
def _single_tensor_muon(
params: List[torch.Tensor],
grads: List[torch.Tensor],
momentum_bufs: List[torch.Tensor],
*,
lr: float,
weight_decay: float,
momentum: float,
nesterov: bool,
ns_steps: int,
ns_coefficients: NSCoeff,
eps: float,
safety_factor: float,
adjust_lr_fn: Optional[str],
conv_mode: str,
normalize_spatial: bool,
) -> None:
"""Single tensor Muon update."""
ns_coefficients = resolve_ns_coefficients(ns_coefficients, _COEFFICIENTS)
for i, param in enumerate(params):
grad = grads[i]
momentum_buf = momentum_bufs[i]
# Apply weight decay
param.mul_(1 - lr * weight_decay)
# Update momentum buffer
momentum_buf.lerp_(grad, 1. - momentum)
update = grad.lerp_(momentum_buf, momentum) if nesterov else momentum_buf.clone()
# Reshape for processing (handle 3D+ tensors like conv weights)
if update.ndim >= 3:
update_reshaped, original_shape = reshape_for_muon(update, mode=conv_mode)
else:
update_reshaped = update
original_shape = update.shape
# Apply Newton-Schulz orthogonalization
update_ortho = zeropower_via_newtonschulz(
update_reshaped,
ns_steps,
ns_coefficients,
eps=eps,
safety_factor=safety_factor,
#dtype=torch.bfloat16, # wire to arg?
)
# Adjust learning rate based on parameter shape
scale = get_lr_scale(update_ortho.shape, adjust_lr_fn)
# Apply spatial normalization and permute back if in batched mode
if conv_mode == "batched" and update_ortho.ndim >= 3:
if normalize_spatial:
scale *= update_ortho.shape[0] ** -0.5
# Permute back: (spatial_prod, out, in) -> (out, in, spatial_prod)
update_ortho = update_ortho.permute(1, 2, 0)
# Reshape back to original shape
update_ortho = update_ortho.reshape(original_shape)
# Apply update
param.add_(update_ortho, alpha=-lr * scale)
class Muon(torch.optim.Optimizer):
"""Muon - MomentUm Orthogonalized by Newton-schulz
Combines Muon for 2D+ parameters (weight matrices) with AdamW for 1D parameters (biases, norms) and
parameter groups with 'use_fallback=True' set (or 'use_muon=False' for compatibility).
"""
def __init__(
self,
params: ParamsT,
lr: float = 0.02,
weight_decay: float = 0,
momentum: float = 0.95,
nesterov: bool = False,
ns_steps: int = DEFAULT_NS_STEPS,
ns_coefficients: NSCoeff = "quintic",
eps: float = MUON_EPS,
safety_factor: float = 1.0,
adjust_lr_fn: Optional[str] = "match_rms_adamw",
conv_mode: str = "flatten",
normalize_spatial: bool = True,
adamw_lr: Optional[float] = None,
betas: Tuple[float, float] = (0.9, 0.95),
verbose: bool = False,
):
""" Create Muon optimizer.
Args:
params: Iterable of parameters or dicts defining parameter groups
lr: Learning rate (default: 0.02 for Muon parameters)
weight_decay: Weight decay coefficient
momentum: Momentum factor for Muon
nesterov: Whether to use Nesterov momentum
ns_steps: Number of Newton-Schulz iterations
ns_coefficients: Coefficients for NS iteration
eps: Numerical stability epsilon
safety_factor: Multiplicative safety factor for NS norm
adjust_lr_fn: LR adjustment function - "original" or "match_rms_adamw"
conv_mode: How to handle convolutions - "flatten" or "batched"
normalize_spatial: Whether to normalize by sqrt(spatial_size) in batched mode
adamw_lr: Learning rate for AdamW (1D params), defaults to lr if not specified
betas: AdamW beta coefficients
verbose: Log parameter routing decisions (Muon vs AdamW)
Example:
```python
# Simple usage - automatically uses Muon for 2D+ params, AdamW for 1D
optimizer = Muon(model.parameters(), lr=0.02)
# Manual control over parameter groups
optimizer = Muon([
{'params': weight_matrices, 'lr': 0.02},
{'params': biases, 'use_fallback': True, 'lr': 3e-4}, # use AdamW if use_fallback=True
])
```
"""
if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= weight_decay:
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
if not 0.0 <= momentum < 1.0:
raise ValueError(f"Invalid momentum value: {momentum}")
if not 0.0 <= eps:
raise ValueError(f"Invalid epsilon value: {eps}")
if conv_mode not in ["flatten", "batched"]:
raise ValueError(f"Invalid conv_mode: {conv_mode}")
defaults = dict(
lr=lr,
weight_decay=weight_decay,
momentum=momentum,
nesterov=nesterov,
ns_steps=ns_steps,
ns_coefficients=ns_coefficients,
eps=eps,
safety_factor=safety_factor,
adjust_lr_fn=adjust_lr_fn,
conv_mode=conv_mode,
normalize_spatial=normalize_spatial,
adamw_lr=adamw_lr if adamw_lr is not None else lr,
betas=betas,
verbose=verbose,
)
super().__init__(params, defaults)
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step."""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
verbose = self.defaults.get("verbose", False)
# Tracking for logging (populated on first encounter of each param)
muon_count = 0
adamw_count = 0
routing_reasons = {} if verbose else None
for group in self.param_groups:
# Separate params into Muon and AdamW groups
muon_params = []
muon_grads = []
muon_momentum_bufs = []
adamw_params = []
adamw_grads = []
adamw_exp_avgs = []
adamw_exp_avg_sqs = []
adamw_state_steps = []
for p in group["params"]:
if p.grad is None:
continue
if p.grad.is_sparse:
raise RuntimeError("Muon does not support sparse gradients")
state = self.state[p]
# Determine routing on first encounter (cache in state)
if "use_muon" not in state:
# Check explicit flags first (support both 'use_fallback' and 'use_muon' for compatibility)
reason = None
if group.get("use_fallback", False):
# use_fallback=True means use AdamW (use_muon=False)
state["use_muon"] = False
if verbose:
reason = "use_fallback_flag"
elif "use_muon" in group:
# Explicit use_muon flag for compatibility with other Muon implementations
state["use_muon"] = group["use_muon"]
if verbose:
reason = "use_muon_flag"
else:
# Check shape suitability
if verbose:
suitable, reason = _is_suitable_for_muon(p, return_reason=True)
else:
suitable = _is_suitable_for_muon(p, return_reason=False)
state["use_muon"] = suitable
# Track routing decision for logging
if routing_reasons is not None and reason is not None:
shape_str = "x".join(str(s) for s in p.shape)
if shape_str not in routing_reasons:
routing_reasons[shape_str] = []
routing_reasons[shape_str].append(reason)
# Use cached routing decision
use_muon = state["use_muon"]
if use_muon:
# Collect Muon params
muon_params.append(p)
muon_grads.append(p.grad)
muon_count += 1
# State initialization for Muon
if "momentum_buffer" not in state:
state["momentum_buffer"] = torch.zeros_like(p, memory_format=torch.preserve_format)
muon_momentum_bufs.append(state["momentum_buffer"])
else:
# Collect AdamW/NAdamW params
adamw_params.append(p)
adamw_grads.append(p.grad)
adamw_count += 1
# State initialization for AdamW
if "step" not in state:
state["step"] = torch.tensor(0.)
state["exp_avg"] = torch.zeros_like(p, memory_format=torch.preserve_format)
state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format)
adamw_exp_avgs.append(state["exp_avg"])
adamw_exp_avg_sqs.append(state["exp_avg_sq"])
adamw_state_steps.append(state["step"])
# Apply Muon updates
if muon_params:
muon(
muon_params,
muon_grads,
muon_momentum_bufs,
lr=group["lr"],
weight_decay=group["weight_decay"],
momentum=group["momentum"],
nesterov=group["nesterov"],
ns_steps=group["ns_steps"],
ns_coefficients=group["ns_coefficients"],
eps=group["eps"],
safety_factor=group["safety_factor"],
adjust_lr_fn=group["adjust_lr_fn"],
conv_mode=group["conv_mode"],
normalize_spatial=group["normalize_spatial"],
)
# Apply AdamW updates
if adamw_params:
beta1, beta2 = group["betas"]
if group["nesterov"]:
# use nadamw for fallback optimizer if nesterov is enabled
nadamw(
adamw_params,
adamw_grads,
adamw_exp_avgs,
adamw_exp_avg_sqs,
adamw_state_steps,
foreach=None,
beta1=beta1,
beta2=beta2,
lr=group["adamw_lr"],
weight_decay=group["weight_decay"],
eps=group["eps"],
caution=False,
maximize=False,
capturable=False,
max_lr=None,
)
else:
adamw(
adamw_params,
adamw_grads,
adamw_exp_avgs,
adamw_exp_avg_sqs,
[], # max_exp_avg_sqs (not using amsgrad)
adamw_state_steps,
foreach=None,
amsgrad=False,
beta1=beta1,
beta2=beta2,
lr=group["adamw_lr"],
weight_decay=group["weight_decay"],
eps=group["eps"],
caution=False,
maximize=False,
capturable=False,
max_lr=None,
)
# Log routing summary when we have new routing decisions
if routing_reasons and len(routing_reasons) > 0:
# Concise summary
_logger.info(f"Muon parameter routing: {muon_count} Muon, {adamw_count} AdamW")
# Group by reason for detailed breakdown
reason_groups = {}
for shape_str, reasons in sorted(routing_reasons.items()):
for reason in reasons:
if reason not in reason_groups:
reason_groups[reason] = []
reason_groups[reason].append(shape_str)
# Log summary counts per reason
reason_summary = []
for reason, shapes in sorted(reason_groups.items()):
reason_summary.append(f"{reason}={len(shapes)}")
_logger.info(f" Breakdown: {', '.join(reason_summary)}")
# Detailed breakdown at INFO level
if _logger.isEnabledFor(logging.INFO):
for reason, shapes in sorted(reason_groups.items()):
optimizer_name = "Muon" if reason == "ok" else "AdamW"
_logger.info(f" {reason} -> {optimizer_name}:")
for shape in shapes[:10]:
_logger.info(f" {shape}")
if len(shapes) > 10:
_logger.info(f" ... and {len(shapes) - 10} more")
return loss
def resolve_ns_coefficients(
value: Union[str, Sequence[float], Sequence[Sequence[float]]],
presets: Mapping[str, Sequence[Sequence[float]]]
) -> List[Tuple[float, float, float]]:
# tiny helpers (kept inline for succinctness)
is_seq = lambda x: isinstance(x, Sequence) and not isinstance(x, (str, bytes))
is_real = lambda x: isinstance(x, numbers.Real) and not isinstance(x, bool)
def as_coeff(x: Sequence[float]) -> Tuple[float, float, float]:
if not is_seq(x) or len(x) != 3 or not all(is_real(v) for v in x):
raise ValueError(f"Coefficient must be length-3 of real numbers, got: {x!r}")
a, b, c = x # type: ignore[misc]
return float(a), float(b), float(c)
if isinstance(value, str):
if value not in presets:
valid = ", ".join(sorted(presets.keys()))
raise ValueError(f"Unknown coefficients preset '{value}'. Valid options: {valid}")
seq = presets[value]
if not is_seq(seq) or len(seq) == 0:
raise ValueError(f"Preset '{value}' is empty or invalid")
return [as_coeff(item) for item in seq] # validate & cast
if not is_seq(value):
raise TypeError(
"Coefficients must be a preset name (str), a 3-sequence (a,b,c), "
"or a sequence of 3-sequences."
)
# Decide single triple vs list-of-triples by structure
if len(value) == 3 and all(is_real(v) for v in value): # type: ignore[index]
return [as_coeff(value)] # single triple -> wrap
# Otherwise treat as list/tuple of triples
out = []
for i, item in enumerate(value): # type: ignore[assignment]
if not is_seq(item):
raise TypeError(f"Item {i} is not a sequence: {item!r}")
out.append(as_coeff(item))
if not out:
raise ValueError("Coefficient list cannot be empty")
return out