set allow_implicit=False by default (#15319)

* set allow_implicit=False by default

* modernize beautiful mnist
This commit is contained in:
George Hotz 2026-03-17 17:14:38 +08:00 committed by GitHub
commit 9d95321be3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 33 additions and 32 deletions

View file

@ -1,6 +1,6 @@
# model based off https://medium.com/data-science/going-beyond-99-mnist-handwritten-digits-recognition-cfff96337392
from typing import Callable
from tinygrad import Tensor, TinyJit, nn, GlobalCounters
from tinygrad import Tensor, TinyJit, nn, GlobalCounters, function
from tinygrad.helpers import getenv, colored, trange
from tinygrad.nn.datasets import mnist
@ -15,30 +15,31 @@ class Model:
nn.BatchNorm(64), Tensor.max_pool2d,
lambda x: x.flatten(1), nn.Linear(576, 10)]
@function
def __call__(self, x:Tensor) -> Tensor: return x.sequential(self.layers)
@TinyJit
@Tensor.train()
def train_step(self, X_train:Tensor, Y_train:Tensor) -> Tensor:
opt.zero_grad()
samples = Tensor.randint(getenv("BS", 512), high=X_train.shape[0])
loss = self(X_train[samples]).sparse_categorical_crossentropy(Y_train[samples]).backward()
return loss.realize(*opt.schedule_step())
@TinyJit
def get_test_acc(self, X_test:Tensor, Y_test:Tensor) -> Tensor: return (self(X_test).argmax(axis=1) == Y_test).mean()*100
if __name__ == "__main__":
X_train, Y_train, X_test, Y_test = mnist(fashion=getenv("FASHION"))
model = Model()
opt = (nn.optim.Muon if getenv("MUON") else nn.optim.SGD if getenv("SGD") else nn.optim.Adam)(nn.state.get_parameters(model))
@TinyJit
@Tensor.train()
def train_step() -> Tensor:
opt.zero_grad()
samples = Tensor.randint(getenv("BS", 512), high=X_train.shape[0])
loss = model(X_train[samples]).sparse_categorical_crossentropy(Y_train[samples]).backward()
return loss.realize(*opt.schedule_step())
@TinyJit
def get_test_acc() -> Tensor: return (model(X_test).argmax(axis=1) == Y_test).mean()*100
test_acc = float('nan')
for i in (t:=trange(getenv("STEPS", 70))):
GlobalCounters.reset() # NOTE: this makes it nice for DEBUG=2 timing
loss = train_step()
if i%10 == 9: test_acc = get_test_acc().item()
loss = model.train_step(X_train, Y_train)
if i%10 == 9: test_acc = model.get_test_acc(X_test, Y_test).item()
t.set_description(f"loss: {loss.item():6.2f} test_accuracy: {test_acc:5.2f}%")
# verify eval acc

View file

@ -22,7 +22,7 @@ class TestFunction(unittest.TestCase):
def test_implicit(self):
inp = Tensor([7,8,9])
@function
@function(allow_implicit=True)
def f(a:Tensor, b:Tensor) -> Tensor: return a+b+inp
a = Tensor([1,2,3])
@ -31,7 +31,7 @@ class TestFunction(unittest.TestCase):
def test_implicit_same_as_input(self):
inp = Tensor([7,8,9])
@function
@function(allow_implicit=True)
def f(a:Tensor, b:Tensor) -> Tensor: return a+b+inp
a = Tensor([1,2,3])
@ -39,11 +39,11 @@ class TestFunction(unittest.TestCase):
def test_implicit_2(self):
inp = Tensor([7,8,9])
@function
@function(allow_implicit=True)
def f(a:Tensor, b:Tensor) -> Tensor:
return a+b+inp
inp2 = Tensor([7,8,10])
@function
@function(allow_implicit=True)
def g(a:Tensor, b:Tensor) -> Tensor:
return a+b+inp2
@ -57,7 +57,7 @@ class TestFunction(unittest.TestCase):
def test_implicit_unrealized(self):
inp = Tensor([1,2,3]) + Tensor([4,5,6])
@function
@function(allow_implicit=True)
def f(a:Tensor) -> Tensor: return a + inp
np.testing.assert_equal(f(Tensor([10,20,30])).numpy(), [15,27,39])
@ -103,7 +103,7 @@ class TestFunction(unittest.TestCase):
def test_grad_implicit(self):
w = Tensor([1., 2., 3.], requires_grad=True)
w.realize() # TODO: this is required
@function
@function(allow_implicit=True)
def f(x:Tensor) -> Tensor: return x * w
x = Tensor([4., 5., 6.])
@ -112,7 +112,7 @@ class TestFunction(unittest.TestCase):
def test_symbolic_index(self):
table = Tensor([10,20,30,40]).contiguous().realize()
@function
@function(allow_implicit=True)
def f(x:Tensor, start_pos:int|UOp) -> Tensor:
return x + table[start_pos]
@ -129,9 +129,9 @@ class TestFunction(unittest.TestCase):
def test_nested_calls(self):
w = Tensor([10., 20., 30.])
@function
@function(allow_implicit=True)
def f(a:Tensor) -> Tensor: return a + w
@function
@function(allow_implicit=True)
def g(a:Tensor) -> Tensor: return a * w
a = Tensor([1., 2., 3.])
@ -139,9 +139,9 @@ class TestFunction(unittest.TestCase):
def test_nested_calls_backward(self):
w = Tensor([[1., 2.], [3., 4.]]).contiguous().realize()
@function
@function(allow_implicit=True)
def inner(x:Tensor) -> Tensor: return x + w
@function
@function(allow_implicit=True)
def outer(a:Tensor, b:Tensor) -> Tensor: return inner(a.reshape(1,2) + b.reshape(1,2))
a = Tensor([1., 2.], requires_grad=True)
@ -178,7 +178,7 @@ class TestFunction(unittest.TestCase):
def __init__(self): self.w = Tensor([10,20,30])
def __call__(self, x:Tensor) -> Tensor: return x + self.w
foo = Foo()
f = function(foo)
f = function(foo, allow_implicit=True)
np.testing.assert_equal(f(Tensor([1,2,3])).numpy(), [11,22,33])
assert f(Tensor([1,2,3])).uop.src[0].arg.name.endswith("Foo")
@ -267,7 +267,7 @@ class TestFunctionMulti(unittest.TestCase):
def test_grad_implicit_multi(self):
w = Tensor([1., 2., 3., 4.], requires_grad=True).shard(self.devices_2, axis=None)
w.realize()
@function
@function(allow_implicit=True)
def f(x:Tensor) -> Tensor: return x * w
x = Tensor([4., 5., 6., 7.]).shard(self.devices_2, axis=None)
@ -324,7 +324,7 @@ class TestFunctionMulti(unittest.TestCase):
devices_4 = tuple(f"CPU:{i}" for i in range(4))
w = Tensor([[1.,2.],[3.,4.]], requires_grad=True).shard(devices_4, axis=None)
w.realize()
@function
@function(allow_implicit=True)
def f(x:Tensor) -> Tensor: return x @ w
x = Tensor(np.arange(16).reshape(8,2).astype(np.float32), requires_grad=True).shard(devices_4, axis=0)
@ -337,7 +337,7 @@ class TestFunctionMulti(unittest.TestCase):
w.realize()
# pre-init grads like the training loop does
w.grad = w.zeros_like().contiguous().realize()
@function
@function(allow_implicit=True)
def f(x:Tensor) -> Tensor: return x @ w
expected = np.ones((8,2)) @ np.array([[1,3],[2,4]])

View file

@ -20,8 +20,7 @@ pm_ctx = PatternMatcher([
ReturnType = TypeVar('ReturnType')
class _function(Generic[ReturnType]):
def __init__(self, fxn:Callable[..., ReturnType], *, precompile:bool=False, precompile_backward:bool=False,
allow_implicit:bool=True, grad_fxn:Callable|None=None):
def __init__(self, fxn:Callable[..., ReturnType], *, precompile:bool, precompile_backward:bool, allow_implicit:bool, grad_fxn:Callable|None):
self.fxn = fxn
self.precompile = precompile
self.precompile_backward = precompile_backward
@ -82,11 +81,12 @@ class _function(Generic[ReturnType]):
# overload signatures support both @function and @function(precompile=True) syntax
@overload
def function(fxn:Callable[..., ReturnType], *, precompile:bool=False, precompile_backward:bool=False,
allow_implicit:bool=True, grad_fxn:Callable|None=None) -> _function[ReturnType]: ...
allow_implicit:bool=False, grad_fxn:Callable|None=None) -> _function[ReturnType]: ...
@overload
def function(fxn:None=None, *, precompile:bool=False, precompile_backward:bool=False,
allow_implicit:bool=True, grad_fxn:Callable|None=None) -> Callable[[Callable[..., ReturnType]], _function[ReturnType]]: ...
def function(fxn=None, *, precompile:bool=False, precompile_backward:bool=False, allow_implicit:bool=True, grad_fxn:Callable|None=None):
allow_implicit:bool=False, grad_fxn:Callable|None=None) -> Callable[[Callable[..., ReturnType]], _function[ReturnType]]: ...
def function(fxn=None, *, precompile:bool=False, precompile_backward:bool=False,
allow_implicit:bool=False, grad_fxn:Callable|None=None):
if fxn is None:
return lambda f: _function(f, precompile=precompile, precompile_backward=precompile_backward,
allow_implicit=allow_implicit, grad_fxn=grad_fxn)