687 lines
26 KiB
Python
687 lines
26 KiB
Python
""" Muon Optimizer
|
|
|
|
Improved Muon optimizer implementation with flexible handling of high-dimensional tensors.
|
|
|
|
Combines PyTorch-style structure with options for:
|
|
- Batched spatial processing for convolutions in addition to flatten
|
|
- Optional spatial normalization
|
|
- Selectable coefficient presets
|
|
- Automatic fallback to AdamW for 1D / scalar parameters (biases, norms, etc.) and optional fallback via param groups
|
|
|
|
Based on implementation by Keller Jordan, see
|
|
- https://github.com/KellerJordan/Muon/blob/master/muon.py
|
|
- https://github.com/KellerJordan/modded-nanogpt/blob/master/train_gpt.py
|
|
- https://github.com/KellerJordan/modded-nanogpt/blob/master/train_gpt_medium.py
|
|
- https://github.com/NoahAmsel/PolarExpress/blob/main/polar_express.py
|
|
|
|
Hacked together by Ross Wightman
|
|
"""
|
|
import logging
|
|
import numbers
|
|
from typing import List, Mapping, Optional, Sequence, Tuple, Union
|
|
|
|
import torch
|
|
|
|
from ._types import ParamsT
|
|
from .adamw import adamw
|
|
from .nadamw import nadamw
|
|
|
|
_logger = logging.getLogger(__name__)
|
|
|
|
# Constants from Keller Jordan's Muon
|
|
MUON_EPS = 1e-7
|
|
DEFAULT_NS_STEPS = 5
|
|
|
|
_COEFFICIENTS = {
|
|
"original": [
|
|
# Keller Jordan's Muon https://kellerjordan.github.io/posts/muon/
|
|
(3.4445, -4.7750, 2.0315),
|
|
],
|
|
"quintic": [
|
|
# https://leloykun.github.io/ponder/muon-opt-coeffs/#how-do-we-optimize-the-coefficients
|
|
# From https://github.com/KellerJordan/modded-nanogpt/blob/master/train_gpt_medium.py#L44
|
|
(4.0848, -6.8946, 2.9270),
|
|
(3.9505, -6.3029, 2.6377),
|
|
(3.7418, -5.5913, 2.3037),
|
|
(2.8769, -3.1427, 1.2046),
|
|
(2.8366, -3.0525, 1.2012),
|
|
],
|
|
"polar_express": [
|
|
# Polar Express https://arxiv.org/abs/2505.16932
|
|
# From https://github.com/NoahAmsel/PolarExpress/tree/main with safety 1e-2
|
|
(8.237312490495555, -23.157747414558198, 16.680568411445915),
|
|
(4.082441999064835, -2.893047735332586, 0.5252849256975648),
|
|
(3.9263479922546582, -2.8547468034765298, 0.5318022422894988),
|
|
(3.2982187133085143, -2.424541981026706, 0.48632008358844075),
|
|
(2.2970369434552573, -1.63662558125903, 0.4002628455953627),
|
|
(1.8763805351440397, -1.2347896577722228, 0.35891887501668385),
|
|
(1.8564423485617974, -1.2132449880935525, 0.3568003487825883),
|
|
(1.8749994008682747, -1.2499988017229169, 0.3749994008546422),
|
|
],
|
|
"polar_express_safer": [
|
|
# from https://github.com/KellerJordan/modded-nanogpt/blob/master/train_gpt.py
|
|
# w/ safety 2e-2
|
|
(8.156554524902461, -22.48329292557795, 15.878769915207462),
|
|
(4.0429299351667245, -2.808917465908704, 0.5000178451051299),
|
|
(3.8916678022926563, -2.7724841532176825, 0.5060648178503389),
|
|
(3.285753657755658, -2.3681294933425394, 0.46449024233003117),
|
|
(2.3005307116270983, -1.6111665557258408, 0.3833374427545273),
|
|
(1.8631210546382593, -1.2042160621002727, 0.3421879560523383),
|
|
(1.8382572152247512, -1.1779263289537742, 0.3396513038637379),
|
|
(1.8749999923301852, -1.2499999836060613, 0.374999991275876),
|
|
],
|
|
}
|
|
|
|
|
|
NSCoeff = Union[str, Tuple[float, float, float], List[Tuple[float, float, float]]]
|
|
|
|
|
|
def zeropower_via_newtonschulz(
|
|
G: torch.Tensor,
|
|
steps: int,
|
|
coefficients: List[Tuple[float, float, float]],
|
|
eps: float = MUON_EPS,
|
|
safety_factor: float = 1.0,
|
|
dtype: torch.dtype = torch.bfloat16,
|
|
) -> torch.Tensor:
|
|
"""Newton-Schulz quintic iteration to compute the zeroth power / orthogonalization of gradient.
|
|
|
|
Supports batched operation over leading dimensions.
|
|
|
|
See
|
|
- https://github.com/KellerJordan/Muon/blob/master/muon.py
|
|
- https://github.com/NoahAmsel/PolarExpress/blob/main/polar_express.py
|
|
- https://github.com/KellerJordan/modded-nanogpt/blob/master/train_gpt.py
|
|
|
|
Args:
|
|
G: Input gradient tensor of shape (m, n) or (batch, m, n)
|
|
steps: Number of Newton-Schulz iterations
|
|
coefficients: Coefficients (a, b, c) for the iteration
|
|
eps: Numerical stability epsilon for norm
|
|
safety_factor: Multiplicative safety factor for norm (1.01 is common safety value in 'polar express' variants)
|
|
dtype: Computation dtype
|
|
|
|
Returns:
|
|
Orthogonalized tensor of same shape as G
|
|
"""
|
|
assert G.ndim in (2, 3), f"Input must be 2D or 3D, got {G.ndim}D. Flatten batch dims first."
|
|
num_cs = len(coefficients)
|
|
assert num_cs >= 1 and len(coefficients[0]) == 3
|
|
# match coefficients with # of steps, truncate or repeat last
|
|
coeff_sequence = coefficients[:steps] if steps <= num_cs else \
|
|
coefficients + [coefficients[-1]] * (steps - num_cs)
|
|
|
|
X = G.to(dtype=dtype, copy=True)
|
|
|
|
# Transpose if needed (operate on dimension with fewer elements)
|
|
transposed = X.size(-2) > X.size(-1)
|
|
if transposed:
|
|
X = X.mT
|
|
|
|
# Normalize spectral norm to at most 1
|
|
X.div_(X.norm(2, dim=(-2, -1), keepdim=True).mul(safety_factor).clamp_min(eps))
|
|
|
|
# Batched vs unbatched fused MM
|
|
mm_fn = torch.baddbmm if X.ndim > 2 else torch.addmm
|
|
|
|
# Pre-allocate
|
|
X = X.contiguous()
|
|
A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype)
|
|
B = torch.empty_like(A)
|
|
C = torch.empty_like(X)
|
|
|
|
# Perform Newton-Schulz iterations
|
|
for a, b, c in coeff_sequence:
|
|
mm_fn(A, X, X.mT, beta=0.0, alpha=1.0, out=A) # A = X @ X.mT
|
|
mm_fn(A, A, A, beta=b, alpha=c, out=B) # B = b * A + c * A @ A
|
|
mm_fn(X, B, X, beta=a, alpha=1.0, out=C) # C = a * X + B @ X
|
|
X, C = C, X # swap refs to avoid copy
|
|
|
|
if transposed:
|
|
X = X.mT
|
|
|
|
return X
|
|
|
|
|
|
def get_lr_scale(
|
|
param_shape: torch.Size,
|
|
adjust_lr_fn: str = "match_rms_adamw"
|
|
) -> float:
|
|
"""Adjust learning rate based on parameter shape."""
|
|
out_chs, in_chs = (param_shape[-2], param_shape[-1]) if len(param_shape) > 1 else (1., 1.)
|
|
|
|
if adjust_lr_fn == "original":
|
|
# Original Muon impl (https://kellerjordan.github.io/posts/muon/)
|
|
return max(1, out_chs / in_chs) ** 0.5
|
|
elif adjust_lr_fn == "match_rms_adamw":
|
|
# Kimi (https://arxiv.org/abs/2502.16982)
|
|
return 0.2 * max(out_chs, in_chs) ** 0.5
|
|
elif adjust_lr_fn == "rms_to_rms":
|
|
# Scion (https://arxiv.org/abs/2502.07529, https://github.com/LIONS-EPFL/scion)
|
|
# Bernstein et al. (https://jeremybernste.in/writing/deriving-muon)
|
|
return (out_chs / in_chs) ** 0.5
|
|
else:
|
|
assert False, f'Invalid scaling function "{adjust_lr_fn}"'
|
|
|
|
|
|
def _is_suitable_for_muon(
|
|
param: torch.Tensor,
|
|
min_dim_size: int = 4,
|
|
max_aspect_ratio: float = 128.,
|
|
return_reason: bool = False,
|
|
) -> Union[bool, Tuple[bool, str]]:
|
|
"""Check if a parameter is suitable for Muon optimization.
|
|
|
|
Args:
|
|
param: Parameter tensor
|
|
min_dim_size: Minimum size for non-unit dimensions
|
|
max_aspect_ratio: Maximum allowed aspect ratio
|
|
return_reason: If True, return (bool, reason_string), else just bool (faster)
|
|
|
|
Returns:
|
|
If return_reason=False: bool indicating suitability
|
|
If return_reason=True: Tuple of (is_suitable, reason_string)
|
|
|
|
Examples:
|
|
(64, 128) -> True (or (True, "ok") if return_reason=True)
|
|
(96, 3, 4, 4) -> True - will be flattened to (96, 48)
|
|
(4, 2048) -> False - extreme aspect ratio
|
|
(64,) -> False - insufficient dims
|
|
(1, 196, 768) -> False - leading unit dims
|
|
|
|
NOTE: these rules were created to balance complexity with covering common timm model cases
|
|
Please let me know if there are non-optimal cases that you run into.
|
|
"""
|
|
|
|
s = param.shape
|
|
# Must have at least 2 non-unit dimensions
|
|
if param.ndim < 2 or sum(1 for dim_size in s if dim_size > 1) < 2:
|
|
return (False, "insufficient_dims") if return_reason else False
|
|
|
|
# Unit dimension in first two positions indicates:
|
|
# - Position embeddings (1, seq, dim)
|
|
# - Depthwise convs (out, 1, h, w)
|
|
# - Other degenerate cases possibly not caught by first rule
|
|
if s[0] == 1 or s[1] == 1:
|
|
return (False, "leading_unit_dims") if return_reason else False
|
|
|
|
if param.ndim >= 3:
|
|
# For 3D+ tensors, check what dimensions will be AFTER flattening
|
|
# since that's what gets passed to Newton-Schulz iteration
|
|
# Flatten mode: (out, in, *spatial) -> (out, in * spatial_prod)
|
|
out_ch = s[0]
|
|
in_ch_with_spatial = 1
|
|
for d in s[1:]:
|
|
in_ch_with_spatial *= d
|
|
check_dims = (out_ch, in_ch_with_spatial)
|
|
else:
|
|
# For 2D tensors, check as-is
|
|
check_dims = s
|
|
|
|
# Both dims should be >= minimum size
|
|
min_size = min(check_dims)
|
|
if min_size < min_dim_size:
|
|
if return_reason:
|
|
return False, f"min_dim_too_small:{min_size}"
|
|
return False
|
|
|
|
# Aspect ratio shouldn't be too extreme
|
|
max_size = max(check_dims)
|
|
aspect_ratio = max_size / min_size
|
|
if aspect_ratio > max_aspect_ratio:
|
|
if return_reason:
|
|
return False, f"extreme_aspect_ratio:{aspect_ratio:.1f}"
|
|
return False
|
|
|
|
return (True, "ok") if return_reason else True
|
|
|
|
|
|
def reshape_for_muon(
|
|
tensor: torch.Tensor,
|
|
mode: str = "flatten",
|
|
) -> Tuple[torch.Tensor, torch.Size]:
|
|
"""Reshape high-dimensional tensor for Muon processing.
|
|
|
|
Args:
|
|
tensor: Input tensor of shape (out, in, *spatial)
|
|
mode: How to handle spatial dimensions
|
|
- "flatten": Flatten spatial into output dimension (out, in*H*W)
|
|
- "batched": Batch over spatial positions (spatial_prod, out, in) for per-position orthogonalization
|
|
|
|
Returns:
|
|
Reshaped tensor and original shape for restoration
|
|
"""
|
|
original_shape = tensor.shape
|
|
if tensor.ndim == 2:
|
|
return tensor, original_shape
|
|
if tensor.ndim < 2:
|
|
raise ValueError(f"Tensor must have at least 2 dimensions, got {tensor.ndim}")
|
|
|
|
out_ch, in_ch = tensor.shape[:2]
|
|
if mode == "flatten":
|
|
# Flatten: (out, in, *spatial) -> (out, in * spatial_prod)
|
|
return tensor.reshape(out_ch, -1), original_shape
|
|
elif mode == "batched":
|
|
# Batched: (out, in, *spatial) -> (spatial_prod, out, in)
|
|
# Move spatial dimension to front so zeropower_via_newtonschulz batches over it
|
|
reshaped = tensor.reshape(out_ch, in_ch, -1) # (out, in, spatial_prod)
|
|
reshaped = reshaped.permute(2, 0, 1) # (spatial_prod, out, in)
|
|
return reshaped, original_shape
|
|
else:
|
|
raise ValueError(f"Unknown mode: {mode}")
|
|
|
|
|
|
def muon(
|
|
params: List[torch.Tensor],
|
|
grads: List[torch.Tensor],
|
|
momentum_bufs: List[torch.Tensor],
|
|
*,
|
|
lr: float,
|
|
weight_decay: float,
|
|
momentum: float,
|
|
nesterov: bool,
|
|
ns_steps: int,
|
|
ns_coefficients: NSCoeff,
|
|
eps: float,
|
|
safety_factor: float,
|
|
adjust_lr_fn: Optional[str],
|
|
conv_mode: str,
|
|
normalize_spatial: bool,
|
|
) -> None:
|
|
"""Functional API that performs Muon algorithm computation."""
|
|
_single_tensor_muon(
|
|
params,
|
|
grads,
|
|
momentum_bufs,
|
|
lr=lr,
|
|
weight_decay=weight_decay,
|
|
momentum=momentum,
|
|
nesterov=nesterov,
|
|
ns_steps=ns_steps,
|
|
ns_coefficients=ns_coefficients,
|
|
eps=eps,
|
|
safety_factor=safety_factor,
|
|
adjust_lr_fn=adjust_lr_fn,
|
|
conv_mode=conv_mode,
|
|
normalize_spatial=normalize_spatial,
|
|
)
|
|
|
|
|
|
def _single_tensor_muon(
|
|
params: List[torch.Tensor],
|
|
grads: List[torch.Tensor],
|
|
momentum_bufs: List[torch.Tensor],
|
|
*,
|
|
lr: float,
|
|
weight_decay: float,
|
|
momentum: float,
|
|
nesterov: bool,
|
|
ns_steps: int,
|
|
ns_coefficients: NSCoeff,
|
|
eps: float,
|
|
safety_factor: float,
|
|
adjust_lr_fn: Optional[str],
|
|
conv_mode: str,
|
|
normalize_spatial: bool,
|
|
) -> None:
|
|
"""Single tensor Muon update."""
|
|
ns_coefficients = resolve_ns_coefficients(ns_coefficients, _COEFFICIENTS)
|
|
|
|
for i, param in enumerate(params):
|
|
grad = grads[i]
|
|
momentum_buf = momentum_bufs[i]
|
|
|
|
# Apply weight decay
|
|
param.mul_(1 - lr * weight_decay)
|
|
|
|
# Update momentum buffer
|
|
momentum_buf.lerp_(grad, 1. - momentum)
|
|
update = grad.lerp_(momentum_buf, momentum) if nesterov else momentum_buf.clone()
|
|
|
|
# Reshape for processing (handle 3D+ tensors like conv weights)
|
|
if update.ndim >= 3:
|
|
update_reshaped, original_shape = reshape_for_muon(update, mode=conv_mode)
|
|
else:
|
|
update_reshaped = update
|
|
original_shape = update.shape
|
|
|
|
# Apply Newton-Schulz orthogonalization
|
|
update_ortho = zeropower_via_newtonschulz(
|
|
update_reshaped,
|
|
ns_steps,
|
|
ns_coefficients,
|
|
eps=eps,
|
|
safety_factor=safety_factor,
|
|
#dtype=torch.bfloat16, # wire to arg?
|
|
)
|
|
|
|
# Adjust learning rate based on parameter shape
|
|
scale = get_lr_scale(update_ortho.shape, adjust_lr_fn)
|
|
|
|
# Apply spatial normalization and permute back if in batched mode
|
|
if conv_mode == "batched" and update_ortho.ndim >= 3:
|
|
if normalize_spatial:
|
|
scale *= update_ortho.shape[0] ** -0.5
|
|
# Permute back: (spatial_prod, out, in) -> (out, in, spatial_prod)
|
|
update_ortho = update_ortho.permute(1, 2, 0)
|
|
|
|
# Reshape back to original shape
|
|
update_ortho = update_ortho.reshape(original_shape)
|
|
|
|
# Apply update
|
|
param.add_(update_ortho, alpha=-lr * scale)
|
|
|
|
|
|
class Muon(torch.optim.Optimizer):
|
|
"""Muon - MomentUm Orthogonalized by Newton-schulz
|
|
|
|
Combines Muon for 2D+ parameters (weight matrices) with AdamW for 1D parameters (biases, norms) and
|
|
parameter groups with 'use_fallback=True' set (or 'use_muon=False' for compatibility).
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
params: ParamsT,
|
|
lr: float = 0.02,
|
|
weight_decay: float = 0,
|
|
momentum: float = 0.95,
|
|
nesterov: bool = False,
|
|
ns_steps: int = DEFAULT_NS_STEPS,
|
|
ns_coefficients: NSCoeff = "quintic",
|
|
eps: float = MUON_EPS,
|
|
safety_factor: float = 1.0,
|
|
adjust_lr_fn: Optional[str] = "match_rms_adamw",
|
|
conv_mode: str = "flatten",
|
|
normalize_spatial: bool = True,
|
|
adamw_lr: Optional[float] = None,
|
|
betas: Tuple[float, float] = (0.9, 0.95),
|
|
verbose: bool = False,
|
|
):
|
|
""" Create Muon optimizer.
|
|
Args:
|
|
params: Iterable of parameters or dicts defining parameter groups
|
|
lr: Learning rate (default: 0.02 for Muon parameters)
|
|
weight_decay: Weight decay coefficient
|
|
momentum: Momentum factor for Muon
|
|
nesterov: Whether to use Nesterov momentum
|
|
ns_steps: Number of Newton-Schulz iterations
|
|
ns_coefficients: Coefficients for NS iteration
|
|
eps: Numerical stability epsilon
|
|
safety_factor: Multiplicative safety factor for NS norm
|
|
adjust_lr_fn: LR adjustment function - "original" or "match_rms_adamw"
|
|
conv_mode: How to handle convolutions - "flatten" or "batched"
|
|
normalize_spatial: Whether to normalize by sqrt(spatial_size) in batched mode
|
|
adamw_lr: Learning rate for AdamW (1D params), defaults to lr if not specified
|
|
betas: AdamW beta coefficients
|
|
verbose: Log parameter routing decisions (Muon vs AdamW)
|
|
|
|
Example:
|
|
```python
|
|
# Simple usage - automatically uses Muon for 2D+ params, AdamW for 1D
|
|
optimizer = Muon(model.parameters(), lr=0.02)
|
|
|
|
# Manual control over parameter groups
|
|
optimizer = Muon([
|
|
{'params': weight_matrices, 'lr': 0.02},
|
|
{'params': biases, 'use_fallback': True, 'lr': 3e-4}, # use AdamW if use_fallback=True
|
|
])
|
|
```
|
|
"""
|
|
if not 0.0 <= lr:
|
|
raise ValueError(f"Invalid learning rate: {lr}")
|
|
if not 0.0 <= weight_decay:
|
|
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
|
if not 0.0 <= momentum < 1.0:
|
|
raise ValueError(f"Invalid momentum value: {momentum}")
|
|
if not 0.0 <= eps:
|
|
raise ValueError(f"Invalid epsilon value: {eps}")
|
|
if conv_mode not in ["flatten", "batched"]:
|
|
raise ValueError(f"Invalid conv_mode: {conv_mode}")
|
|
|
|
defaults = dict(
|
|
lr=lr,
|
|
weight_decay=weight_decay,
|
|
momentum=momentum,
|
|
nesterov=nesterov,
|
|
ns_steps=ns_steps,
|
|
ns_coefficients=ns_coefficients,
|
|
eps=eps,
|
|
safety_factor=safety_factor,
|
|
adjust_lr_fn=adjust_lr_fn,
|
|
conv_mode=conv_mode,
|
|
normalize_spatial=normalize_spatial,
|
|
adamw_lr=adamw_lr if adamw_lr is not None else lr,
|
|
betas=betas,
|
|
verbose=verbose,
|
|
)
|
|
super().__init__(params, defaults)
|
|
|
|
@torch.no_grad()
|
|
def step(self, closure=None):
|
|
"""Performs a single optimization step."""
|
|
loss = None
|
|
if closure is not None:
|
|
with torch.enable_grad():
|
|
loss = closure()
|
|
|
|
verbose = self.defaults.get("verbose", False)
|
|
|
|
# Tracking for logging (populated on first encounter of each param)
|
|
muon_count = 0
|
|
adamw_count = 0
|
|
routing_reasons = {} if verbose else None
|
|
|
|
for group in self.param_groups:
|
|
# Separate params into Muon and AdamW groups
|
|
muon_params = []
|
|
muon_grads = []
|
|
muon_momentum_bufs = []
|
|
|
|
adamw_params = []
|
|
adamw_grads = []
|
|
adamw_exp_avgs = []
|
|
adamw_exp_avg_sqs = []
|
|
adamw_state_steps = []
|
|
|
|
for p in group["params"]:
|
|
if p.grad is None:
|
|
continue
|
|
|
|
if p.grad.is_sparse:
|
|
raise RuntimeError("Muon does not support sparse gradients")
|
|
|
|
state = self.state[p]
|
|
|
|
# Determine routing on first encounter (cache in state)
|
|
if "use_muon" not in state:
|
|
# Check explicit flags first (support both 'use_fallback' and 'use_muon' for compatibility)
|
|
reason = None
|
|
if group.get("use_fallback", False):
|
|
# use_fallback=True means use AdamW (use_muon=False)
|
|
state["use_muon"] = False
|
|
if verbose:
|
|
reason = "use_fallback_flag"
|
|
elif "use_muon" in group:
|
|
# Explicit use_muon flag for compatibility with other Muon implementations
|
|
state["use_muon"] = group["use_muon"]
|
|
if verbose:
|
|
reason = "use_muon_flag"
|
|
else:
|
|
# Check shape suitability
|
|
if verbose:
|
|
suitable, reason = _is_suitable_for_muon(p, return_reason=True)
|
|
else:
|
|
suitable = _is_suitable_for_muon(p, return_reason=False)
|
|
state["use_muon"] = suitable
|
|
|
|
# Track routing decision for logging
|
|
if routing_reasons is not None and reason is not None:
|
|
shape_str = "x".join(str(s) for s in p.shape)
|
|
if shape_str not in routing_reasons:
|
|
routing_reasons[shape_str] = []
|
|
routing_reasons[shape_str].append(reason)
|
|
|
|
# Use cached routing decision
|
|
use_muon = state["use_muon"]
|
|
if use_muon:
|
|
# Collect Muon params
|
|
muon_params.append(p)
|
|
muon_grads.append(p.grad)
|
|
muon_count += 1
|
|
|
|
# State initialization for Muon
|
|
if "momentum_buffer" not in state:
|
|
state["momentum_buffer"] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
|
muon_momentum_bufs.append(state["momentum_buffer"])
|
|
else:
|
|
# Collect AdamW/NAdamW params
|
|
adamw_params.append(p)
|
|
adamw_grads.append(p.grad)
|
|
adamw_count += 1
|
|
|
|
# State initialization for AdamW
|
|
if "step" not in state:
|
|
state["step"] = torch.tensor(0.)
|
|
state["exp_avg"] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
|
state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
|
|
|
adamw_exp_avgs.append(state["exp_avg"])
|
|
adamw_exp_avg_sqs.append(state["exp_avg_sq"])
|
|
adamw_state_steps.append(state["step"])
|
|
|
|
# Apply Muon updates
|
|
if muon_params:
|
|
muon(
|
|
muon_params,
|
|
muon_grads,
|
|
muon_momentum_bufs,
|
|
lr=group["lr"],
|
|
weight_decay=group["weight_decay"],
|
|
momentum=group["momentum"],
|
|
nesterov=group["nesterov"],
|
|
ns_steps=group["ns_steps"],
|
|
ns_coefficients=group["ns_coefficients"],
|
|
eps=group["eps"],
|
|
safety_factor=group["safety_factor"],
|
|
adjust_lr_fn=group["adjust_lr_fn"],
|
|
conv_mode=group["conv_mode"],
|
|
normalize_spatial=group["normalize_spatial"],
|
|
)
|
|
|
|
# Apply AdamW updates
|
|
if adamw_params:
|
|
beta1, beta2 = group["betas"]
|
|
if group["nesterov"]:
|
|
# use nadamw for fallback optimizer if nesterov is enabled
|
|
nadamw(
|
|
adamw_params,
|
|
adamw_grads,
|
|
adamw_exp_avgs,
|
|
adamw_exp_avg_sqs,
|
|
adamw_state_steps,
|
|
foreach=None,
|
|
beta1=beta1,
|
|
beta2=beta2,
|
|
lr=group["adamw_lr"],
|
|
weight_decay=group["weight_decay"],
|
|
eps=group["eps"],
|
|
caution=False,
|
|
maximize=False,
|
|
capturable=False,
|
|
max_lr=None,
|
|
)
|
|
else:
|
|
adamw(
|
|
adamw_params,
|
|
adamw_grads,
|
|
adamw_exp_avgs,
|
|
adamw_exp_avg_sqs,
|
|
[], # max_exp_avg_sqs (not using amsgrad)
|
|
adamw_state_steps,
|
|
foreach=None,
|
|
amsgrad=False,
|
|
beta1=beta1,
|
|
beta2=beta2,
|
|
lr=group["adamw_lr"],
|
|
weight_decay=group["weight_decay"],
|
|
eps=group["eps"],
|
|
caution=False,
|
|
maximize=False,
|
|
capturable=False,
|
|
max_lr=None,
|
|
)
|
|
|
|
# Log routing summary when we have new routing decisions
|
|
if routing_reasons and len(routing_reasons) > 0:
|
|
# Concise summary
|
|
_logger.info(f"Muon parameter routing: {muon_count} Muon, {adamw_count} AdamW")
|
|
|
|
# Group by reason for detailed breakdown
|
|
reason_groups = {}
|
|
for shape_str, reasons in sorted(routing_reasons.items()):
|
|
for reason in reasons:
|
|
if reason not in reason_groups:
|
|
reason_groups[reason] = []
|
|
reason_groups[reason].append(shape_str)
|
|
|
|
# Log summary counts per reason
|
|
reason_summary = []
|
|
for reason, shapes in sorted(reason_groups.items()):
|
|
reason_summary.append(f"{reason}={len(shapes)}")
|
|
_logger.info(f" Breakdown: {', '.join(reason_summary)}")
|
|
|
|
# Detailed breakdown at INFO level
|
|
if _logger.isEnabledFor(logging.INFO):
|
|
for reason, shapes in sorted(reason_groups.items()):
|
|
optimizer_name = "Muon" if reason == "ok" else "AdamW"
|
|
_logger.info(f" {reason} -> {optimizer_name}:")
|
|
for shape in shapes[:10]:
|
|
_logger.info(f" {shape}")
|
|
if len(shapes) > 10:
|
|
_logger.info(f" ... and {len(shapes) - 10} more")
|
|
|
|
return loss
|
|
|
|
|
|
def resolve_ns_coefficients(
|
|
value: Union[str, Sequence[float], Sequence[Sequence[float]]],
|
|
presets: Mapping[str, Sequence[Sequence[float]]]
|
|
) -> List[Tuple[float, float, float]]:
|
|
# tiny helpers (kept inline for succinctness)
|
|
is_seq = lambda x: isinstance(x, Sequence) and not isinstance(x, (str, bytes))
|
|
is_real = lambda x: isinstance(x, numbers.Real) and not isinstance(x, bool)
|
|
|
|
def as_coeff(x: Sequence[float]) -> Tuple[float, float, float]:
|
|
if not is_seq(x) or len(x) != 3 or not all(is_real(v) for v in x):
|
|
raise ValueError(f"Coefficient must be length-3 of real numbers, got: {x!r}")
|
|
a, b, c = x # type: ignore[misc]
|
|
return float(a), float(b), float(c)
|
|
|
|
if isinstance(value, str):
|
|
if value not in presets:
|
|
valid = ", ".join(sorted(presets.keys()))
|
|
raise ValueError(f"Unknown coefficients preset '{value}'. Valid options: {valid}")
|
|
seq = presets[value]
|
|
if not is_seq(seq) or len(seq) == 0:
|
|
raise ValueError(f"Preset '{value}' is empty or invalid")
|
|
return [as_coeff(item) for item in seq] # validate & cast
|
|
|
|
if not is_seq(value):
|
|
raise TypeError(
|
|
"Coefficients must be a preset name (str), a 3-sequence (a,b,c), "
|
|
"or a sequence of 3-sequences."
|
|
)
|
|
|
|
# Decide single triple vs list-of-triples by structure
|
|
if len(value) == 3 and all(is_real(v) for v in value): # type: ignore[index]
|
|
return [as_coeff(value)] # single triple -> wrap
|
|
|
|
# Otherwise treat as list/tuple of triples
|
|
out = []
|
|
for i, item in enumerate(value): # type: ignore[assignment]
|
|
if not is_seq(item):
|
|
raise TypeError(f"Item {i} is not a sequence: {item!r}")
|
|
out.append(as_coeff(item))
|
|
if not out:
|
|
raise ValueError("Coefficient list cannot be empty")
|
|
return out |