test backward in test_tiny (#11697)

* test backward in test_tiny

* empty
This commit is contained in:
George Hotz 2025-08-16 20:29:39 -07:00 committed by GitHub
commit 9366a23eb0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 20 additions and 3 deletions

View file

@ -29,8 +29,7 @@ if __name__ == "__main__":
opt.zero_grad()
samples = Tensor.randint(getenv("BS", 512), high=X_train.shape[0])
loss = model(X_train[samples]).sparse_categorical_crossentropy(Y_train[samples]).backward()
opt.step()
return loss
return loss.realize(*opt.schedule_step())
@TinyJit
def get_test_acc() -> Tensor: return (model(X_test).argmax(axis=1) == Y_test).mean()*100

View file

@ -106,6 +106,24 @@ class TestTiny(unittest.TestCase):
probs = Tensor.rand(1, 1, 28, 28).sequential(layers).tolist()
self.assertEqual(len(probs[0]), 10)
# TODO: this is failing because of how swizzling rewrites the ShapeTracker of the final STORE
@unittest.skipIf(IMAGE>0 or (CI and Device.DEFAULT == "DSP"), "failing because of make things that can't be images not images")
def test_mnist_backward(self):
# NOTE: we don't have the whole model here for speed
layers = [
nn.Conv2d(1, 32, 5), Tensor.relu,
nn.Conv2d(32, 32, 5), Tensor.relu]
# replace random weights with ones
# TODO: there's a bug here where it's tying two of the biases together. we need UNIQUE const
#Tensor.realize(*[p.replace(Tensor.ones_like(p).contiguous()) for p in nn.state.get_parameters(layers)])
for p in nn.state.get_parameters(layers): p.replace(Tensor.empty(p.shape))
# realize gradients
for x in nn.state.get_parameters(layers): x.requires_grad_()
Tensor.empty(4, 1, 28, 28).sequential(layers).sum().backward()
Tensor.realize(*[x.grad for x in nn.state.get_parameters(layers) if x.grad is not None])
# *** image ***
@unittest.skipIf(Device.DEFAULT != "GPU", "image only supported on GPU")

View file

@ -252,7 +252,7 @@ class Tensor(MathTrait):
# create the schedule
schedule, var_vals = create_schedule_with_vars(sink)
schedule = memory_planner(schedule)
if DEBUG >= 1 and len(schedule) >= 10: print(f"scheduled {len(schedule)} kernels in {(time.perf_counter()-st)*1000:.2f} ms")
if DEBUG >= 1 and len(schedule) > 1: print(f"scheduled {len(schedule)} kernels in {(time.perf_counter()-st)*1000:.2f} ms")
return schedule, var_vals
def schedule(self, *lst:Tensor) -> list[ScheduleItem]: