Compare commits

...

8 commits

Author SHA1 Message Date
George Hotz
34d2ef6447
Merge branch 'master' into precompile_backward 2026-03-17 15:17:41 +08:00
George Hotz
125d987886 no NOOPT 2026-03-17 15:14:25 +08:00
George Hotz
c83f2a7154 simpler 2026-03-17 15:01:46 +08:00
George Hotz
32a8b3aaa0 split v not split 2026-03-17 14:57:53 +08:00
George Hotz
39aae38b2a compact grad 2026-03-17 14:37:40 +08:00
George Hotz
f909d5c983 fix 2026-03-17 14:19:47 +08:00
George Hotz
f7512e9595 cleanups 2026-03-17 14:14:41 +08:00
George Hotz
9f591b42d1 add precompile backward support 2026-03-17 11:40:09 +08:00
5 changed files with 80 additions and 31 deletions

View file

@ -1,7 +1,7 @@
import numpy as np
import unittest
from tinygrad.function import function
from tinygrad import Tensor
from tinygrad import Tensor, GlobalCounters
from tinygrad.uop.ops import UOp
class TestFunction(unittest.TestCase):
@ -384,5 +384,31 @@ class TestFunctionTuple(unittest.TestCase):
np.testing.assert_allclose(x.grad.numpy(), [1., 1., 1.])
np.testing.assert_allclose(y.grad.numpy(), [1., 1., 1.])
class TestFunctionGrad(unittest.TestCase):
def test_function_grad_ops(self, precompile=False, precompile_backward=False):
N = 64
x = Tensor.ones(N,N).contiguous()
w1 = Tensor.ones(N,N, requires_grad=True).contiguous()
w2 = Tensor.ones(N,N, requires_grad=True).contiguous()
w3 = Tensor.ones(N,N, requires_grad=True).contiguous()
ref = Tensor.ones(N,N).contiguous()
Tensor.realize(x, w1, w2, w3, ref)
@function(precompile=precompile, precompile_backward=precompile_backward)
def f(x, w1, w2, w3) -> tuple[Tensor, ...]:
p1 = x@w1
p2 = p1@w2
p3 = p2@w3
return p1, p2, p3, p3.contiguous()
ret = f(x, w1, w2, w3)[-1]
loss = (ret-ref).square().mean().backward()
print("RESET")
GlobalCounters.reset()
loss.realize(w1.grad, w2.grad, w3.grad)
print(GlobalCounters.global_ops, GlobalCounters.global_mem)
self.assertLessEqual(GlobalCounters.global_ops, 4739073)
def test_function_grad_ops_precompile(self): self.test_function_grad_ops(precompile=True)
def test_function_grad_ops_precompile_backward(self):
self.test_function_grad_ops(precompile=True, precompile_backward=True)
if __name__ == '__main__':
unittest.main()

View file

@ -117,6 +117,9 @@ pm_early_transform_tensor_graph = PatternMatcher([
# transform precompiled CALLs
(UPat(Ops.CALL, name="c"), transform_precompiled_call),
# resolve TUPLE+GETTUPLE (for precompiled calls)
(UPat(Ops.GETTUPLE, src=(UPat(Ops.TUPLE, name="t"),), name="g"), lambda g,t: t.src[g.arg]),
# CONTIGUOUS(MOPS(BUFFER/BUFFER_VIEW)) → CONTIGUOUS(BUFFER_VIEW) when movement ops collapse to contiguous range
(UPat(Ops.CONTIGUOUS, src=(UPat(GroupOp.Movement, name="src"),), name="c"), contiguous_mops_to_view),

View file

@ -20,9 +20,11 @@ pm_ctx = PatternMatcher([
ReturnType = TypeVar('ReturnType')
class _function(Generic[ReturnType]):
def __init__(self, fxn:Callable[..., ReturnType], *, precompile:bool=False, allow_implicit:bool=True, grad_fxn:Callable|None=None):
def __init__(self, fxn:Callable[..., ReturnType], *, precompile:bool=False, precompile_backward:bool=False,
allow_implicit:bool=True, grad_fxn:Callable|None=None):
self.fxn = fxn
self.precompile = precompile
self.precompile_backward = precompile_backward
self.allow_implicit = allow_implicit
self.grad_fxn = grad_fxn
@ -70,7 +72,8 @@ class _function(Generic[ReturnType]):
#call = assigned.call(*call_uops, buffer, name=name)
#ret = buffer.after(call)
fret = uret.call(*call_uops, grad_fxn=self.grad_fxn, name=name, precompile=self.precompile)
fret = uret.call(*call_uops, grad_fxn=self.grad_fxn, name=name, precompile=self.precompile,
precompile_backward=self.precompile_backward)
if isinstance(ret, tuple):
return cast(ReturnType, tuple(Tensor(fret.gettuple(i), device=fret.device) for i in range(len(ret))))
else:
@ -78,11 +81,14 @@ class _function(Generic[ReturnType]):
# overload signatures support both @function and @function(precompile=True) syntax
@overload
def function(fxn:Callable[..., ReturnType], *, precompile:bool=False, allow_implicit:bool=True,
grad_fxn:Callable|None=None) -> _function[ReturnType]: ...
def function(fxn:Callable[..., ReturnType], *, precompile:bool=False, precompile_backward:bool=False,
allow_implicit:bool=True, grad_fxn:Callable|None=None) -> _function[ReturnType]: ...
@overload
def function(fxn:None=None, *, precompile:bool=False, allow_implicit:bool=True,
grad_fxn:Callable|None=None) -> Callable[[Callable[..., ReturnType]], _function[ReturnType]]: ...
def function(fxn=None, *, precompile:bool=False, allow_implicit:bool=True, grad_fxn:Callable|None=None):
if fxn is None: return lambda f: _function(f, precompile=precompile, allow_implicit=allow_implicit, grad_fxn=grad_fxn)
return _function(fxn, precompile=precompile, allow_implicit=allow_implicit, grad_fxn=grad_fxn)
def function(fxn:None=None, *, precompile:bool=False, precompile_backward:bool=False,
allow_implicit:bool=True, grad_fxn:Callable|None=None) -> Callable[[Callable[..., ReturnType]], _function[ReturnType]]: ...
def function(fxn=None, *, precompile:bool=False, precompile_backward:bool=False, allow_implicit:bool=True, grad_fxn:Callable|None=None):
if fxn is None:
return lambda f: _function(f, precompile=precompile, precompile_backward=precompile_backward,
allow_implicit=allow_implicit, grad_fxn=grad_fxn)
return _function(fxn, precompile=precompile, precompile_backward=precompile_backward,
allow_implicit=allow_implicit, grad_fxn=grad_fxn)

View file

@ -13,25 +13,33 @@ def reduce_gradient(ctx:UOp, ret:UOp, op:Ops):
return ((mask/broadcast_to_input(count)) * broadcast_to_input(ctx),)
if op == Ops.MUL: return (broadcast_to_input(ctx * ret) / ret.src[0],)
def call_gradient(ctx:UOp, k:UOp) -> tuple[UOp|None, ...]:
def _compact_params(body:UOp, all_args:tuple[UOp, ...]) -> tuple[UOp, tuple[UOp, ...]]:
"""Remove unused PARAMs from body and return compacted (body, args)."""
used = sorted({p.arg: p for p in body.toposort() if p.op is Ops.PARAM}.items())
return body.substitute({p: p.replace(arg=j) for j,(_, p) in enumerate(used)}, walk=True), tuple(all_args[i] for i,_ in used)
def call_gradient(ctx:UOp, k:UOp, needed:set[int]) -> tuple[UOp|None, ...]:
fxn, args = k.src[0], k.src[1:]
if k.arg.grad_fxn is not None:
return (None,) + (k.arg.grad_fxn(*ctx.src, call=k) if ctx.op is Ops.TUPLE else k.arg.grad_fxn(ctx, k))
if ctx.op is Ops.TUPLE:
real = [g for g in ctx.src if g.op is not Ops.NOOP]
return (None,) + (k.arg.grad_fxn(*real, call=k) if len(real) > 1 else k.arg.grad_fxn(real[0], k))
return (None,) + k.arg.grad_fxn(ctx, k)
assert fxn.op is Ops.TUPLE, f"expected TUPLE body for gradient, got {fxn.op}"
params = {x.arg:x for x in fxn.toposort(enter_calls=False) if x.op == Ops.PARAM}
grad_args = ctx.src
root_grad = UOp(Ops.TUPLE, src=tuple(g.param_like(len(args) + i) for i, g in enumerate(grad_args)))
root_grad = UOp(Ops.TUPLE, src=tuple(fxn.src[i].const_like(0) if g.op is Ops.NOOP else g.param_like(len(args)+i) for i,g in enumerate(grad_args)))
grads = compute_gradient(fxn, root_grad, set(params.values()))
ret: list[UOp|None] = [None]
for i in range(len(args)):
if (p:=params.get(i, None)) is not None and p in grads:
# TODO: compact the args and remove unused ones
assert not grads[p].op_in_backward_slice_with_self(Ops.BUFFER), "BUG: BUFFER in backward slice of grad"
bwd_call = grads[p].call(*args, *grad_args, name=(k.arg.name or "")+f"_backward_{i}", precompile=k.arg.precompile_backward)
ret.append(bwd_call.gettuple(0))
else:
ret.append(None)
return tuple(ret)
# for precompiled calls, substitute forward outputs with params so intermediates aren't recomputed
fwd_subs = {src: src.param_like(len(args)+len(grad_args)+i) for i, src in enumerate(fxn.src)} if k.arg.precompile else {}
fwd_outs = tuple(k.gettuple(i) for i in range(len(fxn.src))) if k.arg.precompile else ()
# collect needed gradient bodies, compact unused params, create a single backward CALL
grad_bodies = [(i, grads[p]) for i in needed if (p:=params.get(i)) is not None and p in grads]
bwd_body = UOp.maketuple(*(gb for _, gb in grad_bodies)).substitute(fwd_subs, walk=True)
bwd_body, compact_args = _compact_params(bwd_body, (*args, *grad_args, *fwd_outs))
bwd_call = bwd_body.call(*compact_args, name=(k.arg.name or "")+"_backward", precompile=k.arg.precompile_backward)
gb_map = {i: idx for idx, (i, _) in enumerate(grad_bodies)}
return (None,) + tuple(bwd_call.gettuple(gb_map[i]) if i in gb_map else None for i in range(len(args)))
# ctx is grad_output
pm_gradient = PatternMatcher([
@ -63,32 +71,37 @@ pm_gradient = PatternMatcher([
(UPat(Ops.TUPLE), lambda ctx: ctx.src),
# NOTE: this is only correct when the KERNEL has a single output
(UPat(Ops.AFTER), lambda ctx: (ctx, ctx)),
# gradient on CALL: use provided grad_fxn or auto-differentiate
(UPat(Ops.CALL, name="k"), call_gradient),
# there's no gradient for bitcast
(UPat(Ops.BITCAST), lambda: (None,)),
])
def _deepwalk(root:UOp, targets:set[UOp]) -> list[UOp]:
def _deepwalk(root:UOp, targets:set[UOp]) -> tuple[list[UOp], dict[UOp, bool]]:
# compute the target path (top down)
in_target_path: dict[UOp, bool] = {}
for u in root.toposort(): in_target_path[u] = any(x in targets or in_target_path[x] for x in u.src)
# don't flow through DETACH or anything not in target path
return list(root.toposort(lambda node: node.op is not Ops.DETACH and in_target_path[node]))
return list(root.toposort(lambda node: node.op is not Ops.DETACH and in_target_path[node])), in_target_path
def compute_gradient(root:UOp, root_grad:UOp, targets:set[UOp]) -> dict[UOp, UOp]:
walk, in_target_path = _deepwalk(root, targets)
grads: dict[UOp, UOp] = {root: root_grad}
for t0 in reversed(_deepwalk(root, targets)):
for t0 in reversed(walk):
if t0 not in grads: continue
# GETTUPLE: accumulate gradient into a TUPLE UOp on the CALL, process when we hit the CALL
if t0.op is Ops.GETTUPLE:
k = t0.src[0] # the CALL
assert k.op is Ops.CALL and k.src[0].op is Ops.TUPLE
n_outputs = len(k.src[0].src)
prev: tuple[UOp, ...] = grads[k].src if k in grads else tuple(grads[t0].const_like(0) for _ in range(n_outputs))
grads[k] = UOp.maketuple(*(prev[i] + grads[t0] if i == t0.arg else prev[i] for i in range(n_outputs)))
prev = grads[k].src if k in grads else tuple(UOp(Ops.NOOP) for _ in range(n_outputs))
grads[k] = UOp.maketuple(*(prev[i] + grads[t0] if i == t0.arg and prev[i].op is not Ops.NOOP else
grads[t0] if i == t0.arg else prev[i] for i in range(n_outputs)))
continue
lgrads: tuple[UOp|None, ...]|None = cast(tuple[UOp|None, ...]|None, pm_gradient.rewrite(t0, ctx=grads[t0]))
# CALL: pass needed param set so backward only computes required gradients
if t0.op is Ops.CALL:
needed = {i for i, arg in enumerate(t0.src[1:]) if arg in targets or in_target_path.get(arg, False)}
lgrads:tuple[UOp|None, ...]|None = call_gradient(grads[t0], t0, needed)
else:
lgrads = cast(tuple[UOp|None, ...]|None, pm_gradient.rewrite(t0, ctx=grads[t0]))
if lgrads is None: raise RuntimeError(f"failed to compute gradient for {t0.op}\n\nin {str(t0)[0:1000]}...")
assert len(lgrads) == len(t0.src), f"got {len(lgrads)} gradient, expected {len(t0.src)}"
for k,v in zip(t0.src, lgrads):

View file

@ -697,6 +697,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
def has_buffer_identity(self):
"""Check if this UOp has a concrete buffer identity in the graph (RESHAPE/MULTI -> BUFFER chain)."""
if self.op in {Ops.RESHAPE, Ops.MULTI}: return self.src[0].has_buffer_identity()
if self.op is Ops.GETTUPLE and self.src[0].op is Ops.TUPLE: return self.src[0].src[self.arg].has_buffer_identity()
return self.op in {Ops.BUFFER, Ops.BUFFER_VIEW, Ops.PARAM}
def _base_buffer_is_realized(self) -> bool: