mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
set allow_implicit=False by default (#15319)
* set allow_implicit=False by default * modernize beautiful mnist
This commit is contained in:
parent
e1c2d09720
commit
9d95321be3
3 changed files with 33 additions and 32 deletions
|
|
@ -1,6 +1,6 @@
|
|||
# model based off https://medium.com/data-science/going-beyond-99-mnist-handwritten-digits-recognition-cfff96337392
|
||||
from typing import Callable
|
||||
from tinygrad import Tensor, TinyJit, nn, GlobalCounters
|
||||
from tinygrad import Tensor, TinyJit, nn, GlobalCounters, function
|
||||
from tinygrad.helpers import getenv, colored, trange
|
||||
from tinygrad.nn.datasets import mnist
|
||||
|
||||
|
|
@ -15,30 +15,31 @@ class Model:
|
|||
nn.BatchNorm(64), Tensor.max_pool2d,
|
||||
lambda x: x.flatten(1), nn.Linear(576, 10)]
|
||||
|
||||
@function
|
||||
def __call__(self, x:Tensor) -> Tensor: return x.sequential(self.layers)
|
||||
|
||||
@TinyJit
|
||||
@Tensor.train()
|
||||
def train_step(self, X_train:Tensor, Y_train:Tensor) -> Tensor:
|
||||
opt.zero_grad()
|
||||
samples = Tensor.randint(getenv("BS", 512), high=X_train.shape[0])
|
||||
loss = self(X_train[samples]).sparse_categorical_crossentropy(Y_train[samples]).backward()
|
||||
return loss.realize(*opt.schedule_step())
|
||||
|
||||
@TinyJit
|
||||
def get_test_acc(self, X_test:Tensor, Y_test:Tensor) -> Tensor: return (self(X_test).argmax(axis=1) == Y_test).mean()*100
|
||||
|
||||
if __name__ == "__main__":
|
||||
X_train, Y_train, X_test, Y_test = mnist(fashion=getenv("FASHION"))
|
||||
|
||||
model = Model()
|
||||
opt = (nn.optim.Muon if getenv("MUON") else nn.optim.SGD if getenv("SGD") else nn.optim.Adam)(nn.state.get_parameters(model))
|
||||
|
||||
@TinyJit
|
||||
@Tensor.train()
|
||||
def train_step() -> Tensor:
|
||||
opt.zero_grad()
|
||||
samples = Tensor.randint(getenv("BS", 512), high=X_train.shape[0])
|
||||
loss = model(X_train[samples]).sparse_categorical_crossentropy(Y_train[samples]).backward()
|
||||
return loss.realize(*opt.schedule_step())
|
||||
|
||||
@TinyJit
|
||||
def get_test_acc() -> Tensor: return (model(X_test).argmax(axis=1) == Y_test).mean()*100
|
||||
|
||||
test_acc = float('nan')
|
||||
for i in (t:=trange(getenv("STEPS", 70))):
|
||||
GlobalCounters.reset() # NOTE: this makes it nice for DEBUG=2 timing
|
||||
loss = train_step()
|
||||
if i%10 == 9: test_acc = get_test_acc().item()
|
||||
loss = model.train_step(X_train, Y_train)
|
||||
if i%10 == 9: test_acc = model.get_test_acc(X_test, Y_test).item()
|
||||
t.set_description(f"loss: {loss.item():6.2f} test_accuracy: {test_acc:5.2f}%")
|
||||
|
||||
# verify eval acc
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ class TestFunction(unittest.TestCase):
|
|||
|
||||
def test_implicit(self):
|
||||
inp = Tensor([7,8,9])
|
||||
@function
|
||||
@function(allow_implicit=True)
|
||||
def f(a:Tensor, b:Tensor) -> Tensor: return a+b+inp
|
||||
|
||||
a = Tensor([1,2,3])
|
||||
|
|
@ -31,7 +31,7 @@ class TestFunction(unittest.TestCase):
|
|||
|
||||
def test_implicit_same_as_input(self):
|
||||
inp = Tensor([7,8,9])
|
||||
@function
|
||||
@function(allow_implicit=True)
|
||||
def f(a:Tensor, b:Tensor) -> Tensor: return a+b+inp
|
||||
|
||||
a = Tensor([1,2,3])
|
||||
|
|
@ -39,11 +39,11 @@ class TestFunction(unittest.TestCase):
|
|||
|
||||
def test_implicit_2(self):
|
||||
inp = Tensor([7,8,9])
|
||||
@function
|
||||
@function(allow_implicit=True)
|
||||
def f(a:Tensor, b:Tensor) -> Tensor:
|
||||
return a+b+inp
|
||||
inp2 = Tensor([7,8,10])
|
||||
@function
|
||||
@function(allow_implicit=True)
|
||||
def g(a:Tensor, b:Tensor) -> Tensor:
|
||||
return a+b+inp2
|
||||
|
||||
|
|
@ -57,7 +57,7 @@ class TestFunction(unittest.TestCase):
|
|||
|
||||
def test_implicit_unrealized(self):
|
||||
inp = Tensor([1,2,3]) + Tensor([4,5,6])
|
||||
@function
|
||||
@function(allow_implicit=True)
|
||||
def f(a:Tensor) -> Tensor: return a + inp
|
||||
|
||||
np.testing.assert_equal(f(Tensor([10,20,30])).numpy(), [15,27,39])
|
||||
|
|
@ -103,7 +103,7 @@ class TestFunction(unittest.TestCase):
|
|||
def test_grad_implicit(self):
|
||||
w = Tensor([1., 2., 3.], requires_grad=True)
|
||||
w.realize() # TODO: this is required
|
||||
@function
|
||||
@function(allow_implicit=True)
|
||||
def f(x:Tensor) -> Tensor: return x * w
|
||||
|
||||
x = Tensor([4., 5., 6.])
|
||||
|
|
@ -112,7 +112,7 @@ class TestFunction(unittest.TestCase):
|
|||
|
||||
def test_symbolic_index(self):
|
||||
table = Tensor([10,20,30,40]).contiguous().realize()
|
||||
@function
|
||||
@function(allow_implicit=True)
|
||||
def f(x:Tensor, start_pos:int|UOp) -> Tensor:
|
||||
return x + table[start_pos]
|
||||
|
||||
|
|
@ -129,9 +129,9 @@ class TestFunction(unittest.TestCase):
|
|||
|
||||
def test_nested_calls(self):
|
||||
w = Tensor([10., 20., 30.])
|
||||
@function
|
||||
@function(allow_implicit=True)
|
||||
def f(a:Tensor) -> Tensor: return a + w
|
||||
@function
|
||||
@function(allow_implicit=True)
|
||||
def g(a:Tensor) -> Tensor: return a * w
|
||||
|
||||
a = Tensor([1., 2., 3.])
|
||||
|
|
@ -139,9 +139,9 @@ class TestFunction(unittest.TestCase):
|
|||
|
||||
def test_nested_calls_backward(self):
|
||||
w = Tensor([[1., 2.], [3., 4.]]).contiguous().realize()
|
||||
@function
|
||||
@function(allow_implicit=True)
|
||||
def inner(x:Tensor) -> Tensor: return x + w
|
||||
@function
|
||||
@function(allow_implicit=True)
|
||||
def outer(a:Tensor, b:Tensor) -> Tensor: return inner(a.reshape(1,2) + b.reshape(1,2))
|
||||
|
||||
a = Tensor([1., 2.], requires_grad=True)
|
||||
|
|
@ -178,7 +178,7 @@ class TestFunction(unittest.TestCase):
|
|||
def __init__(self): self.w = Tensor([10,20,30])
|
||||
def __call__(self, x:Tensor) -> Tensor: return x + self.w
|
||||
foo = Foo()
|
||||
f = function(foo)
|
||||
f = function(foo, allow_implicit=True)
|
||||
np.testing.assert_equal(f(Tensor([1,2,3])).numpy(), [11,22,33])
|
||||
assert f(Tensor([1,2,3])).uop.src[0].arg.name.endswith("Foo")
|
||||
|
||||
|
|
@ -267,7 +267,7 @@ class TestFunctionMulti(unittest.TestCase):
|
|||
def test_grad_implicit_multi(self):
|
||||
w = Tensor([1., 2., 3., 4.], requires_grad=True).shard(self.devices_2, axis=None)
|
||||
w.realize()
|
||||
@function
|
||||
@function(allow_implicit=True)
|
||||
def f(x:Tensor) -> Tensor: return x * w
|
||||
|
||||
x = Tensor([4., 5., 6., 7.]).shard(self.devices_2, axis=None)
|
||||
|
|
@ -324,7 +324,7 @@ class TestFunctionMulti(unittest.TestCase):
|
|||
devices_4 = tuple(f"CPU:{i}" for i in range(4))
|
||||
w = Tensor([[1.,2.],[3.,4.]], requires_grad=True).shard(devices_4, axis=None)
|
||||
w.realize()
|
||||
@function
|
||||
@function(allow_implicit=True)
|
||||
def f(x:Tensor) -> Tensor: return x @ w
|
||||
|
||||
x = Tensor(np.arange(16).reshape(8,2).astype(np.float32), requires_grad=True).shard(devices_4, axis=0)
|
||||
|
|
@ -337,7 +337,7 @@ class TestFunctionMulti(unittest.TestCase):
|
|||
w.realize()
|
||||
# pre-init grads like the training loop does
|
||||
w.grad = w.zeros_like().contiguous().realize()
|
||||
@function
|
||||
@function(allow_implicit=True)
|
||||
def f(x:Tensor) -> Tensor: return x @ w
|
||||
|
||||
expected = np.ones((8,2)) @ np.array([[1,3],[2,4]])
|
||||
|
|
|
|||
|
|
@ -20,8 +20,7 @@ pm_ctx = PatternMatcher([
|
|||
|
||||
ReturnType = TypeVar('ReturnType')
|
||||
class _function(Generic[ReturnType]):
|
||||
def __init__(self, fxn:Callable[..., ReturnType], *, precompile:bool=False, precompile_backward:bool=False,
|
||||
allow_implicit:bool=True, grad_fxn:Callable|None=None):
|
||||
def __init__(self, fxn:Callable[..., ReturnType], *, precompile:bool, precompile_backward:bool, allow_implicit:bool, grad_fxn:Callable|None):
|
||||
self.fxn = fxn
|
||||
self.precompile = precompile
|
||||
self.precompile_backward = precompile_backward
|
||||
|
|
@ -82,11 +81,12 @@ class _function(Generic[ReturnType]):
|
|||
# overload signatures support both @function and @function(precompile=True) syntax
|
||||
@overload
|
||||
def function(fxn:Callable[..., ReturnType], *, precompile:bool=False, precompile_backward:bool=False,
|
||||
allow_implicit:bool=True, grad_fxn:Callable|None=None) -> _function[ReturnType]: ...
|
||||
allow_implicit:bool=False, grad_fxn:Callable|None=None) -> _function[ReturnType]: ...
|
||||
@overload
|
||||
def function(fxn:None=None, *, precompile:bool=False, precompile_backward:bool=False,
|
||||
allow_implicit:bool=True, grad_fxn:Callable|None=None) -> Callable[[Callable[..., ReturnType]], _function[ReturnType]]: ...
|
||||
def function(fxn=None, *, precompile:bool=False, precompile_backward:bool=False, allow_implicit:bool=True, grad_fxn:Callable|None=None):
|
||||
allow_implicit:bool=False, grad_fxn:Callable|None=None) -> Callable[[Callable[..., ReturnType]], _function[ReturnType]]: ...
|
||||
def function(fxn=None, *, precompile:bool=False, precompile_backward:bool=False,
|
||||
allow_implicit:bool=False, grad_fxn:Callable|None=None):
|
||||
if fxn is None:
|
||||
return lambda f: _function(f, precompile=precompile, precompile_backward=precompile_backward,
|
||||
allow_implicit=allow_implicit, grad_fxn=grad_fxn)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue