""" Muon Optimizer Improved Muon optimizer implementation with flexible handling of high-dimensional tensors. Combines PyTorch-style structure with options for: - Batched spatial processing for convolutions in addition to flatten - Optional spatial normalization - Selectable coefficient presets - Automatic fallback to AdamW for 1D / scalar parameters (biases, norms, etc.) and optional fallback via param groups Based on implementation by Keller Jordan, see - https://github.com/KellerJordan/Muon/blob/master/muon.py - https://github.com/KellerJordan/modded-nanogpt/blob/master/train_gpt.py - https://github.com/KellerJordan/modded-nanogpt/blob/master/train_gpt_medium.py - https://github.com/NoahAmsel/PolarExpress/blob/main/polar_express.py Hacked together by Ross Wightman """ import logging import numbers from typing import List, Mapping, Optional, Sequence, Tuple, Union import torch from ._types import ParamsT from .adamw import adamw from .nadamw import nadamw _logger = logging.getLogger(__name__) # Constants from Keller Jordan's Muon MUON_EPS = 1e-7 DEFAULT_NS_STEPS = 5 _COEFFICIENTS = { "original": [ # Keller Jordan's Muon https://kellerjordan.github.io/posts/muon/ (3.4445, -4.7750, 2.0315), ], "quintic": [ # https://leloykun.github.io/ponder/muon-opt-coeffs/#how-do-we-optimize-the-coefficients # From https://github.com/KellerJordan/modded-nanogpt/blob/master/train_gpt_medium.py#L44 (4.0848, -6.8946, 2.9270), (3.9505, -6.3029, 2.6377), (3.7418, -5.5913, 2.3037), (2.8769, -3.1427, 1.2046), (2.8366, -3.0525, 1.2012), ], "polar_express": [ # Polar Express https://arxiv.org/abs/2505.16932 # From https://github.com/NoahAmsel/PolarExpress/tree/main with safety 1e-2 (8.237312490495555, -23.157747414558198, 16.680568411445915), (4.082441999064835, -2.893047735332586, 0.5252849256975648), (3.9263479922546582, -2.8547468034765298, 0.5318022422894988), (3.2982187133085143, -2.424541981026706, 0.48632008358844075), (2.2970369434552573, -1.63662558125903, 0.4002628455953627), (1.8763805351440397, -1.2347896577722228, 0.35891887501668385), (1.8564423485617974, -1.2132449880935525, 0.3568003487825883), (1.8749994008682747, -1.2499988017229169, 0.3749994008546422), ], "polar_express_safer": [ # from https://github.com/KellerJordan/modded-nanogpt/blob/master/train_gpt.py # w/ safety 2e-2 (8.156554524902461, -22.48329292557795, 15.878769915207462), (4.0429299351667245, -2.808917465908704, 0.5000178451051299), (3.8916678022926563, -2.7724841532176825, 0.5060648178503389), (3.285753657755658, -2.3681294933425394, 0.46449024233003117), (2.3005307116270983, -1.6111665557258408, 0.3833374427545273), (1.8631210546382593, -1.2042160621002727, 0.3421879560523383), (1.8382572152247512, -1.1779263289537742, 0.3396513038637379), (1.8749999923301852, -1.2499999836060613, 0.374999991275876), ], } NSCoeff = Union[str, Tuple[float, float, float], List[Tuple[float, float, float]]] def zeropower_via_newtonschulz( G: torch.Tensor, steps: int, coefficients: List[Tuple[float, float, float]], eps: float = MUON_EPS, safety_factor: float = 1.0, dtype: torch.dtype = torch.bfloat16, ) -> torch.Tensor: """Newton-Schulz quintic iteration to compute the zeroth power / orthogonalization of gradient. Supports batched operation over leading dimensions. See - https://github.com/KellerJordan/Muon/blob/master/muon.py - https://github.com/NoahAmsel/PolarExpress/blob/main/polar_express.py - https://github.com/KellerJordan/modded-nanogpt/blob/master/train_gpt.py Args: G: Input gradient tensor of shape (m, n) or (batch, m, n) steps: Number of Newton-Schulz iterations coefficients: Coefficients (a, b, c) for the iteration eps: Numerical stability epsilon for norm safety_factor: Multiplicative safety factor for norm (1.01 is common safety value in 'polar express' variants) dtype: Computation dtype Returns: Orthogonalized tensor of same shape as G """ assert G.ndim in (2, 3), f"Input must be 2D or 3D, got {G.ndim}D. Flatten batch dims first." num_cs = len(coefficients) assert num_cs >= 1 and len(coefficients[0]) == 3 # match coefficients with # of steps, truncate or repeat last coeff_sequence = coefficients[:steps] if steps <= num_cs else \ coefficients + [coefficients[-1]] * (steps - num_cs) X = G.to(dtype=dtype, copy=True) # Transpose if needed (operate on dimension with fewer elements) transposed = X.size(-2) > X.size(-1) if transposed: X = X.mT # Normalize spectral norm to at most 1 X.div_(X.norm(2, dim=(-2, -1), keepdim=True).mul(safety_factor).clamp_min(eps)) # Batched vs unbatched fused MM mm_fn = torch.baddbmm if X.ndim > 2 else torch.addmm # Pre-allocate X = X.contiguous() A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) B = torch.empty_like(A) C = torch.empty_like(X) # Perform Newton-Schulz iterations for a, b, c in coeff_sequence: mm_fn(A, X, X.mT, beta=0.0, alpha=1.0, out=A) # A = X @ X.mT mm_fn(A, A, A, beta=b, alpha=c, out=B) # B = b * A + c * A @ A mm_fn(X, B, X, beta=a, alpha=1.0, out=C) # C = a * X + B @ X X, C = C, X # swap refs to avoid copy if transposed: X = X.mT return X def get_lr_scale( param_shape: torch.Size, adjust_lr_fn: str = "match_rms_adamw" ) -> float: """Adjust learning rate based on parameter shape.""" out_chs, in_chs = (param_shape[-2], param_shape[-1]) if len(param_shape) > 1 else (1., 1.) if adjust_lr_fn == "original": # Original Muon impl (https://kellerjordan.github.io/posts/muon/) return max(1, out_chs / in_chs) ** 0.5 elif adjust_lr_fn == "match_rms_adamw": # Kimi (https://arxiv.org/abs/2502.16982) return 0.2 * max(out_chs, in_chs) ** 0.5 elif adjust_lr_fn == "rms_to_rms": # Scion (https://arxiv.org/abs/2502.07529, https://github.com/LIONS-EPFL/scion) # Bernstein et al. (https://jeremybernste.in/writing/deriving-muon) return (out_chs / in_chs) ** 0.5 else: assert False, f'Invalid scaling function "{adjust_lr_fn}"' def _is_suitable_for_muon( param: torch.Tensor, min_dim_size: int = 4, max_aspect_ratio: float = 128., return_reason: bool = False, ) -> Union[bool, Tuple[bool, str]]: """Check if a parameter is suitable for Muon optimization. Args: param: Parameter tensor min_dim_size: Minimum size for non-unit dimensions max_aspect_ratio: Maximum allowed aspect ratio return_reason: If True, return (bool, reason_string), else just bool (faster) Returns: If return_reason=False: bool indicating suitability If return_reason=True: Tuple of (is_suitable, reason_string) Examples: (64, 128) -> True (or (True, "ok") if return_reason=True) (96, 3, 4, 4) -> True - will be flattened to (96, 48) (4, 2048) -> False - extreme aspect ratio (64,) -> False - insufficient dims (1, 196, 768) -> False - leading unit dims NOTE: these rules were created to balance complexity with covering common timm model cases Please let me know if there are non-optimal cases that you run into. """ s = param.shape # Must have at least 2 non-unit dimensions if param.ndim < 2 or sum(1 for dim_size in s if dim_size > 1) < 2: return (False, "insufficient_dims") if return_reason else False # Unit dimension in first two positions indicates: # - Position embeddings (1, seq, dim) # - Depthwise convs (out, 1, h, w) # - Other degenerate cases possibly not caught by first rule if s[0] == 1 or s[1] == 1: return (False, "leading_unit_dims") if return_reason else False if param.ndim >= 3: # For 3D+ tensors, check what dimensions will be AFTER flattening # since that's what gets passed to Newton-Schulz iteration # Flatten mode: (out, in, *spatial) -> (out, in * spatial_prod) out_ch = s[0] in_ch_with_spatial = 1 for d in s[1:]: in_ch_with_spatial *= d check_dims = (out_ch, in_ch_with_spatial) else: # For 2D tensors, check as-is check_dims = s # Both dims should be >= minimum size min_size = min(check_dims) if min_size < min_dim_size: if return_reason: return False, f"min_dim_too_small:{min_size}" return False # Aspect ratio shouldn't be too extreme max_size = max(check_dims) aspect_ratio = max_size / min_size if aspect_ratio > max_aspect_ratio: if return_reason: return False, f"extreme_aspect_ratio:{aspect_ratio:.1f}" return False return (True, "ok") if return_reason else True def reshape_for_muon( tensor: torch.Tensor, mode: str = "flatten", ) -> Tuple[torch.Tensor, torch.Size]: """Reshape high-dimensional tensor for Muon processing. Args: tensor: Input tensor of shape (out, in, *spatial) mode: How to handle spatial dimensions - "flatten": Flatten spatial into output dimension (out, in*H*W) - "batched": Batch over spatial positions (spatial_prod, out, in) for per-position orthogonalization Returns: Reshaped tensor and original shape for restoration """ original_shape = tensor.shape if tensor.ndim == 2: return tensor, original_shape if tensor.ndim < 2: raise ValueError(f"Tensor must have at least 2 dimensions, got {tensor.ndim}") out_ch, in_ch = tensor.shape[:2] if mode == "flatten": # Flatten: (out, in, *spatial) -> (out, in * spatial_prod) return tensor.reshape(out_ch, -1), original_shape elif mode == "batched": # Batched: (out, in, *spatial) -> (spatial_prod, out, in) # Move spatial dimension to front so zeropower_via_newtonschulz batches over it reshaped = tensor.reshape(out_ch, in_ch, -1) # (out, in, spatial_prod) reshaped = reshaped.permute(2, 0, 1) # (spatial_prod, out, in) return reshaped, original_shape else: raise ValueError(f"Unknown mode: {mode}") def muon( params: List[torch.Tensor], grads: List[torch.Tensor], momentum_bufs: List[torch.Tensor], *, lr: float, weight_decay: float, momentum: float, nesterov: bool, ns_steps: int, ns_coefficients: NSCoeff, eps: float, safety_factor: float, adjust_lr_fn: Optional[str], conv_mode: str, normalize_spatial: bool, ) -> None: """Functional API that performs Muon algorithm computation.""" _single_tensor_muon( params, grads, momentum_bufs, lr=lr, weight_decay=weight_decay, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps, ns_coefficients=ns_coefficients, eps=eps, safety_factor=safety_factor, adjust_lr_fn=adjust_lr_fn, conv_mode=conv_mode, normalize_spatial=normalize_spatial, ) def _single_tensor_muon( params: List[torch.Tensor], grads: List[torch.Tensor], momentum_bufs: List[torch.Tensor], *, lr: float, weight_decay: float, momentum: float, nesterov: bool, ns_steps: int, ns_coefficients: NSCoeff, eps: float, safety_factor: float, adjust_lr_fn: Optional[str], conv_mode: str, normalize_spatial: bool, ) -> None: """Single tensor Muon update.""" ns_coefficients = resolve_ns_coefficients(ns_coefficients, _COEFFICIENTS) for i, param in enumerate(params): grad = grads[i] momentum_buf = momentum_bufs[i] # Apply weight decay param.mul_(1 - lr * weight_decay) # Update momentum buffer momentum_buf.lerp_(grad, 1. - momentum) update = grad.lerp_(momentum_buf, momentum) if nesterov else momentum_buf.clone() # Reshape for processing (handle 3D+ tensors like conv weights) if update.ndim >= 3: update_reshaped, original_shape = reshape_for_muon(update, mode=conv_mode) else: update_reshaped = update original_shape = update.shape # Apply Newton-Schulz orthogonalization update_ortho = zeropower_via_newtonschulz( update_reshaped, ns_steps, ns_coefficients, eps=eps, safety_factor=safety_factor, #dtype=torch.bfloat16, # wire to arg? ) # Adjust learning rate based on parameter shape scale = get_lr_scale(update_ortho.shape, adjust_lr_fn) # Apply spatial normalization and permute back if in batched mode if conv_mode == "batched" and update_ortho.ndim >= 3: if normalize_spatial: scale *= update_ortho.shape[0] ** -0.5 # Permute back: (spatial_prod, out, in) -> (out, in, spatial_prod) update_ortho = update_ortho.permute(1, 2, 0) # Reshape back to original shape update_ortho = update_ortho.reshape(original_shape) # Apply update param.add_(update_ortho, alpha=-lr * scale) class Muon(torch.optim.Optimizer): """Muon - MomentUm Orthogonalized by Newton-schulz Combines Muon for 2D+ parameters (weight matrices) with AdamW for 1D parameters (biases, norms) and parameter groups with 'use_fallback=True' set (or 'use_muon=False' for compatibility). """ def __init__( self, params: ParamsT, lr: float = 0.02, weight_decay: float = 0, momentum: float = 0.95, nesterov: bool = False, ns_steps: int = DEFAULT_NS_STEPS, ns_coefficients: NSCoeff = "quintic", eps: float = MUON_EPS, safety_factor: float = 1.0, adjust_lr_fn: Optional[str] = "match_rms_adamw", conv_mode: str = "flatten", normalize_spatial: bool = True, adamw_lr: Optional[float] = None, betas: Tuple[float, float] = (0.9, 0.95), verbose: bool = False, ): """ Create Muon optimizer. Args: params: Iterable of parameters or dicts defining parameter groups lr: Learning rate (default: 0.02 for Muon parameters) weight_decay: Weight decay coefficient momentum: Momentum factor for Muon nesterov: Whether to use Nesterov momentum ns_steps: Number of Newton-Schulz iterations ns_coefficients: Coefficients for NS iteration eps: Numerical stability epsilon safety_factor: Multiplicative safety factor for NS norm adjust_lr_fn: LR adjustment function - "original" or "match_rms_adamw" conv_mode: How to handle convolutions - "flatten" or "batched" normalize_spatial: Whether to normalize by sqrt(spatial_size) in batched mode adamw_lr: Learning rate for AdamW (1D params), defaults to lr if not specified betas: AdamW beta coefficients verbose: Log parameter routing decisions (Muon vs AdamW) Example: ```python # Simple usage - automatically uses Muon for 2D+ params, AdamW for 1D optimizer = Muon(model.parameters(), lr=0.02) # Manual control over parameter groups optimizer = Muon([ {'params': weight_matrices, 'lr': 0.02}, {'params': biases, 'use_fallback': True, 'lr': 3e-4}, # use AdamW if use_fallback=True ]) ``` """ if not 0.0 <= lr: raise ValueError(f"Invalid learning rate: {lr}") if not 0.0 <= weight_decay: raise ValueError(f"Invalid weight_decay value: {weight_decay}") if not 0.0 <= momentum < 1.0: raise ValueError(f"Invalid momentum value: {momentum}") if not 0.0 <= eps: raise ValueError(f"Invalid epsilon value: {eps}") if conv_mode not in ["flatten", "batched"]: raise ValueError(f"Invalid conv_mode: {conv_mode}") defaults = dict( lr=lr, weight_decay=weight_decay, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps, ns_coefficients=ns_coefficients, eps=eps, safety_factor=safety_factor, adjust_lr_fn=adjust_lr_fn, conv_mode=conv_mode, normalize_spatial=normalize_spatial, adamw_lr=adamw_lr if adamw_lr is not None else lr, betas=betas, verbose=verbose, ) super().__init__(params, defaults) @torch.no_grad() def step(self, closure=None): """Performs a single optimization step.""" loss = None if closure is not None: with torch.enable_grad(): loss = closure() verbose = self.defaults.get("verbose", False) # Tracking for logging (populated on first encounter of each param) muon_count = 0 adamw_count = 0 routing_reasons = {} if verbose else None for group in self.param_groups: # Separate params into Muon and AdamW groups muon_params = [] muon_grads = [] muon_momentum_bufs = [] adamw_params = [] adamw_grads = [] adamw_exp_avgs = [] adamw_exp_avg_sqs = [] adamw_state_steps = [] for p in group["params"]: if p.grad is None: continue if p.grad.is_sparse: raise RuntimeError("Muon does not support sparse gradients") state = self.state[p] # Determine routing on first encounter (cache in state) if "use_muon" not in state: # Check explicit flags first (support both 'use_fallback' and 'use_muon' for compatibility) reason = None if group.get("use_fallback", False): # use_fallback=True means use AdamW (use_muon=False) state["use_muon"] = False if verbose: reason = "use_fallback_flag" elif "use_muon" in group: # Explicit use_muon flag for compatibility with other Muon implementations state["use_muon"] = group["use_muon"] if verbose: reason = "use_muon_flag" else: # Check shape suitability if verbose: suitable, reason = _is_suitable_for_muon(p, return_reason=True) else: suitable = _is_suitable_for_muon(p, return_reason=False) state["use_muon"] = suitable # Track routing decision for logging if routing_reasons is not None and reason is not None: shape_str = "x".join(str(s) for s in p.shape) if shape_str not in routing_reasons: routing_reasons[shape_str] = [] routing_reasons[shape_str].append(reason) # Use cached routing decision use_muon = state["use_muon"] if use_muon: # Collect Muon params muon_params.append(p) muon_grads.append(p.grad) muon_count += 1 # State initialization for Muon if "momentum_buffer" not in state: state["momentum_buffer"] = torch.zeros_like(p, memory_format=torch.preserve_format) muon_momentum_bufs.append(state["momentum_buffer"]) else: # Collect AdamW/NAdamW params adamw_params.append(p) adamw_grads.append(p.grad) adamw_count += 1 # State initialization for AdamW if "step" not in state: state["step"] = torch.tensor(0.) state["exp_avg"] = torch.zeros_like(p, memory_format=torch.preserve_format) state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format) adamw_exp_avgs.append(state["exp_avg"]) adamw_exp_avg_sqs.append(state["exp_avg_sq"]) adamw_state_steps.append(state["step"]) # Apply Muon updates if muon_params: muon( muon_params, muon_grads, muon_momentum_bufs, lr=group["lr"], weight_decay=group["weight_decay"], momentum=group["momentum"], nesterov=group["nesterov"], ns_steps=group["ns_steps"], ns_coefficients=group["ns_coefficients"], eps=group["eps"], safety_factor=group["safety_factor"], adjust_lr_fn=group["adjust_lr_fn"], conv_mode=group["conv_mode"], normalize_spatial=group["normalize_spatial"], ) # Apply AdamW updates if adamw_params: beta1, beta2 = group["betas"] if group["nesterov"]: # use nadamw for fallback optimizer if nesterov is enabled nadamw( adamw_params, adamw_grads, adamw_exp_avgs, adamw_exp_avg_sqs, adamw_state_steps, foreach=None, beta1=beta1, beta2=beta2, lr=group["adamw_lr"], weight_decay=group["weight_decay"], eps=group["eps"], caution=False, maximize=False, capturable=False, max_lr=None, ) else: adamw( adamw_params, adamw_grads, adamw_exp_avgs, adamw_exp_avg_sqs, [], # max_exp_avg_sqs (not using amsgrad) adamw_state_steps, foreach=None, amsgrad=False, beta1=beta1, beta2=beta2, lr=group["adamw_lr"], weight_decay=group["weight_decay"], eps=group["eps"], caution=False, maximize=False, capturable=False, max_lr=None, ) # Log routing summary when we have new routing decisions if routing_reasons and len(routing_reasons) > 0: # Concise summary _logger.info(f"Muon parameter routing: {muon_count} Muon, {adamw_count} AdamW") # Group by reason for detailed breakdown reason_groups = {} for shape_str, reasons in sorted(routing_reasons.items()): for reason in reasons: if reason not in reason_groups: reason_groups[reason] = [] reason_groups[reason].append(shape_str) # Log summary counts per reason reason_summary = [] for reason, shapes in sorted(reason_groups.items()): reason_summary.append(f"{reason}={len(shapes)}") _logger.info(f" Breakdown: {', '.join(reason_summary)}") # Detailed breakdown at INFO level if _logger.isEnabledFor(logging.INFO): for reason, shapes in sorted(reason_groups.items()): optimizer_name = "Muon" if reason == "ok" else "AdamW" _logger.info(f" {reason} -> {optimizer_name}:") for shape in shapes[:10]: _logger.info(f" {shape}") if len(shapes) > 10: _logger.info(f" ... and {len(shapes) - 10} more") return loss def resolve_ns_coefficients( value: Union[str, Sequence[float], Sequence[Sequence[float]]], presets: Mapping[str, Sequence[Sequence[float]]] ) -> List[Tuple[float, float, float]]: # tiny helpers (kept inline for succinctness) is_seq = lambda x: isinstance(x, Sequence) and not isinstance(x, (str, bytes)) is_real = lambda x: isinstance(x, numbers.Real) and not isinstance(x, bool) def as_coeff(x: Sequence[float]) -> Tuple[float, float, float]: if not is_seq(x) or len(x) != 3 or not all(is_real(v) for v in x): raise ValueError(f"Coefficient must be length-3 of real numbers, got: {x!r}") a, b, c = x # type: ignore[misc] return float(a), float(b), float(c) if isinstance(value, str): if value not in presets: valid = ", ".join(sorted(presets.keys())) raise ValueError(f"Unknown coefficients preset '{value}'. Valid options: {valid}") seq = presets[value] if not is_seq(seq) or len(seq) == 0: raise ValueError(f"Preset '{value}' is empty or invalid") return [as_coeff(item) for item in seq] # validate & cast if not is_seq(value): raise TypeError( "Coefficients must be a preset name (str), a 3-sequence (a,b,c), " "or a sequence of 3-sequences." ) # Decide single triple vs list-of-triples by structure if len(value) == 3 and all(is_real(v) for v in value): # type: ignore[index] return [as_coeff(value)] # single triple -> wrap # Otherwise treat as list/tuple of triples out = [] for i, item in enumerate(value): # type: ignore[assignment] if not is_seq(item): raise TypeError(f"Item {i} is not a sequence: {item!r}") out.append(as_coeff(item)) if not out: raise ValueError("Coefficient list cannot be empty") return out