Compare commits

...

1 commit

Author SHA1 Message Date
George Hotz
a0c8d04feb grad outerworld test 2025-11-16 07:43:51 -08:00
2 changed files with 46 additions and 0 deletions

View file

@ -57,6 +57,41 @@ class TestOuterRange(unittest.TestCase):
# TODO: testing allclose
assert Tensor.allclose(ref, out, atol=1e-6), f"{ref.numpy()=}, {out.numpy()=}"
def test_range_grad(self):
def range_matmul(vec, mats):
# vec: (1, 10), mats: (3, 10, 10)
# assume vec, mats already have requires_grad set however you like
i = UOp.range(3, -100, AxisType.OUTER) # loop axis
vec_i = Tensor(vec.uop.after(i)) # "loop-carried" vector
vi = UOp.variable("i", i.vmin, i.vmax).bind(i)
body = (vec_i.contiguous() @ mats[vi]) # matmul using loop index
out = Tensor(vec.uop.after(vec_i.uop.store(body.uop).end(i)))
return out
vec = Tensor.randn(1, 3, requires_grad=True)
mats = Tensor.randn(3, 3, 3, requires_grad=True)
Tensor.realize(vec, mats)
ref = ((vec @ mats[0]) @ mats[1]) @ mats[2]
loss = (1.0 - ref).square().mean()
loss.backward()
Tensor.realize(vec.grad, mats.grad)
print(vec.grad.numpy())
print(mats.grad.numpy())
vec.grad = None
mats.grad = None
out = range_matmul(vec, mats)
loss = (1.0 - out).square().mean()
loss.backward()
Tensor.realize(vec.grad, mats.grad)
print(vec.grad, mats.grad) # should be non-None and finite
print(vec.grad.numpy())
print(mats.grad.numpy())
class TestOuterworld(unittest.TestCase):
def test_range_plus_1(self):
t = Tensor.arange(100).reshape(10,10).realize()

View file

@ -43,6 +43,17 @@ pm_gradient = PatternMatcher([
(UPat(Ops.KERNEL, name="k"), lambda ctx, k: k.arg.grad_fxn(ctx, k)),
# there's no gradient for bitcast
(UPat(Ops.BITCAST), lambda: (None,)),
# RANGE: loop index / axis, not a differentiable quantity
(UPat(Ops.RANGE), lambda: (None,)),
# STORE: buffer write. Gradient flows only into the value being stored.
# src layout is roughly (buffer, value, *axes_or_indices)
(UPat(Ops.STORE), lambda ctx: (None, ctx)),
# END: loop terminator / "end of range" node.
# Just pass the gradient into the body (first src), ignore the ranges.
(UPat(Ops.END, name="ret"), lambda ctx, ret: (ctx, *[None]*(len(ret.src) - 1))),
])
def _deepwalk(root:UOp, targets:set[UOp]) -> list[UOp]: