Compare commits

...

8 commits

Author SHA1 Message Date
George Hotz
34d2ef6447
Merge branch 'master' into precompile_backward 2026-03-17 15:17:41 +08:00
George Hotz
125d987886 no NOOPT 2026-03-17 15:14:25 +08:00
George Hotz
c83f2a7154 simpler 2026-03-17 15:01:46 +08:00
George Hotz
32a8b3aaa0 split v not split 2026-03-17 14:57:53 +08:00
George Hotz
39aae38b2a compact grad 2026-03-17 14:37:40 +08:00
George Hotz
f909d5c983 fix 2026-03-17 14:19:47 +08:00
George Hotz
f7512e9595 cleanups 2026-03-17 14:14:41 +08:00
George Hotz
9f591b42d1 add precompile backward support 2026-03-17 11:40:09 +08:00
5 changed files with 80 additions and 31 deletions

View file

@ -1,7 +1,7 @@
import numpy as np import numpy as np
import unittest import unittest
from tinygrad.function import function from tinygrad.function import function
from tinygrad import Tensor from tinygrad import Tensor, GlobalCounters
from tinygrad.uop.ops import UOp from tinygrad.uop.ops import UOp
class TestFunction(unittest.TestCase): class TestFunction(unittest.TestCase):
@ -384,5 +384,31 @@ class TestFunctionTuple(unittest.TestCase):
np.testing.assert_allclose(x.grad.numpy(), [1., 1., 1.]) np.testing.assert_allclose(x.grad.numpy(), [1., 1., 1.])
np.testing.assert_allclose(y.grad.numpy(), [1., 1., 1.]) np.testing.assert_allclose(y.grad.numpy(), [1., 1., 1.])
class TestFunctionGrad(unittest.TestCase):
def test_function_grad_ops(self, precompile=False, precompile_backward=False):
N = 64
x = Tensor.ones(N,N).contiguous()
w1 = Tensor.ones(N,N, requires_grad=True).contiguous()
w2 = Tensor.ones(N,N, requires_grad=True).contiguous()
w3 = Tensor.ones(N,N, requires_grad=True).contiguous()
ref = Tensor.ones(N,N).contiguous()
Tensor.realize(x, w1, w2, w3, ref)
@function(precompile=precompile, precompile_backward=precompile_backward)
def f(x, w1, w2, w3) -> tuple[Tensor, ...]:
p1 = x@w1
p2 = p1@w2
p3 = p2@w3
return p1, p2, p3, p3.contiguous()
ret = f(x, w1, w2, w3)[-1]
loss = (ret-ref).square().mean().backward()
print("RESET")
GlobalCounters.reset()
loss.realize(w1.grad, w2.grad, w3.grad)
print(GlobalCounters.global_ops, GlobalCounters.global_mem)
self.assertLessEqual(GlobalCounters.global_ops, 4739073)
def test_function_grad_ops_precompile(self): self.test_function_grad_ops(precompile=True)
def test_function_grad_ops_precompile_backward(self):
self.test_function_grad_ops(precompile=True, precompile_backward=True)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View file

@ -117,6 +117,9 @@ pm_early_transform_tensor_graph = PatternMatcher([
# transform precompiled CALLs # transform precompiled CALLs
(UPat(Ops.CALL, name="c"), transform_precompiled_call), (UPat(Ops.CALL, name="c"), transform_precompiled_call),
# resolve TUPLE+GETTUPLE (for precompiled calls)
(UPat(Ops.GETTUPLE, src=(UPat(Ops.TUPLE, name="t"),), name="g"), lambda g,t: t.src[g.arg]),
# CONTIGUOUS(MOPS(BUFFER/BUFFER_VIEW)) → CONTIGUOUS(BUFFER_VIEW) when movement ops collapse to contiguous range # CONTIGUOUS(MOPS(BUFFER/BUFFER_VIEW)) → CONTIGUOUS(BUFFER_VIEW) when movement ops collapse to contiguous range
(UPat(Ops.CONTIGUOUS, src=(UPat(GroupOp.Movement, name="src"),), name="c"), contiguous_mops_to_view), (UPat(Ops.CONTIGUOUS, src=(UPat(GroupOp.Movement, name="src"),), name="c"), contiguous_mops_to_view),

View file

@ -20,9 +20,11 @@ pm_ctx = PatternMatcher([
ReturnType = TypeVar('ReturnType') ReturnType = TypeVar('ReturnType')
class _function(Generic[ReturnType]): class _function(Generic[ReturnType]):
def __init__(self, fxn:Callable[..., ReturnType], *, precompile:bool=False, allow_implicit:bool=True, grad_fxn:Callable|None=None): def __init__(self, fxn:Callable[..., ReturnType], *, precompile:bool=False, precompile_backward:bool=False,
allow_implicit:bool=True, grad_fxn:Callable|None=None):
self.fxn = fxn self.fxn = fxn
self.precompile = precompile self.precompile = precompile
self.precompile_backward = precompile_backward
self.allow_implicit = allow_implicit self.allow_implicit = allow_implicit
self.grad_fxn = grad_fxn self.grad_fxn = grad_fxn
@ -70,7 +72,8 @@ class _function(Generic[ReturnType]):
#call = assigned.call(*call_uops, buffer, name=name) #call = assigned.call(*call_uops, buffer, name=name)
#ret = buffer.after(call) #ret = buffer.after(call)
fret = uret.call(*call_uops, grad_fxn=self.grad_fxn, name=name, precompile=self.precompile) fret = uret.call(*call_uops, grad_fxn=self.grad_fxn, name=name, precompile=self.precompile,
precompile_backward=self.precompile_backward)
if isinstance(ret, tuple): if isinstance(ret, tuple):
return cast(ReturnType, tuple(Tensor(fret.gettuple(i), device=fret.device) for i in range(len(ret)))) return cast(ReturnType, tuple(Tensor(fret.gettuple(i), device=fret.device) for i in range(len(ret))))
else: else:
@ -78,11 +81,14 @@ class _function(Generic[ReturnType]):
# overload signatures support both @function and @function(precompile=True) syntax # overload signatures support both @function and @function(precompile=True) syntax
@overload @overload
def function(fxn:Callable[..., ReturnType], *, precompile:bool=False, allow_implicit:bool=True, def function(fxn:Callable[..., ReturnType], *, precompile:bool=False, precompile_backward:bool=False,
grad_fxn:Callable|None=None) -> _function[ReturnType]: ... allow_implicit:bool=True, grad_fxn:Callable|None=None) -> _function[ReturnType]: ...
@overload @overload
def function(fxn:None=None, *, precompile:bool=False, allow_implicit:bool=True, def function(fxn:None=None, *, precompile:bool=False, precompile_backward:bool=False,
grad_fxn:Callable|None=None) -> Callable[[Callable[..., ReturnType]], _function[ReturnType]]: ... allow_implicit:bool=True, grad_fxn:Callable|None=None) -> Callable[[Callable[..., ReturnType]], _function[ReturnType]]: ...
def function(fxn=None, *, precompile:bool=False, allow_implicit:bool=True, grad_fxn:Callable|None=None): def function(fxn=None, *, precompile:bool=False, precompile_backward:bool=False, allow_implicit:bool=True, grad_fxn:Callable|None=None):
if fxn is None: return lambda f: _function(f, precompile=precompile, allow_implicit=allow_implicit, grad_fxn=grad_fxn) if fxn is None:
return _function(fxn, precompile=precompile, allow_implicit=allow_implicit, grad_fxn=grad_fxn) return lambda f: _function(f, precompile=precompile, precompile_backward=precompile_backward,
allow_implicit=allow_implicit, grad_fxn=grad_fxn)
return _function(fxn, precompile=precompile, precompile_backward=precompile_backward,
allow_implicit=allow_implicit, grad_fxn=grad_fxn)

View file

@ -13,25 +13,33 @@ def reduce_gradient(ctx:UOp, ret:UOp, op:Ops):
return ((mask/broadcast_to_input(count)) * broadcast_to_input(ctx),) return ((mask/broadcast_to_input(count)) * broadcast_to_input(ctx),)
if op == Ops.MUL: return (broadcast_to_input(ctx * ret) / ret.src[0],) if op == Ops.MUL: return (broadcast_to_input(ctx * ret) / ret.src[0],)
def call_gradient(ctx:UOp, k:UOp) -> tuple[UOp|None, ...]: def _compact_params(body:UOp, all_args:tuple[UOp, ...]) -> tuple[UOp, tuple[UOp, ...]]:
"""Remove unused PARAMs from body and return compacted (body, args)."""
used = sorted({p.arg: p for p in body.toposort() if p.op is Ops.PARAM}.items())
return body.substitute({p: p.replace(arg=j) for j,(_, p) in enumerate(used)}, walk=True), tuple(all_args[i] for i,_ in used)
def call_gradient(ctx:UOp, k:UOp, needed:set[int]) -> tuple[UOp|None, ...]:
fxn, args = k.src[0], k.src[1:] fxn, args = k.src[0], k.src[1:]
if k.arg.grad_fxn is not None: if k.arg.grad_fxn is not None:
return (None,) + (k.arg.grad_fxn(*ctx.src, call=k) if ctx.op is Ops.TUPLE else k.arg.grad_fxn(ctx, k)) if ctx.op is Ops.TUPLE:
real = [g for g in ctx.src if g.op is not Ops.NOOP]
return (None,) + (k.arg.grad_fxn(*real, call=k) if len(real) > 1 else k.arg.grad_fxn(real[0], k))
return (None,) + k.arg.grad_fxn(ctx, k)
assert fxn.op is Ops.TUPLE, f"expected TUPLE body for gradient, got {fxn.op}" assert fxn.op is Ops.TUPLE, f"expected TUPLE body for gradient, got {fxn.op}"
params = {x.arg:x for x in fxn.toposort(enter_calls=False) if x.op == Ops.PARAM} params = {x.arg:x for x in fxn.toposort(enter_calls=False) if x.op == Ops.PARAM}
grad_args = ctx.src grad_args = ctx.src
root_grad = UOp(Ops.TUPLE, src=tuple(g.param_like(len(args) + i) for i, g in enumerate(grad_args))) root_grad = UOp(Ops.TUPLE, src=tuple(fxn.src[i].const_like(0) if g.op is Ops.NOOP else g.param_like(len(args)+i) for i,g in enumerate(grad_args)))
grads = compute_gradient(fxn, root_grad, set(params.values())) grads = compute_gradient(fxn, root_grad, set(params.values()))
ret: list[UOp|None] = [None] # for precompiled calls, substitute forward outputs with params so intermediates aren't recomputed
for i in range(len(args)): fwd_subs = {src: src.param_like(len(args)+len(grad_args)+i) for i, src in enumerate(fxn.src)} if k.arg.precompile else {}
if (p:=params.get(i, None)) is not None and p in grads: fwd_outs = tuple(k.gettuple(i) for i in range(len(fxn.src))) if k.arg.precompile else ()
# TODO: compact the args and remove unused ones # collect needed gradient bodies, compact unused params, create a single backward CALL
assert not grads[p].op_in_backward_slice_with_self(Ops.BUFFER), "BUG: BUFFER in backward slice of grad" grad_bodies = [(i, grads[p]) for i in needed if (p:=params.get(i)) is not None and p in grads]
bwd_call = grads[p].call(*args, *grad_args, name=(k.arg.name or "")+f"_backward_{i}", precompile=k.arg.precompile_backward) bwd_body = UOp.maketuple(*(gb for _, gb in grad_bodies)).substitute(fwd_subs, walk=True)
ret.append(bwd_call.gettuple(0)) bwd_body, compact_args = _compact_params(bwd_body, (*args, *grad_args, *fwd_outs))
else: bwd_call = bwd_body.call(*compact_args, name=(k.arg.name or "")+"_backward", precompile=k.arg.precompile_backward)
ret.append(None) gb_map = {i: idx for idx, (i, _) in enumerate(grad_bodies)}
return tuple(ret) return (None,) + tuple(bwd_call.gettuple(gb_map[i]) if i in gb_map else None for i in range(len(args)))
# ctx is grad_output # ctx is grad_output
pm_gradient = PatternMatcher([ pm_gradient = PatternMatcher([
@ -63,32 +71,37 @@ pm_gradient = PatternMatcher([
(UPat(Ops.TUPLE), lambda ctx: ctx.src), (UPat(Ops.TUPLE), lambda ctx: ctx.src),
# NOTE: this is only correct when the KERNEL has a single output # NOTE: this is only correct when the KERNEL has a single output
(UPat(Ops.AFTER), lambda ctx: (ctx, ctx)), (UPat(Ops.AFTER), lambda ctx: (ctx, ctx)),
# gradient on CALL: use provided grad_fxn or auto-differentiate
(UPat(Ops.CALL, name="k"), call_gradient),
# there's no gradient for bitcast # there's no gradient for bitcast
(UPat(Ops.BITCAST), lambda: (None,)), (UPat(Ops.BITCAST), lambda: (None,)),
]) ])
def _deepwalk(root:UOp, targets:set[UOp]) -> list[UOp]: def _deepwalk(root:UOp, targets:set[UOp]) -> tuple[list[UOp], dict[UOp, bool]]:
# compute the target path (top down) # compute the target path (top down)
in_target_path: dict[UOp, bool] = {} in_target_path: dict[UOp, bool] = {}
for u in root.toposort(): in_target_path[u] = any(x in targets or in_target_path[x] for x in u.src) for u in root.toposort(): in_target_path[u] = any(x in targets or in_target_path[x] for x in u.src)
# don't flow through DETACH or anything not in target path # don't flow through DETACH or anything not in target path
return list(root.toposort(lambda node: node.op is not Ops.DETACH and in_target_path[node])) return list(root.toposort(lambda node: node.op is not Ops.DETACH and in_target_path[node])), in_target_path
def compute_gradient(root:UOp, root_grad:UOp, targets:set[UOp]) -> dict[UOp, UOp]: def compute_gradient(root:UOp, root_grad:UOp, targets:set[UOp]) -> dict[UOp, UOp]:
walk, in_target_path = _deepwalk(root, targets)
grads: dict[UOp, UOp] = {root: root_grad} grads: dict[UOp, UOp] = {root: root_grad}
for t0 in reversed(_deepwalk(root, targets)): for t0 in reversed(walk):
if t0 not in grads: continue if t0 not in grads: continue
# GETTUPLE: accumulate gradient into a TUPLE UOp on the CALL, process when we hit the CALL # GETTUPLE: accumulate gradient into a TUPLE UOp on the CALL, process when we hit the CALL
if t0.op is Ops.GETTUPLE: if t0.op is Ops.GETTUPLE:
k = t0.src[0] # the CALL k = t0.src[0] # the CALL
assert k.op is Ops.CALL and k.src[0].op is Ops.TUPLE assert k.op is Ops.CALL and k.src[0].op is Ops.TUPLE
n_outputs = len(k.src[0].src) n_outputs = len(k.src[0].src)
prev: tuple[UOp, ...] = grads[k].src if k in grads else tuple(grads[t0].const_like(0) for _ in range(n_outputs)) prev = grads[k].src if k in grads else tuple(UOp(Ops.NOOP) for _ in range(n_outputs))
grads[k] = UOp.maketuple(*(prev[i] + grads[t0] if i == t0.arg else prev[i] for i in range(n_outputs))) grads[k] = UOp.maketuple(*(prev[i] + grads[t0] if i == t0.arg and prev[i].op is not Ops.NOOP else
grads[t0] if i == t0.arg else prev[i] for i in range(n_outputs)))
continue continue
lgrads: tuple[UOp|None, ...]|None = cast(tuple[UOp|None, ...]|None, pm_gradient.rewrite(t0, ctx=grads[t0])) # CALL: pass needed param set so backward only computes required gradients
if t0.op is Ops.CALL:
needed = {i for i, arg in enumerate(t0.src[1:]) if arg in targets or in_target_path.get(arg, False)}
lgrads:tuple[UOp|None, ...]|None = call_gradient(grads[t0], t0, needed)
else:
lgrads = cast(tuple[UOp|None, ...]|None, pm_gradient.rewrite(t0, ctx=grads[t0]))
if lgrads is None: raise RuntimeError(f"failed to compute gradient for {t0.op}\n\nin {str(t0)[0:1000]}...") if lgrads is None: raise RuntimeError(f"failed to compute gradient for {t0.op}\n\nin {str(t0)[0:1000]}...")
assert len(lgrads) == len(t0.src), f"got {len(lgrads)} gradient, expected {len(t0.src)}" assert len(lgrads) == len(t0.src), f"got {len(lgrads)} gradient, expected {len(t0.src)}"
for k,v in zip(t0.src, lgrads): for k,v in zip(t0.src, lgrads):

View file

@ -697,6 +697,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
def has_buffer_identity(self): def has_buffer_identity(self):
"""Check if this UOp has a concrete buffer identity in the graph (RESHAPE/MULTI -> BUFFER chain).""" """Check if this UOp has a concrete buffer identity in the graph (RESHAPE/MULTI -> BUFFER chain)."""
if self.op in {Ops.RESHAPE, Ops.MULTI}: return self.src[0].has_buffer_identity() if self.op in {Ops.RESHAPE, Ops.MULTI}: return self.src[0].has_buffer_identity()
if self.op is Ops.GETTUPLE and self.src[0].op is Ops.TUPLE: return self.src[0].src[self.arg].has_buffer_identity()
return self.op in {Ops.BUFFER, Ops.BUFFER_VIEW, Ops.PARAM} return self.op in {Ops.BUFFER, Ops.BUFFER_VIEW, Ops.PARAM}
def _base_buffer_is_realized(self) -> bool: def _base_buffer_is_realized(self) -> bool: