tinygrad/examples/beautiful_mnist.py
kevvz e2873a3a41
[bounty] Muon optim (#11414)
* newton schulz

* add muon + move newton schulz to tensor

* compact newton schulz

* better tests

* cleanup

* add comments for muon

* cleanup

* add export with tests

* match muon optim with test optim

* cleanup

* unsed import

* correct comment

* whitespace

* move export

* muon test fix

* match reference impl + tests

* remove export by moving muon device

* add credit

* cleanup

* remove print

* spacing

* spacing

* comma

* cleanup

* removal

* fix tests + optim momentum

* consistent is not/ not

* more consistency

* fix test

* cleanup

* fix the nones

* remove comment

* cast

* comment

* comment

* muon teeny test

* muon flag beautiful mnist

* set steps

* steps as hyperparam

* match default test steps

* name

* large cleanup

* dont care about steps

* nesterov false default

* match each other impl

* steps

* switch nest

* swap defaults

* update docstring

* add no nesterov test

* ban fuse_optim

* prints

* classical momentum

* alternative condition

* recon

* pre + post wd

* false default

* detach

* signature changes

* context

* swap order

* big cleanup

* 0 step instead

* parity

* remove fuse

* remove fused

* better paper

* assert message

* correct shape check + eps

* multidim

* add eps

* cleanup

* correct assert message

* lint

* better tests

* naming

* ns_steps,ns_params

* update docstring

* docstring

* match sgd and muon together

* sandwich

* add back fused

* parity

---------

Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
2025-08-13 14:27:55 -04:00

48 lines
1.9 KiB
Python

# model based off https://medium.com/data-science/going-beyond-99-mnist-handwritten-digits-recognition-cfff96337392
from typing import Callable
from tinygrad import Tensor, TinyJit, nn, GlobalCounters
from tinygrad.helpers import getenv, colored, trange
from tinygrad.nn.datasets import mnist
class Model:
def __init__(self):
self.layers: list[Callable[[Tensor], Tensor]] = [
nn.Conv2d(1, 32, 5), Tensor.relu,
nn.Conv2d(32, 32, 5), Tensor.relu,
nn.BatchNorm(32), Tensor.max_pool2d,
nn.Conv2d(32, 64, 3), Tensor.relu,
nn.Conv2d(64, 64, 3), Tensor.relu,
nn.BatchNorm(64), Tensor.max_pool2d,
lambda x: x.flatten(1), nn.Linear(576, 10)]
def __call__(self, x:Tensor) -> Tensor: return x.sequential(self.layers)
if __name__ == "__main__":
X_train, Y_train, X_test, Y_test = mnist(fashion=getenv("FASHION"))
model = Model()
opt = (nn.optim.Adam if not getenv("MUON") else nn.optim.Muon)(nn.state.get_parameters(model))
@TinyJit
@Tensor.train()
def train_step() -> Tensor:
opt.zero_grad()
samples = Tensor.randint(getenv("BS", 512), high=X_train.shape[0])
loss = model(X_train[samples]).sparse_categorical_crossentropy(Y_train[samples]).backward()
opt.step()
return loss
@TinyJit
def get_test_acc() -> Tensor: return (model(X_test).argmax(axis=1) == Y_test).mean()*100
test_acc = float('nan')
for i in (t:=trange(getenv("STEPS", 70))):
GlobalCounters.reset() # NOTE: this makes it nice for DEBUG=2 timing
loss = train_step()
if i%10 == 9: test_acc = get_test_acc().item()
t.set_description(f"loss: {loss.item():6.2f} test_accuracy: {test_acc:5.2f}%")
# verify eval acc
if target := getenv("TARGET_EVAL_ACC_PCT", 0.0):
if test_acc >= target and test_acc != 100.0: print(colored(f"{test_acc=} >= {target}", "green"))
else: raise ValueError(colored(f"{test_acc=} < {target}", "red"))