noether.core.optimizer.muon_composite

MuonComposite: torch.optim.Muon for 2D params, any optimizer for the rest.

Parameters are routed based on dimensionality: ndim >= 2 goes to Muon, everything else (biases, norms, 1D embeddings) goes to the secondary optimizer.

Classes

MuonComposite

Composite optimizer using torch.optim.Muon for 2D weight matrices

Module Contents

class noether.core.optimizer.muon_composite.MuonComposite(params, lr=0.01, momentum=0.95, weight_decay=0.01, secondary=None, nesterov=None, ns_steps=None, adjust_lr_fn=None)

Bases: torch.optim.Optimizer

Composite optimizer using torch.optim.Muon for 2D weight matrices and a configurable secondary optimizer for all other parameters (biases, norms, embeddings).

Parameters:
  • params – Iterable of parameter groups.

  • lr – Learning rate for the Muon optimizer.

  • momentum – Momentum factor for the Muon optimizer.

  • weight_decay – Weight decay for the Muon optimizer.

  • secondary – Configuration dict for the secondary optimizer (biases, norms, embeddings).

  • nesterov – Enable Nesterov momentum in Muon. None uses Muon’s default (True).

  • ns_steps – Number of Newton-Schulz iteration steps. None uses Muon’s default (5).

  • adjust_lr_fn – Per-matrix LR adjustment strategy for Muon. One of "original" or "match_rms_adamw". None uses Muon’s default ("original").

param_groups = []
step(closure=None)
zero_grad(set_to_none=True)
state_dict()
load_state_dict(state_dict)