mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
[bounty] Muon optim (#11414)
* newton schulz * add muon + move newton schulz to tensor * compact newton schulz * better tests * cleanup * add comments for muon * cleanup * add export with tests * match muon optim with test optim * cleanup * unsed import * correct comment * whitespace * move export * muon test fix * match reference impl + tests * remove export by moving muon device * add credit * cleanup * remove print * spacing * spacing * comma * cleanup * removal * fix tests + optim momentum * consistent is not/ not * more consistency * fix test * cleanup * fix the nones * remove comment * cast * comment * comment * muon teeny test * muon flag beautiful mnist * set steps * steps as hyperparam * match default test steps * name * large cleanup * dont care about steps * nesterov false default * match each other impl * steps * switch nest * swap defaults * update docstring * add no nesterov test * ban fuse_optim * prints * classical momentum * alternative condition * recon * pre + post wd * false default * detach * signature changes * context * swap order * big cleanup * 0 step instead * parity * remove fuse * remove fused * better paper * assert message * correct shape check + eps * multidim * add eps * cleanup * correct assert message * lint * better tests * naming * ns_steps,ns_params * update docstring * docstring * match sgd and muon together * sandwich * add back fused * parity --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
parent
94e6d84e32
commit
e2873a3a41
6 changed files with 149 additions and 6 deletions
|
|
@ -21,7 +21,7 @@ if __name__ == "__main__":
|
|||
X_train, Y_train, X_test, Y_test = mnist(fashion=getenv("FASHION"))
|
||||
|
||||
model = Model()
|
||||
opt = nn.optim.Adam(nn.state.get_parameters(model))
|
||||
opt = (nn.optim.Adam if not getenv("MUON") else nn.optim.Muon)(nn.state.get_parameters(model))
|
||||
|
||||
@TinyJit
|
||||
@Tensor.train()
|
||||
|
|
|
|||
75
extra/torch_muon.py
Normal file
75
extra/torch_muon.py
Normal file
|
|
@ -0,0 +1,75 @@
|
|||
import torch
|
||||
|
||||
#credit to KellerJordan at https://github.com/KellerJordan/Muon/tree/master
|
||||
#some changes: classic momentum instead of weighting gradient
|
||||
#added ns_steps, ns_params, nesterov as hyperparams
|
||||
def zeropower_via_newtonschulz5(G:torch.tensor, steps:int, params:tuple[int, ...]):
|
||||
"""
|
||||
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
|
||||
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
|
||||
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
|
||||
zero even beyond the point where the iteration no longer converges all the way to one everywhere
|
||||
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
|
||||
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
|
||||
performance at all relative to UV^T, where USV^T = G is the SVD.
|
||||
"""
|
||||
assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
|
||||
|
||||
a, b, c = params
|
||||
X = G
|
||||
if G.size(-2) > G.size(-1):
|
||||
X = X.mT
|
||||
|
||||
# Ensure spectral norm is at most 1
|
||||
X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
|
||||
# Perform the NS iterations
|
||||
for _ in range(steps):
|
||||
A = X @ X.mT
|
||||
B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
|
||||
X = a * X + B @ X
|
||||
|
||||
if G.size(-2) > G.size(-1):
|
||||
X = X.mT
|
||||
|
||||
return X
|
||||
|
||||
def muon_update(grad, momentum, beta=0.95, ns_steps=5, ns_params=(3.4445, -4.7750, 2.0315), nesterov=True):
|
||||
if beta:
|
||||
momentum.mul_(beta).add_(grad)
|
||||
update = grad.add(momentum,alpha=beta) if nesterov else momentum
|
||||
else: update = grad
|
||||
if update.ndim == 4: # for the case of conv filters
|
||||
update = update.view(len(update), -1)
|
||||
update = zeropower_via_newtonschulz5(update, steps=ns_steps, params=ns_params)
|
||||
return update
|
||||
|
||||
class SingleDeviceMuon(torch.optim.Optimizer):
|
||||
"""
|
||||
Muon variant for usage in non-distributed settings.
|
||||
"""
|
||||
def __init__(self, params, lr=0.02, weight_decay=0.0, momentum=0.95, ns_steps=5, ns_params=(3.4445, -4.7750, 2.0315), nesterov=True):
|
||||
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, ns_steps=ns_steps, ns_params=ns_params, nesterov=nesterov)
|
||||
super().__init__(params, defaults)
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
|
||||
loss = None
|
||||
if closure is not None:
|
||||
with torch.enable_grad():
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
for p in group["params"]:
|
||||
if p.grad is None:
|
||||
p.grad = torch.zeros_like(p) # Force synchronization
|
||||
state = self.state[p]
|
||||
if len(state) == 0:
|
||||
state["momentum_buffer"] = torch.zeros_like(p)
|
||||
update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"], ns_steps=group["ns_steps"],
|
||||
ns_params=group["ns_params"], nesterov=group["nesterov"])
|
||||
p.mul_(1.0 - group["lr"] * group["weight_decay"])
|
||||
|
||||
p.add_(update.reshape(p.shape), alpha=-group["lr"])
|
||||
|
||||
return loss
|
||||
|
|
@ -62,5 +62,15 @@ class TestLinAlg(unittest.TestCase):
|
|||
orthogonality_helper(Q)
|
||||
reconstruction_helper([Q,R],a)
|
||||
|
||||
def test_newton_schulz(self):
|
||||
coefficients = [(2, -1.5, 0.5), (2.0, -1.4, 0.2, 0.2)]#these params map to the sign function
|
||||
sizes = [(2,2), (3,2), (2,3), (2,2,2)]
|
||||
for coefs in coefficients:
|
||||
for size in sizes:
|
||||
a = Tensor.randn(size)
|
||||
b = Tensor.newton_schulz(a, steps=20, params=coefs, eps=0.0)
|
||||
# ns(A) = U @ Vt -> (U @ Vt) @ (U @ Vt)t = I
|
||||
orthogonality_helper(b if size[-1] > size[-2] else b.transpose(-2, -1), tolerance=1e-1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
@ -2,9 +2,10 @@ import numpy as np
|
|||
import torch
|
||||
import unittest
|
||||
from tinygrad import Tensor, Device, dtypes
|
||||
from tinygrad.nn.optim import Adam, SGD, AdamW
|
||||
from tinygrad.nn.optim import Adam, SGD, AdamW, Muon
|
||||
from tinygrad.helpers import CI
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from extra.torch_muon import SingleDeviceMuon as TorchMuon
|
||||
|
||||
np.random.seed(1337)
|
||||
x_init = np.random.randn(1,4).astype(np.float32)
|
||||
|
|
@ -57,9 +58,12 @@ class TestOptim(unittest.TestCase):
|
|||
def _test_sgd(self, steps, opts, atol, rtol): self._test_optim(SGD, torch.optim.SGD, steps, opts, atol, rtol)
|
||||
def _test_adam(self, steps, opts, atol, rtol): self._test_optim(Adam, torch.optim.Adam, steps, opts, atol, rtol)
|
||||
def _test_adamw(self, steps, opts, atol, rtol): self._test_optim(AdamW, torch.optim.AdamW, steps, opts, atol, rtol)
|
||||
#TODO: use torch.muon when it comes out
|
||||
def _test_muon(self, steps, opts, atol, rtol): self._test_optim(Muon, TorchMuon, steps, opts, atol, rtol)
|
||||
|
||||
def test_multistep_sgd_high_lr_teeny(self): self._test_sgd(2, {'lr': 1.1, 'teeny': True}, 1e-6, 1e-5)
|
||||
def test_multistep_adam_high_lr_teeny(self): self._test_adam(2, {'lr': 1.1, 'teeny': True}, 2e-4, 5e-4)
|
||||
def test_multistep_muon_high_lr_teeny(self): self._test_muon(2, {'lr': 1.1, 'teeny': True}, 2e-4, 5e-4)
|
||||
|
||||
def test_sgd(self): self._test_sgd(1, {'lr': 0.001}, 1e-6, 0)
|
||||
def test_sgd_high_lr(self): self._test_sgd(1, {'lr': 10}, 1e-6, 1e-5)
|
||||
|
|
@ -83,6 +87,28 @@ class TestOptim(unittest.TestCase):
|
|||
def test_multistep_sgd_high_lr_nesterov_momentum_wd(self):
|
||||
self._test_sgd(10, {'lr': 9, 'momentum': 0.9, 'nesterov': True, 'weight_decay': 0.1}, 1e-5, 3e-4)
|
||||
|
||||
def test_muon(self): self._test_muon(1, {'lr': 0.001}, 1e-6, 0)
|
||||
def test_muon_high_lr(self): self._test_muon(1, {'lr': 10}, 1e-6, 3e-4)
|
||||
def test_muon_wd(self): self._test_muon(1, {'lr': 0.001, 'weight_decay': 0.01}, 1e-6, 0)
|
||||
def test_muon_high_lr_wd(self): self._test_muon(1, {'lr': 10, 'weight_decay': 0.01}, 1e-6, 3e-4)
|
||||
|
||||
# NOTE: momentum set to 0.95 by default, nesterov set to True by default
|
||||
def test_multistep_muon_momentum_wd(self): self._test_muon(10, {'lr': 0.001, 'weight_decay': 0.01}, 1e-5, 0)
|
||||
# ns defaults are numerically unstable, but it is tolerable in real training (see nsteps/nparam tests)
|
||||
def test_multistep_muon_high_lr_momentum_wd(self): self._test_muon(10, {'lr': 10, 'weight_decay': 0.01}, 1e-1, 3e-4)
|
||||
def test_multistep_muon_no_nesterov_momentum(self): self._test_muon(10, {'lr': 0.001, 'nesterov': False}, 1e-5, 0)
|
||||
def test_multistep_muon_high_lr_no_nesterov_momentum(self): self._test_muon(10, {'lr': 10, 'nesterov': False}, 0.5e-1, 1e-1)
|
||||
|
||||
def test_muon_ns_steps(self): self._test_muon(1, {'lr': 0.001, 'ns_steps': 3}, 1e-6, 0)
|
||||
def test_muon_high_lr_ns_steps(self): self._test_muon(1, {'lr': 10, 'ns_steps': 3}, 1e-5, 3e-4)
|
||||
def test_muon_ns_params(self): self._test_muon(1, {'lr': 0.001,'ns_params': (2.0,-1.5,0.5)}, 1e-6, 0)
|
||||
def test_muon_high_lr_ns_params(self): self._test_muon(1, {'lr': 10,'ns_params': (2.0,-1.5,0.5)}, 1e-5, 3e-4)
|
||||
|
||||
def test_muon_momentum_wd_ns_steps_ns_params(self):
|
||||
self._test_muon(10, {'lr': 0.001, 'momentum': 0.90, 'weight_decay': 0.01, 'ns_steps': 3, 'ns_params': (2.0,-1.5,0.5)}, 1e-5, 0)
|
||||
def test_multistep_muon_high_lr_momentum_wd_ns_steps_ns_params(self):
|
||||
self._test_muon(10, {'lr': 10, 'momentum': 0.90, 'weight_decay': 0.01, 'ns_steps': 3, 'ns_params': (2.0,-1.5,0.5)}, 1e-5, 3e-4)
|
||||
|
||||
def test_adam(self): self._test_adam(1, {'lr': 0.001}, 1e-5, 0)
|
||||
def test_adam_high_lr(self): self._test_adam(1, {'lr': 10}, 1e-4, 1e-4)
|
||||
def test_adamw(self): self._test_adamw(1, {'lr': 0.001}, 1e-5, 0)
|
||||
|
|
|
|||
|
|
@ -77,7 +77,19 @@ def SGD(params: list[Tensor], lr=0.001, momentum=0.0, weight_decay=0.0, nesterov
|
|||
|
||||
`classic` is a boolean flag that determines whether to use the popular momentum update rule or the classic momentum update rule.
|
||||
"""
|
||||
return LARS(params, lr, momentum, weight_decay, nesterov, classic, tcoef=0.0, fused=fused)
|
||||
return LARS(params, lr, momentum, weight_decay, 0, None, nesterov, classic=classic, pre_wd=True, tcoef=0.0, fused=fused)
|
||||
|
||||
# Muon applies the newton schulz algorithm on gradient. also can include momentum, nesterov, and weight decay
|
||||
def Muon(params: list[Tensor], lr=0.02, momentum=0.95, weight_decay=0.0, ns_steps=5, ns_params=(3.4445, -4.775, 2.0315),
|
||||
nesterov=True, fused=FUSE_OPTIM):
|
||||
"""
|
||||
SGD with newton-schulz iteration and post momentum weight decay.
|
||||
|
||||
- Described: https://kellerjordan.github.io/posts/muon/
|
||||
- Paper: https://arxiv.org/pdf/2502.16982
|
||||
"""
|
||||
assert not fused, "FUSE_OPTIM not allowed for Muon optimizer"
|
||||
return LARS(params, lr, momentum, weight_decay, ns_steps, ns_params, nesterov, classic=False, pre_wd=False, tcoef=0.0, fused=fused)
|
||||
|
||||
class LARS(Optimizer):
|
||||
"""
|
||||
|
|
@ -85,9 +97,11 @@ class LARS(Optimizer):
|
|||
|
||||
- Paper: https://arxiv.org/abs/1708.03888v3
|
||||
"""
|
||||
def __init__(self, params:list[Tensor], lr=0.001, momentum=0.9, weight_decay=1e-4, nesterov=False, classic=True, tcoef=0.001, fused=FUSE_OPTIM):
|
||||
def __init__(self, params:list[Tensor], lr=0.001, momentum=0.9, weight_decay=1e-4, ns_steps=0, ns_params=None,
|
||||
nesterov=False, classic=True, pre_wd=True, tcoef=0.001, fused=FUSE_OPTIM):
|
||||
super().__init__(params, lr, fused)
|
||||
self.momentum, self.wd, self.nesterov, self.classic, self.tcoef = momentum, weight_decay, nesterov, classic, tcoef
|
||||
self.momentum, self.wd, self.ns_steps, self.ns_params = momentum, weight_decay, ns_steps, ns_params
|
||||
self.nesterov, self.classic, self.pre_wd, self.tcoef = nesterov, classic, pre_wd, tcoef
|
||||
self.b = self._new_optim_param() if self.momentum else []
|
||||
|
||||
def _step(self, params:list[Tensor], grads:list[Tensor]) -> tuple[list[Tensor], list[Tensor]]:
|
||||
|
|
@ -98,7 +112,7 @@ class LARS(Optimizer):
|
|||
r2 = g.square().sum().sqrt()
|
||||
r:Tensor|float = (r1 > 0).where((r2 > 0).where(self.tcoef * r1 / (r2 + self.wd * r1), 1.0), 1.0)
|
||||
else: r = 1.0
|
||||
if self.wd > 0: g = g + self.wd * t.detach()
|
||||
if self.pre_wd and self.wd > 0: g = g + self.wd * t.detach()
|
||||
# classic momentum does post learning rate update
|
||||
if self.classic: g = g * r * self.lr
|
||||
if self.momentum:
|
||||
|
|
@ -106,6 +120,9 @@ class LARS(Optimizer):
|
|||
# the scheduler should detect this and just insert contiguous
|
||||
self.b[i].assign(self.momentum * self.b[i].contiguous() + g) # NOTE: self.b[i] is zero on the first run, no if required
|
||||
g = (g + self.momentum * self.b[i]) if self.nesterov else self.b[i]
|
||||
if self.ns_params: g = g.reshape(g.shape[0], -1).newton_schulz(self.ns_steps, self.ns_params).reshape(g.shape)
|
||||
# muon does post momentum weight decay
|
||||
if not self.pre_wd and self.wd > 0: t = t.detach() * (1.0 - self.wd * self.lr)
|
||||
# popular momentum does pre learning rate update
|
||||
if not self.classic: g = g * r * self.lr
|
||||
ret.append((t.detach() - g).cast(t.dtype))
|
||||
|
|
|
|||
|
|
@ -4033,6 +4033,21 @@ class Tensor(MathTrait):
|
|||
nll = -self.gather(1, Y.unsqueeze(1)).squeeze(1) * masked_weight
|
||||
return nll.sum() / masked_weight.sum() if reduction == "mean" else nll._do_reduction(reduction)
|
||||
|
||||
def newton_schulz(self, steps:int, params:tuple[int, ...], eps:float=1.0e-7) -> Tensor:
|
||||
"""
|
||||
Performs the newton-schulz algorithm for odd polynomials. The degree of the odd polynomial depends on the number of params.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor.randn(4, 4)
|
||||
print(t.newton_schulz(steps=5, params=(2,-1.5,0.5)).numpy())
|
||||
```
|
||||
"""
|
||||
assert self.ndim > 1, "NS only works for two or more dims"
|
||||
G = self / (self.square().sum(axis=(-2, -1), keepdim=True).sqrt() + eps)
|
||||
G = G.transpose(-2, -1) if self.shape[-2] > self.shape[-1] else G
|
||||
for _ in range(steps): G = sum(p * functools.reduce(lambda x, y: (y @ y.transpose(-2, -1)) @ x, [G]*i, G) for i,p in enumerate(params))
|
||||
return G.transpose(-2, -1) if self.shape[-2] > self.shape[-1] else G
|
||||
|
||||
def qr(self) -> tuple[Tensor, Tensor]:
|
||||
assert self.ndim > 1, f"expected two or more dimensions, got {self.ndim}"
|
||||
R = self.clone()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue