add STEPS to beautiful_mnist

This commit is contained in:
George Hotz 2024-08-10 15:23:44 -07:00
commit 14b613e281

View file

@ -37,7 +37,7 @@ if __name__ == "__main__":
def get_test_acc() -> Tensor: return (model(X_test).argmax(axis=1) == Y_test).mean()*100
test_acc = float('nan')
for i in (t:=trange(70)):
for i in (t:=trange(getenv("STEPS", 70))):
GlobalCounters.reset() # NOTE: this makes it nice for DEBUG=2 timing
loss = train_step()
if i%10 == 9: test_acc = get_test_acc().item()