mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
1 commit
master
...
grad_outer
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a0c8d04feb |
2 changed files with 46 additions and 0 deletions
|
|
@ -57,6 +57,41 @@ class TestOuterRange(unittest.TestCase):
|
||||||
# TODO: testing allclose
|
# TODO: testing allclose
|
||||||
assert Tensor.allclose(ref, out, atol=1e-6), f"{ref.numpy()=}, {out.numpy()=}"
|
assert Tensor.allclose(ref, out, atol=1e-6), f"{ref.numpy()=}, {out.numpy()=}"
|
||||||
|
|
||||||
|
def test_range_grad(self):
|
||||||
|
def range_matmul(vec, mats):
|
||||||
|
# vec: (1, 10), mats: (3, 10, 10)
|
||||||
|
# assume vec, mats already have requires_grad set however you like
|
||||||
|
|
||||||
|
i = UOp.range(3, -100, AxisType.OUTER) # loop axis
|
||||||
|
vec_i = Tensor(vec.uop.after(i)) # "loop-carried" vector
|
||||||
|
vi = UOp.variable("i", i.vmin, i.vmax).bind(i)
|
||||||
|
|
||||||
|
body = (vec_i.contiguous() @ mats[vi]) # matmul using loop index
|
||||||
|
out = Tensor(vec.uop.after(vec_i.uop.store(body.uop).end(i)))
|
||||||
|
return out
|
||||||
|
|
||||||
|
vec = Tensor.randn(1, 3, requires_grad=True)
|
||||||
|
mats = Tensor.randn(3, 3, 3, requires_grad=True)
|
||||||
|
Tensor.realize(vec, mats)
|
||||||
|
|
||||||
|
ref = ((vec @ mats[0]) @ mats[1]) @ mats[2]
|
||||||
|
loss = (1.0 - ref).square().mean()
|
||||||
|
loss.backward()
|
||||||
|
Tensor.realize(vec.grad, mats.grad)
|
||||||
|
print(vec.grad.numpy())
|
||||||
|
print(mats.grad.numpy())
|
||||||
|
vec.grad = None
|
||||||
|
mats.grad = None
|
||||||
|
|
||||||
|
out = range_matmul(vec, mats)
|
||||||
|
loss = (1.0 - out).square().mean()
|
||||||
|
loss.backward()
|
||||||
|
Tensor.realize(vec.grad, mats.grad)
|
||||||
|
|
||||||
|
print(vec.grad, mats.grad) # should be non-None and finite
|
||||||
|
print(vec.grad.numpy())
|
||||||
|
print(mats.grad.numpy())
|
||||||
|
|
||||||
class TestOuterworld(unittest.TestCase):
|
class TestOuterworld(unittest.TestCase):
|
||||||
def test_range_plus_1(self):
|
def test_range_plus_1(self):
|
||||||
t = Tensor.arange(100).reshape(10,10).realize()
|
t = Tensor.arange(100).reshape(10,10).realize()
|
||||||
|
|
|
||||||
|
|
@ -43,6 +43,17 @@ pm_gradient = PatternMatcher([
|
||||||
(UPat(Ops.KERNEL, name="k"), lambda ctx, k: k.arg.grad_fxn(ctx, k)),
|
(UPat(Ops.KERNEL, name="k"), lambda ctx, k: k.arg.grad_fxn(ctx, k)),
|
||||||
# there's no gradient for bitcast
|
# there's no gradient for bitcast
|
||||||
(UPat(Ops.BITCAST), lambda: (None,)),
|
(UPat(Ops.BITCAST), lambda: (None,)),
|
||||||
|
|
||||||
|
# RANGE: loop index / axis, not a differentiable quantity
|
||||||
|
(UPat(Ops.RANGE), lambda: (None,)),
|
||||||
|
|
||||||
|
# STORE: buffer write. Gradient flows only into the value being stored.
|
||||||
|
# src layout is roughly (buffer, value, *axes_or_indices)
|
||||||
|
(UPat(Ops.STORE), lambda ctx: (None, ctx)),
|
||||||
|
|
||||||
|
# END: loop terminator / "end of range" node.
|
||||||
|
# Just pass the gradient into the body (first src), ignore the ranges.
|
||||||
|
(UPat(Ops.END, name="ret"), lambda ctx, ret: (ctx, *[None]*(len(ret.src) - 1))),
|
||||||
])
|
])
|
||||||
|
|
||||||
def _deepwalk(root:UOp, targets:set[UOp]) -> list[UOp]:
|
def _deepwalk(root:UOp, targets:set[UOp]) -> list[UOp]:
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue