Compare commits

...

6 commits

Author SHA1 Message Date
George Hotz
e4d8e03954 more test 2026-03-16 12:37:16 +08:00
George Hotz
09574a096a maketuple 2026-03-16 12:04:39 +08:00
George Hotz
b89c233917 fix 2026-03-16 11:56:36 +08:00
George Hotz
00d847afdf single backward gradient 2026-03-16 11:44:15 +08:00
George Hotz
8d75eed0a4 fix precompile 2026-03-16 11:33:52 +08:00
George Hotz
e59cbf78bd Add TUPLE and GETTUPLE 2026-03-16 11:22:03 +08:00
8 changed files with 281 additions and 26 deletions

View file

@ -3,6 +3,7 @@ import unittest
from tinygrad.function import function
from tinygrad import Tensor
from tinygrad.uop.ops import UOp
from tinygrad.helpers import GlobalCounters
class TestFunction(unittest.TestCase):
def test_simple(self):
@ -335,5 +336,100 @@ class TestFunctionMulti(unittest.TestCase):
f(x).sum().backward()
np.testing.assert_allclose(x.grad.numpy(), expected)
class TestFunctionTuple(unittest.TestCase):
def test_tuple(self, precompile=False):
x = Tensor.ones(3).contiguous()
@function(precompile=precompile)
def f(t:Tensor): return (t+1, t+2)
t1, t2 = f(x)
t1.realize(t2)
print(t1.tolist(), t2.tolist())
assert t1.tolist() == [2,2,2]
assert t2.tolist() == [3,3,3]
def test_tuple_precompile(self): self.test_tuple(True)
class TestFunctionBackward(unittest.TestCase):
def test_backward_has_grad(self):
N = 16
x = Tensor.empty(N, N)
w1 = Tensor.empty(N, N, requires_grad=True)
@function
def f(t:Tensor, w1:Tensor): return (t@w1)
f(x, w1).sum().backward()
assert w1.grad is not None
def test_double_matmul_backward(self):
N = 16
x = Tensor.empty(N, N)
w1 = Tensor.empty(N, N, requires_grad=True)
w2 = Tensor.empty(N, N, requires_grad=True)
ref = Tensor.empty(N)
@function(precompile=True)
def f(t:Tensor, w1:Tensor, w2:Tensor): return (t@w1)@w2
loss = (f(x, w1, w2)-ref).square().mean().backward()
loss.realize(w1.grad, w2.grad)
def test_backward_single_call(self):
N = 4
x = Tensor.arange(N*N).reshape(N, N).float()
w1 = Tensor.arange(N*N).reshape(N, N).float().requires_grad_()
w2 = Tensor.arange(N*N).reshape(N, N).float().requires_grad_()
xn, w1n, w2n = x.numpy(), w1.numpy(), w2.numpy()
@function
def f(t:Tensor, w1:Tensor, w2:Tensor): return (t@w1)@w2
f(x, w1, w2).sum().backward()
assert w1.grad is not None and w2.grad is not None
np.testing.assert_allclose(w1.grad.numpy(), xn.T @ np.ones((N,N)) @ w2n.T, atol=1e-3)
np.testing.assert_allclose(w2.grad.numpy(), (xn @ w1n).T @ np.ones((N,N)), atol=1e-3)
def test_backward_precompile_backward(self):
N = 4
x = Tensor.arange(N*N).reshape(N, N).float().contiguous()
w1 = Tensor.arange(N*N).reshape(N, N).float().requires_grad_().contiguous()
w2 = Tensor.arange(N*N).reshape(N, N).float().requires_grad_().contiguous()
Tensor.realize(x, w1, w2)
xn, w1n, w2n = x.numpy(), w1.numpy(), w2.numpy()
@function(precompile=True, precompile_backward=True)
def f(t:Tensor, w1:Tensor, w2:Tensor): return (t@w1)@w2
loss = f(x, w1, w2).sum().backward()
assert w1.grad is not None and w2.grad is not None
GlobalCounters.reset()
Tensor.realize(loss, w1.grad, w2.grad)
np.testing.assert_allclose(w1.grad.numpy(), xn.T @ np.ones((N,N)) @ w2n.T, atol=1e-3)
np.testing.assert_allclose(w2.grad.numpy(), (xn @ w1n).T @ np.ones((N,N)), atol=1e-3)
def test_backward_precompile_backward_tuple(self):
N = 4
x = Tensor.arange(N*N).reshape(N, N).float().contiguous()
w1 = Tensor.arange(N*N).reshape(N, N).float().requires_grad_().contiguous()
w2 = Tensor.arange(N*N).reshape(N, N).float().requires_grad_().contiguous()
Tensor.realize(x, w1, w2)
xn, w1n, w2n = x.numpy(), w1.numpy(), w2.numpy()
# non-tuple reference
ref_w1 = Tensor.arange(N*N).reshape(N, N).float().requires_grad_().contiguous()
ref_w2 = Tensor.arange(N*N).reshape(N, N).float().requires_grad_().contiguous()
Tensor.realize(ref_w1, ref_w2)
@function(precompile=True, precompile_backward=True)
def g(t:Tensor, w1:Tensor, w2:Tensor): return (t@w1)@w2
g(x, ref_w1, ref_w2).sum().backward()
GlobalCounters.reset()
Tensor.realize(ref_w1.grad, ref_w2.grad)
ref_ops = GlobalCounters.global_ops
# tuple version with intermediate — should not redo forward compute
@function(precompile=True, precompile_backward=True)
def f(t:Tensor, w1:Tensor, w2:Tensor):
h = t@w1
return (h, h@w2)
h, out = f(x, w1, w2)
loss = out.sum().backward()
assert w1.grad is not None and w2.grad is not None
GlobalCounters.reset()
Tensor.realize(loss, w1.grad, w2.grad)
np.testing.assert_allclose(w1.grad.numpy(), xn.T @ np.ones((N,N)) @ w2n.T, atol=1e-3)
np.testing.assert_allclose(w2.grad.numpy(), (xn @ w1n).T @ np.ones((N,N)), atol=1e-3)
# tuple version should have fewer ops than non-tuple (saves recomputing h=t@w1 in backward)
self.assertLessEqual(GlobalCounters.global_ops, ref_ops)
if __name__ == '__main__':
unittest.main()

View file

@ -94,8 +94,21 @@ def contiguous_mops_to_view(c:UOp):
def transform_precompiled_call(c:UOp) -> UOp|None:
if not c.arg.precompile: return None
if c.src[0].op is Ops.SINK: return None
out = _buffer_like(c)
input_buffers = tuple(x.contiguous() if x.op not in {Ops.AFTER, Ops.BIND} else x for x in c.src[1:])
# multi-output (TUPLE) precompiled calls: allocate a buffer per output and return a TUPLE of buffers
if c.src[0].op is Ops.TUPLE:
tuple_body = c.src[0]
out_bufs = []
sink_srcs = []
for i, elem in enumerate(tuple_body.src):
buf = UOp.new_buffer(c.device, prod(elem.max_shape), elem.dtype).reshape(elem.max_shape).shrink_to(elem.shape)
out_bufs.append(buf)
target = buf.param_like(len(c.src) - 1 + i).shrink_to(elem.shape)
sink_srcs.append(target.after(target.store(elem)))
fxn = UOp.sink(*sink_srcs)
new_call = c.replace(src=(fxn, *input_buffers, *out_bufs), dtype=dtypes.void, tag=None)
return UOp.maketuple(*[buf.after(new_call) for buf in out_bufs])
out = _buffer_like(c)
target = out.param_like(len(c.src)-1).shrink_to(c.shape)
fxn = target.after(target.store(c.src[0])).sink()
ret = out.after(c.replace(src=(fxn, *input_buffers, out), dtype=dtypes.void, tag=None))

View file

@ -2,6 +2,7 @@ import functools
from typing import Generic, TypeVar, Callable, cast, overload
from tinygrad.helpers import Context, dedup, getenv
from tinygrad.uop.ops import UOp, Ops, graph_rewrite, PatternMatcher, UPat
from tinygrad.gradient import compute_gradient
from tinygrad.tensor import Tensor
def add_to_ctx(ctx, x:UOp):
@ -19,9 +20,10 @@ pm_ctx = PatternMatcher([
ReturnType = TypeVar('ReturnType')
class _function(Generic[ReturnType]):
def __init__(self, fxn:Callable[..., ReturnType], *, precompile:bool=False):
def __init__(self, fxn:Callable[..., ReturnType], *, precompile:bool=False, precompile_backward:bool=False):
self.fxn = fxn
self.precompile = precompile
self.precompile_backward = precompile_backward
def __get__(self, obj, objtype=None): return functools.partial(self.__call__, obj) if obj is not None else self
@ -39,12 +41,17 @@ class _function(Generic[ReturnType]):
# run it and do surgery later
with Context(ALLOW_DEVICE_USAGE=getenv("DEVICE_IN_FUNCTION_BUG", 0)):
ret = self.fxn(*args, **kwargs)
assert isinstance(ret, Tensor), "only supports one tensor return for now"
if isinstance(ret, Tensor):
uret = ret.uop
elif isinstance(ret, tuple) and all(isinstance(x, Tensor) for x in ret):
uret = UOp.maketuple(*[x.uop for x in ret])
else:
raise RuntimeError(f"function return type {type(ret)} not supported")
# replace the known inputs with params (using deduplicated slots)
subs = {}
for i,x in enumerate(call_uops): subs[x] = x.param_like(i)
uret = ret.uop.substitute(subs)
uret = uret.substitute(subs)
# add contiguous to call_uops
#call_uops = [x.contiguous() for x in call_uops]
@ -60,14 +67,77 @@ class _function(Generic[ReturnType]):
#call = assigned.call(*call_uops, buffer, name=name)
#ret = buffer.after(call)
ret = uret.call(*call_uops, name=name, precompile=self.precompile)
return cast(ReturnType, Tensor(ret, device=ret.device))
# precompute the backward: determine which inputs need gradients and build the backward CALL body
grad_fxn = None
inputs = [t for t in list(args)+[kwargs[k] for k in sorted(kwargs)] if isinstance(t, (Tensor, UOp))]
grad_params = {x.arg:x for x in uret.toposort(enter_calls=False) if x.op == Ops.PARAM}
# find which param slots correspond to requires_grad inputs
need_grad = {i for i, t in enumerate(inputs) if isinstance(t, Tensor) and t.requires_grad}
target_params = {grad_params[i] for i in need_grad if i in grad_params}
if target_params:
grad_fxn = self._make_grad_fxn(uret, len(call_uops), target_params, need_grad, name, self.precompile_backward)
fret = uret.call(*call_uops, name=name, precompile=self.precompile, precompile_backward=self.precompile_backward, grad_fxn=grad_fxn)
if isinstance(ret, tuple):
return cast(ReturnType, tuple(Tensor(fret.gettuple(i), device=fret.device) for i in range(len(ret))))
else:
return cast(ReturnType, Tensor(fret, device=fret.device))
@staticmethod
def _make_grad_fxn(uret:UOp, num_args:int, target_params:set[UOp], need_grad:set[int], name:str, precompile_backward:bool):
def grad_fxn(ctx, k):
fxn, args = k.src[0], k.src[1:]
params = {x.arg:x for x in fxn.toposort(enter_calls=False) if x.op == Ops.PARAM}
# compute gradients only for needed params
if isinstance(ctx, dict):
all_grads: dict[UOp, UOp] = {}
for idx, grad_out in ctx.items():
elem_grads = compute_gradient(fxn.src[idx], grad_out.param_like(len(args) + idx), target_params)
for p, g in elem_grads.items():
if p in all_grads: all_grads[p] = all_grads[p] + g
else: all_grads[p] = g
grads = all_grads
grad_ctx_inputs = tuple(ctx.get(i, fxn.src[i].const_like(0)) for i in range(len(fxn.src)))
else:
grads = compute_gradient(fxn, ctx.param_like(len(args)), target_params)
grad_ctx_inputs = (ctx,)
# collect gradients for needed params only
grad_indices: list[int] = []
grad_uops: list[UOp] = []
for i in range(len(args)):
if i in need_grad and (p:=params.get(i)) is not None and p in grads:
grad_indices.append(i)
grad_uops.append(grads[p])
if len(grad_uops) == 0: return (None,) * len(args)
# replace forward output references with PARAMs to avoid recomputation
if fxn.op is Ops.TUPLE:
fwd_subs = {elem: elem.param_like(len(args) + len(grad_ctx_inputs) + i) for i, elem in enumerate(fxn.src)}
fwd_inputs = tuple(k.gettuple(i) for i in range(len(fxn.src)))
else:
fwd_subs = {fxn: fxn.param_like(len(args) + 1)}
fwd_inputs = (k,)
grad_uops = [g.substitute(fwd_subs) for g in grad_uops]
# build a single backward CALL returning a TUPLE of all gradients
bwd_body = UOp.maketuple(*grad_uops)
bwd_call = bwd_body.call(*args, *grad_ctx_inputs, *fwd_inputs, name=name+"_backward", precompile=precompile_backward)
# extract each gradient via GETTUPLE
ret: list[UOp|None] = []
gi = 0
for i in range(len(args)):
if gi < len(grad_indices) and grad_indices[gi] == i:
ret.append(bwd_call.gettuple(gi))
gi += 1
else:
ret.append(None)
return tuple(ret)
return grad_fxn
# overload signatures support both @function and @function(precompile=True) syntax
@overload
def function(fxn:Callable[..., ReturnType], *, precompile:bool=False) -> _function[ReturnType]: ...
def function(fxn:Callable[..., ReturnType], *, precompile:bool=False, precompile_backward:bool=False) -> _function[ReturnType]: ...
@overload
def function(fxn:None=None, *, precompile:bool=False) -> Callable[[Callable[..., ReturnType]], _function[ReturnType]]: ...
def function(fxn=None, *, precompile:bool=False):
if fxn is None: return lambda f: _function(f, precompile=precompile)
return _function(fxn, precompile=precompile)
def function(fxn:None=None, *, precompile:bool=False, precompile_backward:bool=False) -> \
Callable[[Callable[..., ReturnType]], _function[ReturnType]]: ...
def function(fxn=None, *, precompile:bool=False, precompile_backward:bool=False):
if fxn is None: return lambda f: _function(f, precompile=precompile, precompile_backward=precompile_backward)
return _function(fxn, precompile=precompile, precompile_backward=precompile_backward)

View file

@ -13,18 +13,53 @@ def reduce_gradient(ctx:UOp, ret:UOp, op:Ops):
return ((mask/broadcast_to_input(count)) * broadcast_to_input(ctx),)
if op == Ops.MUL: return (broadcast_to_input(ctx * ret) / ret.src[0],)
def call_gradient(ctx:UOp, k:UOp) -> tuple[UOp|None, ...]:
if k.arg.grad_fxn is not None: return (None,) + k.arg.grad_fxn(ctx, k)
def call_gradient(ctx:UOp|dict[int, UOp], k:UOp) -> tuple[UOp|None, ...]:
if k.arg.grad_fxn is not None:
if isinstance(ctx, dict): return (None,) + k.arg.grad_fxn(ctx, k)
return (None,) + k.arg.grad_fxn(ctx, k)
# auto-differentiate the function
fxn, args = k.src[0], k.src[1:]
params = {x.arg:x for x in fxn.toposort(enter_calls=False) if x.op == Ops.PARAM}
grads = compute_gradient(fxn, ctx.param_like(len(args)), set(params.values()))
ret: list[UOp|None] = [None]
# for tuple-returning CALLs, ctx is a dict mapping output index to gradient; differentiate each output separately and sum
if isinstance(ctx, dict):
all_grads: dict[UOp, UOp] = {}
for idx, grad_out in ctx.items():
elem_grads = compute_gradient(fxn.src[idx], grad_out.param_like(len(args) + idx), set(params.values()))
for p, g in elem_grads.items():
if p in all_grads: all_grads[p] = all_grads[p] + g
else: all_grads[p] = g
grads = all_grads
grad_ctx_inputs = tuple(ctx.get(i, fxn.src[i].const_like(0)) for i in range(len(fxn.src)))
else:
grads = compute_gradient(fxn, ctx.param_like(len(args)), set(params.values()))
grad_ctx_inputs = (ctx,)
# collect which args have gradients
grad_indices: list[int] = []
grad_uops: list[UOp] = []
for i in range(len(args)):
if (p:=params.get(i, None)) is not None and p in grads:
# TODO: compact the args and remove unused ones
assert not grads[p].op_in_backward_slice_with_self(Ops.BUFFER), "BUG: BUFFER in backward slice of grad"
ret.append(grads[p].call(*args, ctx, name=(k.arg.name or "")+f"_backward_{i}"))
grad_indices.append(i)
grad_uops.append(grads[p])
if len(grad_uops) == 0: return (None,) * (len(args) + 1)
# replace forward output references with PARAMs, passing the forward CALL output(s) as inputs to avoid recomputation
if fxn.op is Ops.TUPLE:
fwd_subs = {elem: elem.param_like(len(args) + len(grad_ctx_inputs) + i) for i, elem in enumerate(fxn.src)}
fwd_inputs = tuple(k.gettuple(i) for i in range(len(fxn.src)))
else:
fwd_subs = {fxn: fxn.param_like(len(args) + 1)}
fwd_inputs = (k,)
grad_uops = [g.substitute(fwd_subs) for g in grad_uops]
# build a single backward CALL returning a TUPLE of all gradients
bwd_body = UOp.maketuple(*grad_uops)
bwd_call = bwd_body.call(*args, *grad_ctx_inputs, *fwd_inputs, name=(k.arg.name or "")+"_backward", precompile=k.arg.precompile_backward)
# extract each gradient via GETTUPLE
ret: list[UOp|None] = [None]
gi = 0
for i in range(len(args)):
if gi < len(grad_indices) and grad_indices[gi] == i:
ret.append(bwd_call.gettuple(gi))
gi += 1
else:
ret.append(None)
return tuple(ret)
@ -72,11 +107,26 @@ def _deepwalk(root:UOp, targets:set[UOp]) -> list[UOp]:
return list(root.toposort(lambda node: node.op not in {Ops.DETACH, Ops.ASSIGN} and in_target_path[node]))
def compute_gradient(root:UOp, root_grad:UOp, targets:set[UOp]) -> dict[UOp, UOp]:
grads = {root: root_grad}
grads: dict[UOp, UOp] = {root: root_grad}
# for GETTUPLE nodes on tuple-returning CALLs, collect per-output gradients
tuple_call_grads: dict[UOp, dict[int, UOp]] = {}
for t0 in reversed(_deepwalk(root, targets)):
if t0 not in grads: continue
lgrads: tuple[UOp|None, ...]|None = cast(tuple[UOp|None, ...]|None, pm_gradient.rewrite(t0, ctx=grads[t0]))
if lgrads is None: raise RuntimeError(f"failed to compute gradient for {t0.op}\n\nin {str(t0)[0:1000]}...")
if t0 not in grads and t0 not in tuple_call_grads: continue
# for tuple-returning CALLs, use accumulated per-output gradients
if t0.op is Ops.CALL and t0 in tuple_call_grads:
lgrads = cast(tuple[UOp|None, ...], call_gradient(tuple_call_grads[t0], t0))
elif t0 not in grads:
continue
# for GETTUPLE on a CALL, accumulate gradient per output index instead of propagating to CALL directly
elif t0.op is Ops.GETTUPLE and t0.src[0].op is Ops.CALL:
call = t0.src[0]
if call not in tuple_call_grads: tuple_call_grads[call] = {}
if t0.arg in tuple_call_grads[call]: tuple_call_grads[call][t0.arg] = tuple_call_grads[call][t0.arg] + grads[t0]
else: tuple_call_grads[call][t0.arg] = grads[t0]
continue
else:
lgrads = cast(tuple[UOp|None, ...]|None, pm_gradient.rewrite(t0, ctx=grads[t0]))
if lgrads is None: raise RuntimeError(f"failed to compute gradient for {t0.op}\n\nin {str(t0)[0:1000]}...")
assert len(lgrads) == len(t0.src), f"got {len(lgrads)} gradient, expected {len(t0.src)}"
for k,v in zip(t0.src, lgrads):
if v is None: continue

View file

@ -137,6 +137,9 @@ earliest_rewrites = mop_cleanup+PatternMatcher([
# resolve calls
(UPat(Ops.CALL, name="c"), resolve_call),
# resolve TUPLE+GETTUPLE
(UPat(Ops.GETTUPLE, src=(UPat(Ops.TUPLE, name="t"),), name="g"), lambda g,t: t.src[g.arg]),
# resolve allreduce (must be bottom up)
(UPat(Ops.ALLREDUCE, src=(UPat.var("buf"), UPat()), name="red"), create_allreduce_function),

View file

@ -39,6 +39,9 @@ class Ops(FastEnum):
# vector creation / item selection
GEP = auto(); VECTORIZE = auto()
# tuple/gettuple for function with multiple returns
TUPLE = auto(); GETTUPLE = auto()
# ** 3 -- load/store **
# INDEX is a BinaryOp similar to ADD, but it operates on pointers

View file

@ -209,9 +209,15 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
# late ops don't have shape
case Ops.UNIQUE | Ops.LUNIQUE | Ops.DEVICE | Ops.RANGE | Ops.LOAD | Ops.STORE | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \
Ops.VECTORIZE | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.CONTRACT | Ops.SINK | \
Ops.LINEAR | Ops.PROGRAM | Ops.SOURCE | Ops.BINARY | Ops.INS:
Ops.LINEAR | Ops.PROGRAM | Ops.SOURCE | Ops.BINARY | Ops.INS | Ops.TUPLE:
return None
case Ops.GETTUPLE:
# GETTUPLE extracts from a TUPLE
in_tuple = self.src[0].src[0] if self.src[0].op is Ops.CALL else self.src[0]
assert in_tuple.op is Ops.TUPLE
return in_tuple.src[self.arg]._shape
case Ops.CAST:
# when PTX casts from ptr to non ptr, remove the shape
if isinstance(self.src[0].dtype, PtrDType) and not isinstance(self.src[0].dtype, ImageDType) and not isinstance(self.dtype, PtrDType):
@ -404,6 +410,12 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
def sink(*srcs:UOp|None, **kwargs): # pylint: disable=no-self-argument
return UOp(Ops.SINK, dtypes.void, tuple([x for x in srcs if x is not None]), **kwargs)
def maketuple(*srcs:UOp): # pylint: disable=no-self-argument
return UOp(Ops.TUPLE, dtypes.void, srcs)
def gettuple(self, idx:int) -> UOp:
in_tuple = self.src[0] if self.op is Ops.CALL else self
assert in_tuple.op is Ops.TUPLE, f"gettuple requires CALL or TUPLE source, got {self.op}"
return UOp(Ops.GETTUPLE, in_tuple.src[idx].dtype, (self,), idx)
def group(*srcs:UOp|None): # pylint: disable=no-self-argument
if len(srcs) == 1 and isinstance(srcs[0], UOp): return srcs[0]
return UOp(Ops.GROUP, dtypes.void, tuple([x for x in srcs if x is not None]))
@ -903,9 +915,10 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
if self.axis is not None: p = p.replace(src=p.src + (UOp(Ops.MULTI, arg=self.axis),))
return p
def call(self, *srcs:UOp, grad_fxn:Callable|None=None, metadata:tuple[Metadata, ...]=(), name:str|None=None, precompile:bool=False) -> UOp:
def call(self, *srcs:UOp, grad_fxn:Callable|None=None, metadata:tuple[Metadata, ...]=(),
name:str|None=None, precompile:bool=False, precompile_backward:bool=False) -> UOp:
assert len(self.ranges) == 0, f"ranges {self.ranges} are leaking out of the call in {self.pyrender()}"
return UOp(Ops.CALL, self.dtype, (self,)+srcs, CallInfo(grad_fxn, metadata, name, precompile))
return UOp(Ops.CALL, self.dtype, (self,)+srcs, CallInfo(grad_fxn, metadata, name, precompile, precompile_backward))
def custom_kernel(*srcs:UOp, fxn:Callable, grad_fxn:Callable|None=None) -> list[UOp]:
contig_srcs = tuple(x.contiguous() if x.op is not Ops.AFTER else x for x in srcs)
placeholders = [UOp.placeholder_like(s, slot=i) for i,s in enumerate(contig_srcs)]
@ -929,9 +942,12 @@ class CallInfo:
metadata: tuple[Metadata, ...] = ()
name: str|None = None
precompile: bool = False
precompile_backward: bool = False
# grad_fxn can't be pickled, but metadata can
def __reduce__(self): return (CallInfo, (None, self.metadata, self.name, self.precompile))
def __repr__(self): return f"CallInfo({id(self.grad_fxn) if self.grad_fxn else None}, {self.metadata}, {repr(self.name)}, {self.precompile})"
def __reduce__(self): return (CallInfo, (None, self.metadata, self.name, self.precompile, self.precompile_backward))
def __repr__(self):
gf = id(self.grad_fxn) if self.grad_fxn else None
return f"CallInfo({gf}, {self.metadata}, {repr(self.name)}, {self.precompile}, {self.precompile_backward})"
def should_resolve_call(c:UOp) -> bool:
# don't resolve real kernel calls, sink or program

View file

@ -139,6 +139,10 @@ _tensor_spec = PatternMatcher([
(UPat(Ops.PARAM), lambda: True),
(UPat(Ops.CUSTOM_FUNCTION, name="x"), lambda x: isinstance(x.arg, str)),
# TUPLE must have void dtype, GETTUPLE can only appear on CALL or TUPLE
(UPat(Ops.TUPLE, dtypes.void), lambda: True),
(UPat(Ops.GETTUPLE, src=(UPat((Ops.CALL, Ops.TUPLE)),), name="g"), lambda g: isinstance(g.arg, int)),
# ** for custom kernels **
# codegen: PROGRAM with progressive sources through the pipeline (SINK, DEVICE, LINEAR?, SOURCE?, BINARY?)