mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
8 commits
master
...
precompile
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
34d2ef6447 |
||
|
|
125d987886 | ||
|
|
c83f2a7154 | ||
|
|
32a8b3aaa0 | ||
|
|
39aae38b2a | ||
|
|
f909d5c983 | ||
|
|
f7512e9595 | ||
|
|
9f591b42d1 |
5 changed files with 80 additions and 31 deletions
|
|
@ -1,7 +1,7 @@
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import unittest
|
import unittest
|
||||||
from tinygrad.function import function
|
from tinygrad.function import function
|
||||||
from tinygrad import Tensor
|
from tinygrad import Tensor, GlobalCounters
|
||||||
from tinygrad.uop.ops import UOp
|
from tinygrad.uop.ops import UOp
|
||||||
|
|
||||||
class TestFunction(unittest.TestCase):
|
class TestFunction(unittest.TestCase):
|
||||||
|
|
@ -384,5 +384,31 @@ class TestFunctionTuple(unittest.TestCase):
|
||||||
np.testing.assert_allclose(x.grad.numpy(), [1., 1., 1.])
|
np.testing.assert_allclose(x.grad.numpy(), [1., 1., 1.])
|
||||||
np.testing.assert_allclose(y.grad.numpy(), [1., 1., 1.])
|
np.testing.assert_allclose(y.grad.numpy(), [1., 1., 1.])
|
||||||
|
|
||||||
|
class TestFunctionGrad(unittest.TestCase):
|
||||||
|
def test_function_grad_ops(self, precompile=False, precompile_backward=False):
|
||||||
|
N = 64
|
||||||
|
x = Tensor.ones(N,N).contiguous()
|
||||||
|
w1 = Tensor.ones(N,N, requires_grad=True).contiguous()
|
||||||
|
w2 = Tensor.ones(N,N, requires_grad=True).contiguous()
|
||||||
|
w3 = Tensor.ones(N,N, requires_grad=True).contiguous()
|
||||||
|
ref = Tensor.ones(N,N).contiguous()
|
||||||
|
Tensor.realize(x, w1, w2, w3, ref)
|
||||||
|
@function(precompile=precompile, precompile_backward=precompile_backward)
|
||||||
|
def f(x, w1, w2, w3) -> tuple[Tensor, ...]:
|
||||||
|
p1 = x@w1
|
||||||
|
p2 = p1@w2
|
||||||
|
p3 = p2@w3
|
||||||
|
return p1, p2, p3, p3.contiguous()
|
||||||
|
ret = f(x, w1, w2, w3)[-1]
|
||||||
|
loss = (ret-ref).square().mean().backward()
|
||||||
|
print("RESET")
|
||||||
|
GlobalCounters.reset()
|
||||||
|
loss.realize(w1.grad, w2.grad, w3.grad)
|
||||||
|
print(GlobalCounters.global_ops, GlobalCounters.global_mem)
|
||||||
|
self.assertLessEqual(GlobalCounters.global_ops, 4739073)
|
||||||
|
def test_function_grad_ops_precompile(self): self.test_function_grad_ops(precompile=True)
|
||||||
|
def test_function_grad_ops_precompile_backward(self):
|
||||||
|
self.test_function_grad_ops(precompile=True, precompile_backward=True)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
|
||||||
|
|
@ -117,6 +117,9 @@ pm_early_transform_tensor_graph = PatternMatcher([
|
||||||
# transform precompiled CALLs
|
# transform precompiled CALLs
|
||||||
(UPat(Ops.CALL, name="c"), transform_precompiled_call),
|
(UPat(Ops.CALL, name="c"), transform_precompiled_call),
|
||||||
|
|
||||||
|
# resolve TUPLE+GETTUPLE (for precompiled calls)
|
||||||
|
(UPat(Ops.GETTUPLE, src=(UPat(Ops.TUPLE, name="t"),), name="g"), lambda g,t: t.src[g.arg]),
|
||||||
|
|
||||||
# CONTIGUOUS(MOPS(BUFFER/BUFFER_VIEW)) → CONTIGUOUS(BUFFER_VIEW) when movement ops collapse to contiguous range
|
# CONTIGUOUS(MOPS(BUFFER/BUFFER_VIEW)) → CONTIGUOUS(BUFFER_VIEW) when movement ops collapse to contiguous range
|
||||||
(UPat(Ops.CONTIGUOUS, src=(UPat(GroupOp.Movement, name="src"),), name="c"), contiguous_mops_to_view),
|
(UPat(Ops.CONTIGUOUS, src=(UPat(GroupOp.Movement, name="src"),), name="c"), contiguous_mops_to_view),
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -20,9 +20,11 @@ pm_ctx = PatternMatcher([
|
||||||
|
|
||||||
ReturnType = TypeVar('ReturnType')
|
ReturnType = TypeVar('ReturnType')
|
||||||
class _function(Generic[ReturnType]):
|
class _function(Generic[ReturnType]):
|
||||||
def __init__(self, fxn:Callable[..., ReturnType], *, precompile:bool=False, allow_implicit:bool=True, grad_fxn:Callable|None=None):
|
def __init__(self, fxn:Callable[..., ReturnType], *, precompile:bool=False, precompile_backward:bool=False,
|
||||||
|
allow_implicit:bool=True, grad_fxn:Callable|None=None):
|
||||||
self.fxn = fxn
|
self.fxn = fxn
|
||||||
self.precompile = precompile
|
self.precompile = precompile
|
||||||
|
self.precompile_backward = precompile_backward
|
||||||
self.allow_implicit = allow_implicit
|
self.allow_implicit = allow_implicit
|
||||||
self.grad_fxn = grad_fxn
|
self.grad_fxn = grad_fxn
|
||||||
|
|
||||||
|
|
@ -70,7 +72,8 @@ class _function(Generic[ReturnType]):
|
||||||
#call = assigned.call(*call_uops, buffer, name=name)
|
#call = assigned.call(*call_uops, buffer, name=name)
|
||||||
#ret = buffer.after(call)
|
#ret = buffer.after(call)
|
||||||
|
|
||||||
fret = uret.call(*call_uops, grad_fxn=self.grad_fxn, name=name, precompile=self.precompile)
|
fret = uret.call(*call_uops, grad_fxn=self.grad_fxn, name=name, precompile=self.precompile,
|
||||||
|
precompile_backward=self.precompile_backward)
|
||||||
if isinstance(ret, tuple):
|
if isinstance(ret, tuple):
|
||||||
return cast(ReturnType, tuple(Tensor(fret.gettuple(i), device=fret.device) for i in range(len(ret))))
|
return cast(ReturnType, tuple(Tensor(fret.gettuple(i), device=fret.device) for i in range(len(ret))))
|
||||||
else:
|
else:
|
||||||
|
|
@ -78,11 +81,14 @@ class _function(Generic[ReturnType]):
|
||||||
|
|
||||||
# overload signatures support both @function and @function(precompile=True) syntax
|
# overload signatures support both @function and @function(precompile=True) syntax
|
||||||
@overload
|
@overload
|
||||||
def function(fxn:Callable[..., ReturnType], *, precompile:bool=False, allow_implicit:bool=True,
|
def function(fxn:Callable[..., ReturnType], *, precompile:bool=False, precompile_backward:bool=False,
|
||||||
grad_fxn:Callable|None=None) -> _function[ReturnType]: ...
|
allow_implicit:bool=True, grad_fxn:Callable|None=None) -> _function[ReturnType]: ...
|
||||||
@overload
|
@overload
|
||||||
def function(fxn:None=None, *, precompile:bool=False, allow_implicit:bool=True,
|
def function(fxn:None=None, *, precompile:bool=False, precompile_backward:bool=False,
|
||||||
grad_fxn:Callable|None=None) -> Callable[[Callable[..., ReturnType]], _function[ReturnType]]: ...
|
allow_implicit:bool=True, grad_fxn:Callable|None=None) -> Callable[[Callable[..., ReturnType]], _function[ReturnType]]: ...
|
||||||
def function(fxn=None, *, precompile:bool=False, allow_implicit:bool=True, grad_fxn:Callable|None=None):
|
def function(fxn=None, *, precompile:bool=False, precompile_backward:bool=False, allow_implicit:bool=True, grad_fxn:Callable|None=None):
|
||||||
if fxn is None: return lambda f: _function(f, precompile=precompile, allow_implicit=allow_implicit, grad_fxn=grad_fxn)
|
if fxn is None:
|
||||||
return _function(fxn, precompile=precompile, allow_implicit=allow_implicit, grad_fxn=grad_fxn)
|
return lambda f: _function(f, precompile=precompile, precompile_backward=precompile_backward,
|
||||||
|
allow_implicit=allow_implicit, grad_fxn=grad_fxn)
|
||||||
|
return _function(fxn, precompile=precompile, precompile_backward=precompile_backward,
|
||||||
|
allow_implicit=allow_implicit, grad_fxn=grad_fxn)
|
||||||
|
|
|
||||||
|
|
@ -13,25 +13,33 @@ def reduce_gradient(ctx:UOp, ret:UOp, op:Ops):
|
||||||
return ((mask/broadcast_to_input(count)) * broadcast_to_input(ctx),)
|
return ((mask/broadcast_to_input(count)) * broadcast_to_input(ctx),)
|
||||||
if op == Ops.MUL: return (broadcast_to_input(ctx * ret) / ret.src[0],)
|
if op == Ops.MUL: return (broadcast_to_input(ctx * ret) / ret.src[0],)
|
||||||
|
|
||||||
def call_gradient(ctx:UOp, k:UOp) -> tuple[UOp|None, ...]:
|
def _compact_params(body:UOp, all_args:tuple[UOp, ...]) -> tuple[UOp, tuple[UOp, ...]]:
|
||||||
|
"""Remove unused PARAMs from body and return compacted (body, args)."""
|
||||||
|
used = sorted({p.arg: p for p in body.toposort() if p.op is Ops.PARAM}.items())
|
||||||
|
return body.substitute({p: p.replace(arg=j) for j,(_, p) in enumerate(used)}, walk=True), tuple(all_args[i] for i,_ in used)
|
||||||
|
|
||||||
|
def call_gradient(ctx:UOp, k:UOp, needed:set[int]) -> tuple[UOp|None, ...]:
|
||||||
fxn, args = k.src[0], k.src[1:]
|
fxn, args = k.src[0], k.src[1:]
|
||||||
if k.arg.grad_fxn is not None:
|
if k.arg.grad_fxn is not None:
|
||||||
return (None,) + (k.arg.grad_fxn(*ctx.src, call=k) if ctx.op is Ops.TUPLE else k.arg.grad_fxn(ctx, k))
|
if ctx.op is Ops.TUPLE:
|
||||||
|
real = [g for g in ctx.src if g.op is not Ops.NOOP]
|
||||||
|
return (None,) + (k.arg.grad_fxn(*real, call=k) if len(real) > 1 else k.arg.grad_fxn(real[0], k))
|
||||||
|
return (None,) + k.arg.grad_fxn(ctx, k)
|
||||||
assert fxn.op is Ops.TUPLE, f"expected TUPLE body for gradient, got {fxn.op}"
|
assert fxn.op is Ops.TUPLE, f"expected TUPLE body for gradient, got {fxn.op}"
|
||||||
params = {x.arg:x for x in fxn.toposort(enter_calls=False) if x.op == Ops.PARAM}
|
params = {x.arg:x for x in fxn.toposort(enter_calls=False) if x.op == Ops.PARAM}
|
||||||
grad_args = ctx.src
|
grad_args = ctx.src
|
||||||
root_grad = UOp(Ops.TUPLE, src=tuple(g.param_like(len(args) + i) for i, g in enumerate(grad_args)))
|
root_grad = UOp(Ops.TUPLE, src=tuple(fxn.src[i].const_like(0) if g.op is Ops.NOOP else g.param_like(len(args)+i) for i,g in enumerate(grad_args)))
|
||||||
grads = compute_gradient(fxn, root_grad, set(params.values()))
|
grads = compute_gradient(fxn, root_grad, set(params.values()))
|
||||||
ret: list[UOp|None] = [None]
|
# for precompiled calls, substitute forward outputs with params so intermediates aren't recomputed
|
||||||
for i in range(len(args)):
|
fwd_subs = {src: src.param_like(len(args)+len(grad_args)+i) for i, src in enumerate(fxn.src)} if k.arg.precompile else {}
|
||||||
if (p:=params.get(i, None)) is not None and p in grads:
|
fwd_outs = tuple(k.gettuple(i) for i in range(len(fxn.src))) if k.arg.precompile else ()
|
||||||
# TODO: compact the args and remove unused ones
|
# collect needed gradient bodies, compact unused params, create a single backward CALL
|
||||||
assert not grads[p].op_in_backward_slice_with_self(Ops.BUFFER), "BUG: BUFFER in backward slice of grad"
|
grad_bodies = [(i, grads[p]) for i in needed if (p:=params.get(i)) is not None and p in grads]
|
||||||
bwd_call = grads[p].call(*args, *grad_args, name=(k.arg.name or "")+f"_backward_{i}", precompile=k.arg.precompile_backward)
|
bwd_body = UOp.maketuple(*(gb for _, gb in grad_bodies)).substitute(fwd_subs, walk=True)
|
||||||
ret.append(bwd_call.gettuple(0))
|
bwd_body, compact_args = _compact_params(bwd_body, (*args, *grad_args, *fwd_outs))
|
||||||
else:
|
bwd_call = bwd_body.call(*compact_args, name=(k.arg.name or "")+"_backward", precompile=k.arg.precompile_backward)
|
||||||
ret.append(None)
|
gb_map = {i: idx for idx, (i, _) in enumerate(grad_bodies)}
|
||||||
return tuple(ret)
|
return (None,) + tuple(bwd_call.gettuple(gb_map[i]) if i in gb_map else None for i in range(len(args)))
|
||||||
|
|
||||||
# ctx is grad_output
|
# ctx is grad_output
|
||||||
pm_gradient = PatternMatcher([
|
pm_gradient = PatternMatcher([
|
||||||
|
|
@ -63,32 +71,37 @@ pm_gradient = PatternMatcher([
|
||||||
(UPat(Ops.TUPLE), lambda ctx: ctx.src),
|
(UPat(Ops.TUPLE), lambda ctx: ctx.src),
|
||||||
# NOTE: this is only correct when the KERNEL has a single output
|
# NOTE: this is only correct when the KERNEL has a single output
|
||||||
(UPat(Ops.AFTER), lambda ctx: (ctx, ctx)),
|
(UPat(Ops.AFTER), lambda ctx: (ctx, ctx)),
|
||||||
# gradient on CALL: use provided grad_fxn or auto-differentiate
|
|
||||||
(UPat(Ops.CALL, name="k"), call_gradient),
|
|
||||||
# there's no gradient for bitcast
|
# there's no gradient for bitcast
|
||||||
(UPat(Ops.BITCAST), lambda: (None,)),
|
(UPat(Ops.BITCAST), lambda: (None,)),
|
||||||
])
|
])
|
||||||
|
|
||||||
def _deepwalk(root:UOp, targets:set[UOp]) -> list[UOp]:
|
def _deepwalk(root:UOp, targets:set[UOp]) -> tuple[list[UOp], dict[UOp, bool]]:
|
||||||
# compute the target path (top down)
|
# compute the target path (top down)
|
||||||
in_target_path: dict[UOp, bool] = {}
|
in_target_path: dict[UOp, bool] = {}
|
||||||
for u in root.toposort(): in_target_path[u] = any(x in targets or in_target_path[x] for x in u.src)
|
for u in root.toposort(): in_target_path[u] = any(x in targets or in_target_path[x] for x in u.src)
|
||||||
# don't flow through DETACH or anything not in target path
|
# don't flow through DETACH or anything not in target path
|
||||||
return list(root.toposort(lambda node: node.op is not Ops.DETACH and in_target_path[node]))
|
return list(root.toposort(lambda node: node.op is not Ops.DETACH and in_target_path[node])), in_target_path
|
||||||
|
|
||||||
def compute_gradient(root:UOp, root_grad:UOp, targets:set[UOp]) -> dict[UOp, UOp]:
|
def compute_gradient(root:UOp, root_grad:UOp, targets:set[UOp]) -> dict[UOp, UOp]:
|
||||||
|
walk, in_target_path = _deepwalk(root, targets)
|
||||||
grads: dict[UOp, UOp] = {root: root_grad}
|
grads: dict[UOp, UOp] = {root: root_grad}
|
||||||
for t0 in reversed(_deepwalk(root, targets)):
|
for t0 in reversed(walk):
|
||||||
if t0 not in grads: continue
|
if t0 not in grads: continue
|
||||||
# GETTUPLE: accumulate gradient into a TUPLE UOp on the CALL, process when we hit the CALL
|
# GETTUPLE: accumulate gradient into a TUPLE UOp on the CALL, process when we hit the CALL
|
||||||
if t0.op is Ops.GETTUPLE:
|
if t0.op is Ops.GETTUPLE:
|
||||||
k = t0.src[0] # the CALL
|
k = t0.src[0] # the CALL
|
||||||
assert k.op is Ops.CALL and k.src[0].op is Ops.TUPLE
|
assert k.op is Ops.CALL and k.src[0].op is Ops.TUPLE
|
||||||
n_outputs = len(k.src[0].src)
|
n_outputs = len(k.src[0].src)
|
||||||
prev: tuple[UOp, ...] = grads[k].src if k in grads else tuple(grads[t0].const_like(0) for _ in range(n_outputs))
|
prev = grads[k].src if k in grads else tuple(UOp(Ops.NOOP) for _ in range(n_outputs))
|
||||||
grads[k] = UOp.maketuple(*(prev[i] + grads[t0] if i == t0.arg else prev[i] for i in range(n_outputs)))
|
grads[k] = UOp.maketuple(*(prev[i] + grads[t0] if i == t0.arg and prev[i].op is not Ops.NOOP else
|
||||||
|
grads[t0] if i == t0.arg else prev[i] for i in range(n_outputs)))
|
||||||
continue
|
continue
|
||||||
lgrads: tuple[UOp|None, ...]|None = cast(tuple[UOp|None, ...]|None, pm_gradient.rewrite(t0, ctx=grads[t0]))
|
# CALL: pass needed param set so backward only computes required gradients
|
||||||
|
if t0.op is Ops.CALL:
|
||||||
|
needed = {i for i, arg in enumerate(t0.src[1:]) if arg in targets or in_target_path.get(arg, False)}
|
||||||
|
lgrads:tuple[UOp|None, ...]|None = call_gradient(grads[t0], t0, needed)
|
||||||
|
else:
|
||||||
|
lgrads = cast(tuple[UOp|None, ...]|None, pm_gradient.rewrite(t0, ctx=grads[t0]))
|
||||||
if lgrads is None: raise RuntimeError(f"failed to compute gradient for {t0.op}\n\nin {str(t0)[0:1000]}...")
|
if lgrads is None: raise RuntimeError(f"failed to compute gradient for {t0.op}\n\nin {str(t0)[0:1000]}...")
|
||||||
assert len(lgrads) == len(t0.src), f"got {len(lgrads)} gradient, expected {len(t0.src)}"
|
assert len(lgrads) == len(t0.src), f"got {len(lgrads)} gradient, expected {len(t0.src)}"
|
||||||
for k,v in zip(t0.src, lgrads):
|
for k,v in zip(t0.src, lgrads):
|
||||||
|
|
|
||||||
|
|
@ -697,6 +697,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||||
def has_buffer_identity(self):
|
def has_buffer_identity(self):
|
||||||
"""Check if this UOp has a concrete buffer identity in the graph (RESHAPE/MULTI -> BUFFER chain)."""
|
"""Check if this UOp has a concrete buffer identity in the graph (RESHAPE/MULTI -> BUFFER chain)."""
|
||||||
if self.op in {Ops.RESHAPE, Ops.MULTI}: return self.src[0].has_buffer_identity()
|
if self.op in {Ops.RESHAPE, Ops.MULTI}: return self.src[0].has_buffer_identity()
|
||||||
|
if self.op is Ops.GETTUPLE and self.src[0].op is Ops.TUPLE: return self.src[0].src[self.arg].has_buffer_identity()
|
||||||
return self.op in {Ops.BUFFER, Ops.BUFFER_VIEW, Ops.PARAM}
|
return self.op in {Ops.BUFFER, Ops.BUFFER_VIEW, Ops.PARAM}
|
||||||
|
|
||||||
def _base_buffer_is_realized(self) -> bool:
|
def _base_buffer_is_realized(self) -> bool:
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue