mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
1 commit
master
...
grad_outer
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a0c8d04feb |
2 changed files with 46 additions and 0 deletions
|
|
@ -57,6 +57,41 @@ class TestOuterRange(unittest.TestCase):
|
|||
# TODO: testing allclose
|
||||
assert Tensor.allclose(ref, out, atol=1e-6), f"{ref.numpy()=}, {out.numpy()=}"
|
||||
|
||||
def test_range_grad(self):
|
||||
def range_matmul(vec, mats):
|
||||
# vec: (1, 10), mats: (3, 10, 10)
|
||||
# assume vec, mats already have requires_grad set however you like
|
||||
|
||||
i = UOp.range(3, -100, AxisType.OUTER) # loop axis
|
||||
vec_i = Tensor(vec.uop.after(i)) # "loop-carried" vector
|
||||
vi = UOp.variable("i", i.vmin, i.vmax).bind(i)
|
||||
|
||||
body = (vec_i.contiguous() @ mats[vi]) # matmul using loop index
|
||||
out = Tensor(vec.uop.after(vec_i.uop.store(body.uop).end(i)))
|
||||
return out
|
||||
|
||||
vec = Tensor.randn(1, 3, requires_grad=True)
|
||||
mats = Tensor.randn(3, 3, 3, requires_grad=True)
|
||||
Tensor.realize(vec, mats)
|
||||
|
||||
ref = ((vec @ mats[0]) @ mats[1]) @ mats[2]
|
||||
loss = (1.0 - ref).square().mean()
|
||||
loss.backward()
|
||||
Tensor.realize(vec.grad, mats.grad)
|
||||
print(vec.grad.numpy())
|
||||
print(mats.grad.numpy())
|
||||
vec.grad = None
|
||||
mats.grad = None
|
||||
|
||||
out = range_matmul(vec, mats)
|
||||
loss = (1.0 - out).square().mean()
|
||||
loss.backward()
|
||||
Tensor.realize(vec.grad, mats.grad)
|
||||
|
||||
print(vec.grad, mats.grad) # should be non-None and finite
|
||||
print(vec.grad.numpy())
|
||||
print(mats.grad.numpy())
|
||||
|
||||
class TestOuterworld(unittest.TestCase):
|
||||
def test_range_plus_1(self):
|
||||
t = Tensor.arange(100).reshape(10,10).realize()
|
||||
|
|
|
|||
|
|
@ -43,6 +43,17 @@ pm_gradient = PatternMatcher([
|
|||
(UPat(Ops.KERNEL, name="k"), lambda ctx, k: k.arg.grad_fxn(ctx, k)),
|
||||
# there's no gradient for bitcast
|
||||
(UPat(Ops.BITCAST), lambda: (None,)),
|
||||
|
||||
# RANGE: loop index / axis, not a differentiable quantity
|
||||
(UPat(Ops.RANGE), lambda: (None,)),
|
||||
|
||||
# STORE: buffer write. Gradient flows only into the value being stored.
|
||||
# src layout is roughly (buffer, value, *axes_or_indices)
|
||||
(UPat(Ops.STORE), lambda ctx: (None, ctx)),
|
||||
|
||||
# END: loop terminator / "end of range" node.
|
||||
# Just pass the gradient into the body (first src), ignore the ranges.
|
||||
(UPat(Ops.END, name="ret"), lambda ctx, ret: (ctx, *[None]*(len(ret.src) - 1))),
|
||||
])
|
||||
|
||||
def _deepwalk(root:UOp, targets:set[UOp]) -> list[UOp]:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue