mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
6 commits
master
...
tuplegtupl
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e4d8e03954 | ||
|
|
09574a096a | ||
|
|
b89c233917 | ||
|
|
00d847afdf | ||
|
|
8d75eed0a4 | ||
|
|
e59cbf78bd |
8 changed files with 281 additions and 26 deletions
|
|
@ -3,6 +3,7 @@ import unittest
|
|||
from tinygrad.function import function
|
||||
from tinygrad import Tensor
|
||||
from tinygrad.uop.ops import UOp
|
||||
from tinygrad.helpers import GlobalCounters
|
||||
|
||||
class TestFunction(unittest.TestCase):
|
||||
def test_simple(self):
|
||||
|
|
@ -335,5 +336,100 @@ class TestFunctionMulti(unittest.TestCase):
|
|||
f(x).sum().backward()
|
||||
np.testing.assert_allclose(x.grad.numpy(), expected)
|
||||
|
||||
class TestFunctionTuple(unittest.TestCase):
|
||||
def test_tuple(self, precompile=False):
|
||||
x = Tensor.ones(3).contiguous()
|
||||
@function(precompile=precompile)
|
||||
def f(t:Tensor): return (t+1, t+2)
|
||||
t1, t2 = f(x)
|
||||
t1.realize(t2)
|
||||
print(t1.tolist(), t2.tolist())
|
||||
assert t1.tolist() == [2,2,2]
|
||||
assert t2.tolist() == [3,3,3]
|
||||
def test_tuple_precompile(self): self.test_tuple(True)
|
||||
|
||||
class TestFunctionBackward(unittest.TestCase):
|
||||
def test_backward_has_grad(self):
|
||||
N = 16
|
||||
x = Tensor.empty(N, N)
|
||||
w1 = Tensor.empty(N, N, requires_grad=True)
|
||||
@function
|
||||
def f(t:Tensor, w1:Tensor): return (t@w1)
|
||||
f(x, w1).sum().backward()
|
||||
assert w1.grad is not None
|
||||
|
||||
def test_double_matmul_backward(self):
|
||||
N = 16
|
||||
x = Tensor.empty(N, N)
|
||||
w1 = Tensor.empty(N, N, requires_grad=True)
|
||||
w2 = Tensor.empty(N, N, requires_grad=True)
|
||||
ref = Tensor.empty(N)
|
||||
|
||||
@function(precompile=True)
|
||||
def f(t:Tensor, w1:Tensor, w2:Tensor): return (t@w1)@w2
|
||||
loss = (f(x, w1, w2)-ref).square().mean().backward()
|
||||
loss.realize(w1.grad, w2.grad)
|
||||
|
||||
def test_backward_single_call(self):
|
||||
N = 4
|
||||
x = Tensor.arange(N*N).reshape(N, N).float()
|
||||
w1 = Tensor.arange(N*N).reshape(N, N).float().requires_grad_()
|
||||
w2 = Tensor.arange(N*N).reshape(N, N).float().requires_grad_()
|
||||
xn, w1n, w2n = x.numpy(), w1.numpy(), w2.numpy()
|
||||
@function
|
||||
def f(t:Tensor, w1:Tensor, w2:Tensor): return (t@w1)@w2
|
||||
f(x, w1, w2).sum().backward()
|
||||
assert w1.grad is not None and w2.grad is not None
|
||||
np.testing.assert_allclose(w1.grad.numpy(), xn.T @ np.ones((N,N)) @ w2n.T, atol=1e-3)
|
||||
np.testing.assert_allclose(w2.grad.numpy(), (xn @ w1n).T @ np.ones((N,N)), atol=1e-3)
|
||||
|
||||
def test_backward_precompile_backward(self):
|
||||
N = 4
|
||||
x = Tensor.arange(N*N).reshape(N, N).float().contiguous()
|
||||
w1 = Tensor.arange(N*N).reshape(N, N).float().requires_grad_().contiguous()
|
||||
w2 = Tensor.arange(N*N).reshape(N, N).float().requires_grad_().contiguous()
|
||||
Tensor.realize(x, w1, w2)
|
||||
xn, w1n, w2n = x.numpy(), w1.numpy(), w2.numpy()
|
||||
@function(precompile=True, precompile_backward=True)
|
||||
def f(t:Tensor, w1:Tensor, w2:Tensor): return (t@w1)@w2
|
||||
loss = f(x, w1, w2).sum().backward()
|
||||
assert w1.grad is not None and w2.grad is not None
|
||||
GlobalCounters.reset()
|
||||
Tensor.realize(loss, w1.grad, w2.grad)
|
||||
np.testing.assert_allclose(w1.grad.numpy(), xn.T @ np.ones((N,N)) @ w2n.T, atol=1e-3)
|
||||
np.testing.assert_allclose(w2.grad.numpy(), (xn @ w1n).T @ np.ones((N,N)), atol=1e-3)
|
||||
|
||||
def test_backward_precompile_backward_tuple(self):
|
||||
N = 4
|
||||
x = Tensor.arange(N*N).reshape(N, N).float().contiguous()
|
||||
w1 = Tensor.arange(N*N).reshape(N, N).float().requires_grad_().contiguous()
|
||||
w2 = Tensor.arange(N*N).reshape(N, N).float().requires_grad_().contiguous()
|
||||
Tensor.realize(x, w1, w2)
|
||||
xn, w1n, w2n = x.numpy(), w1.numpy(), w2.numpy()
|
||||
# non-tuple reference
|
||||
ref_w1 = Tensor.arange(N*N).reshape(N, N).float().requires_grad_().contiguous()
|
||||
ref_w2 = Tensor.arange(N*N).reshape(N, N).float().requires_grad_().contiguous()
|
||||
Tensor.realize(ref_w1, ref_w2)
|
||||
@function(precompile=True, precompile_backward=True)
|
||||
def g(t:Tensor, w1:Tensor, w2:Tensor): return (t@w1)@w2
|
||||
g(x, ref_w1, ref_w2).sum().backward()
|
||||
GlobalCounters.reset()
|
||||
Tensor.realize(ref_w1.grad, ref_w2.grad)
|
||||
ref_ops = GlobalCounters.global_ops
|
||||
# tuple version with intermediate — should not redo forward compute
|
||||
@function(precompile=True, precompile_backward=True)
|
||||
def f(t:Tensor, w1:Tensor, w2:Tensor):
|
||||
h = t@w1
|
||||
return (h, h@w2)
|
||||
h, out = f(x, w1, w2)
|
||||
loss = out.sum().backward()
|
||||
assert w1.grad is not None and w2.grad is not None
|
||||
GlobalCounters.reset()
|
||||
Tensor.realize(loss, w1.grad, w2.grad)
|
||||
np.testing.assert_allclose(w1.grad.numpy(), xn.T @ np.ones((N,N)) @ w2n.T, atol=1e-3)
|
||||
np.testing.assert_allclose(w2.grad.numpy(), (xn @ w1n).T @ np.ones((N,N)), atol=1e-3)
|
||||
# tuple version should have fewer ops than non-tuple (saves recomputing h=t@w1 in backward)
|
||||
self.assertLessEqual(GlobalCounters.global_ops, ref_ops)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -94,8 +94,21 @@ def contiguous_mops_to_view(c:UOp):
|
|||
def transform_precompiled_call(c:UOp) -> UOp|None:
|
||||
if not c.arg.precompile: return None
|
||||
if c.src[0].op is Ops.SINK: return None
|
||||
out = _buffer_like(c)
|
||||
input_buffers = tuple(x.contiguous() if x.op not in {Ops.AFTER, Ops.BIND} else x for x in c.src[1:])
|
||||
# multi-output (TUPLE) precompiled calls: allocate a buffer per output and return a TUPLE of buffers
|
||||
if c.src[0].op is Ops.TUPLE:
|
||||
tuple_body = c.src[0]
|
||||
out_bufs = []
|
||||
sink_srcs = []
|
||||
for i, elem in enumerate(tuple_body.src):
|
||||
buf = UOp.new_buffer(c.device, prod(elem.max_shape), elem.dtype).reshape(elem.max_shape).shrink_to(elem.shape)
|
||||
out_bufs.append(buf)
|
||||
target = buf.param_like(len(c.src) - 1 + i).shrink_to(elem.shape)
|
||||
sink_srcs.append(target.after(target.store(elem)))
|
||||
fxn = UOp.sink(*sink_srcs)
|
||||
new_call = c.replace(src=(fxn, *input_buffers, *out_bufs), dtype=dtypes.void, tag=None)
|
||||
return UOp.maketuple(*[buf.after(new_call) for buf in out_bufs])
|
||||
out = _buffer_like(c)
|
||||
target = out.param_like(len(c.src)-1).shrink_to(c.shape)
|
||||
fxn = target.after(target.store(c.src[0])).sink()
|
||||
ret = out.after(c.replace(src=(fxn, *input_buffers, out), dtype=dtypes.void, tag=None))
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import functools
|
|||
from typing import Generic, TypeVar, Callable, cast, overload
|
||||
from tinygrad.helpers import Context, dedup, getenv
|
||||
from tinygrad.uop.ops import UOp, Ops, graph_rewrite, PatternMatcher, UPat
|
||||
from tinygrad.gradient import compute_gradient
|
||||
from tinygrad.tensor import Tensor
|
||||
|
||||
def add_to_ctx(ctx, x:UOp):
|
||||
|
|
@ -19,9 +20,10 @@ pm_ctx = PatternMatcher([
|
|||
|
||||
ReturnType = TypeVar('ReturnType')
|
||||
class _function(Generic[ReturnType]):
|
||||
def __init__(self, fxn:Callable[..., ReturnType], *, precompile:bool=False):
|
||||
def __init__(self, fxn:Callable[..., ReturnType], *, precompile:bool=False, precompile_backward:bool=False):
|
||||
self.fxn = fxn
|
||||
self.precompile = precompile
|
||||
self.precompile_backward = precompile_backward
|
||||
|
||||
def __get__(self, obj, objtype=None): return functools.partial(self.__call__, obj) if obj is not None else self
|
||||
|
||||
|
|
@ -39,12 +41,17 @@ class _function(Generic[ReturnType]):
|
|||
# run it and do surgery later
|
||||
with Context(ALLOW_DEVICE_USAGE=getenv("DEVICE_IN_FUNCTION_BUG", 0)):
|
||||
ret = self.fxn(*args, **kwargs)
|
||||
assert isinstance(ret, Tensor), "only supports one tensor return for now"
|
||||
if isinstance(ret, Tensor):
|
||||
uret = ret.uop
|
||||
elif isinstance(ret, tuple) and all(isinstance(x, Tensor) for x in ret):
|
||||
uret = UOp.maketuple(*[x.uop for x in ret])
|
||||
else:
|
||||
raise RuntimeError(f"function return type {type(ret)} not supported")
|
||||
|
||||
# replace the known inputs with params (using deduplicated slots)
|
||||
subs = {}
|
||||
for i,x in enumerate(call_uops): subs[x] = x.param_like(i)
|
||||
uret = ret.uop.substitute(subs)
|
||||
uret = uret.substitute(subs)
|
||||
|
||||
# add contiguous to call_uops
|
||||
#call_uops = [x.contiguous() for x in call_uops]
|
||||
|
|
@ -60,14 +67,77 @@ class _function(Generic[ReturnType]):
|
|||
#call = assigned.call(*call_uops, buffer, name=name)
|
||||
#ret = buffer.after(call)
|
||||
|
||||
ret = uret.call(*call_uops, name=name, precompile=self.precompile)
|
||||
return cast(ReturnType, Tensor(ret, device=ret.device))
|
||||
# precompute the backward: determine which inputs need gradients and build the backward CALL body
|
||||
grad_fxn = None
|
||||
inputs = [t for t in list(args)+[kwargs[k] for k in sorted(kwargs)] if isinstance(t, (Tensor, UOp))]
|
||||
grad_params = {x.arg:x for x in uret.toposort(enter_calls=False) if x.op == Ops.PARAM}
|
||||
# find which param slots correspond to requires_grad inputs
|
||||
need_grad = {i for i, t in enumerate(inputs) if isinstance(t, Tensor) and t.requires_grad}
|
||||
target_params = {grad_params[i] for i in need_grad if i in grad_params}
|
||||
if target_params:
|
||||
grad_fxn = self._make_grad_fxn(uret, len(call_uops), target_params, need_grad, name, self.precompile_backward)
|
||||
|
||||
fret = uret.call(*call_uops, name=name, precompile=self.precompile, precompile_backward=self.precompile_backward, grad_fxn=grad_fxn)
|
||||
if isinstance(ret, tuple):
|
||||
return cast(ReturnType, tuple(Tensor(fret.gettuple(i), device=fret.device) for i in range(len(ret))))
|
||||
else:
|
||||
return cast(ReturnType, Tensor(fret, device=fret.device))
|
||||
|
||||
@staticmethod
|
||||
def _make_grad_fxn(uret:UOp, num_args:int, target_params:set[UOp], need_grad:set[int], name:str, precompile_backward:bool):
|
||||
def grad_fxn(ctx, k):
|
||||
fxn, args = k.src[0], k.src[1:]
|
||||
params = {x.arg:x for x in fxn.toposort(enter_calls=False) if x.op == Ops.PARAM}
|
||||
# compute gradients only for needed params
|
||||
if isinstance(ctx, dict):
|
||||
all_grads: dict[UOp, UOp] = {}
|
||||
for idx, grad_out in ctx.items():
|
||||
elem_grads = compute_gradient(fxn.src[idx], grad_out.param_like(len(args) + idx), target_params)
|
||||
for p, g in elem_grads.items():
|
||||
if p in all_grads: all_grads[p] = all_grads[p] + g
|
||||
else: all_grads[p] = g
|
||||
grads = all_grads
|
||||
grad_ctx_inputs = tuple(ctx.get(i, fxn.src[i].const_like(0)) for i in range(len(fxn.src)))
|
||||
else:
|
||||
grads = compute_gradient(fxn, ctx.param_like(len(args)), target_params)
|
||||
grad_ctx_inputs = (ctx,)
|
||||
# collect gradients for needed params only
|
||||
grad_indices: list[int] = []
|
||||
grad_uops: list[UOp] = []
|
||||
for i in range(len(args)):
|
||||
if i in need_grad and (p:=params.get(i)) is not None and p in grads:
|
||||
grad_indices.append(i)
|
||||
grad_uops.append(grads[p])
|
||||
if len(grad_uops) == 0: return (None,) * len(args)
|
||||
# replace forward output references with PARAMs to avoid recomputation
|
||||
if fxn.op is Ops.TUPLE:
|
||||
fwd_subs = {elem: elem.param_like(len(args) + len(grad_ctx_inputs) + i) for i, elem in enumerate(fxn.src)}
|
||||
fwd_inputs = tuple(k.gettuple(i) for i in range(len(fxn.src)))
|
||||
else:
|
||||
fwd_subs = {fxn: fxn.param_like(len(args) + 1)}
|
||||
fwd_inputs = (k,)
|
||||
grad_uops = [g.substitute(fwd_subs) for g in grad_uops]
|
||||
# build a single backward CALL returning a TUPLE of all gradients
|
||||
bwd_body = UOp.maketuple(*grad_uops)
|
||||
bwd_call = bwd_body.call(*args, *grad_ctx_inputs, *fwd_inputs, name=name+"_backward", precompile=precompile_backward)
|
||||
# extract each gradient via GETTUPLE
|
||||
ret: list[UOp|None] = []
|
||||
gi = 0
|
||||
for i in range(len(args)):
|
||||
if gi < len(grad_indices) and grad_indices[gi] == i:
|
||||
ret.append(bwd_call.gettuple(gi))
|
||||
gi += 1
|
||||
else:
|
||||
ret.append(None)
|
||||
return tuple(ret)
|
||||
return grad_fxn
|
||||
|
||||
# overload signatures support both @function and @function(precompile=True) syntax
|
||||
@overload
|
||||
def function(fxn:Callable[..., ReturnType], *, precompile:bool=False) -> _function[ReturnType]: ...
|
||||
def function(fxn:Callable[..., ReturnType], *, precompile:bool=False, precompile_backward:bool=False) -> _function[ReturnType]: ...
|
||||
@overload
|
||||
def function(fxn:None=None, *, precompile:bool=False) -> Callable[[Callable[..., ReturnType]], _function[ReturnType]]: ...
|
||||
def function(fxn=None, *, precompile:bool=False):
|
||||
if fxn is None: return lambda f: _function(f, precompile=precompile)
|
||||
return _function(fxn, precompile=precompile)
|
||||
def function(fxn:None=None, *, precompile:bool=False, precompile_backward:bool=False) -> \
|
||||
Callable[[Callable[..., ReturnType]], _function[ReturnType]]: ...
|
||||
def function(fxn=None, *, precompile:bool=False, precompile_backward:bool=False):
|
||||
if fxn is None: return lambda f: _function(f, precompile=precompile, precompile_backward=precompile_backward)
|
||||
return _function(fxn, precompile=precompile, precompile_backward=precompile_backward)
|
||||
|
|
|
|||
|
|
@ -13,18 +13,53 @@ def reduce_gradient(ctx:UOp, ret:UOp, op:Ops):
|
|||
return ((mask/broadcast_to_input(count)) * broadcast_to_input(ctx),)
|
||||
if op == Ops.MUL: return (broadcast_to_input(ctx * ret) / ret.src[0],)
|
||||
|
||||
def call_gradient(ctx:UOp, k:UOp) -> tuple[UOp|None, ...]:
|
||||
if k.arg.grad_fxn is not None: return (None,) + k.arg.grad_fxn(ctx, k)
|
||||
def call_gradient(ctx:UOp|dict[int, UOp], k:UOp) -> tuple[UOp|None, ...]:
|
||||
if k.arg.grad_fxn is not None:
|
||||
if isinstance(ctx, dict): return (None,) + k.arg.grad_fxn(ctx, k)
|
||||
return (None,) + k.arg.grad_fxn(ctx, k)
|
||||
# auto-differentiate the function
|
||||
fxn, args = k.src[0], k.src[1:]
|
||||
params = {x.arg:x for x in fxn.toposort(enter_calls=False) if x.op == Ops.PARAM}
|
||||
grads = compute_gradient(fxn, ctx.param_like(len(args)), set(params.values()))
|
||||
ret: list[UOp|None] = [None]
|
||||
# for tuple-returning CALLs, ctx is a dict mapping output index to gradient; differentiate each output separately and sum
|
||||
if isinstance(ctx, dict):
|
||||
all_grads: dict[UOp, UOp] = {}
|
||||
for idx, grad_out in ctx.items():
|
||||
elem_grads = compute_gradient(fxn.src[idx], grad_out.param_like(len(args) + idx), set(params.values()))
|
||||
for p, g in elem_grads.items():
|
||||
if p in all_grads: all_grads[p] = all_grads[p] + g
|
||||
else: all_grads[p] = g
|
||||
grads = all_grads
|
||||
grad_ctx_inputs = tuple(ctx.get(i, fxn.src[i].const_like(0)) for i in range(len(fxn.src)))
|
||||
else:
|
||||
grads = compute_gradient(fxn, ctx.param_like(len(args)), set(params.values()))
|
||||
grad_ctx_inputs = (ctx,)
|
||||
# collect which args have gradients
|
||||
grad_indices: list[int] = []
|
||||
grad_uops: list[UOp] = []
|
||||
for i in range(len(args)):
|
||||
if (p:=params.get(i, None)) is not None and p in grads:
|
||||
# TODO: compact the args and remove unused ones
|
||||
assert not grads[p].op_in_backward_slice_with_self(Ops.BUFFER), "BUG: BUFFER in backward slice of grad"
|
||||
ret.append(grads[p].call(*args, ctx, name=(k.arg.name or "")+f"_backward_{i}"))
|
||||
grad_indices.append(i)
|
||||
grad_uops.append(grads[p])
|
||||
if len(grad_uops) == 0: return (None,) * (len(args) + 1)
|
||||
# replace forward output references with PARAMs, passing the forward CALL output(s) as inputs to avoid recomputation
|
||||
if fxn.op is Ops.TUPLE:
|
||||
fwd_subs = {elem: elem.param_like(len(args) + len(grad_ctx_inputs) + i) for i, elem in enumerate(fxn.src)}
|
||||
fwd_inputs = tuple(k.gettuple(i) for i in range(len(fxn.src)))
|
||||
else:
|
||||
fwd_subs = {fxn: fxn.param_like(len(args) + 1)}
|
||||
fwd_inputs = (k,)
|
||||
grad_uops = [g.substitute(fwd_subs) for g in grad_uops]
|
||||
# build a single backward CALL returning a TUPLE of all gradients
|
||||
bwd_body = UOp.maketuple(*grad_uops)
|
||||
bwd_call = bwd_body.call(*args, *grad_ctx_inputs, *fwd_inputs, name=(k.arg.name or "")+"_backward", precompile=k.arg.precompile_backward)
|
||||
# extract each gradient via GETTUPLE
|
||||
ret: list[UOp|None] = [None]
|
||||
gi = 0
|
||||
for i in range(len(args)):
|
||||
if gi < len(grad_indices) and grad_indices[gi] == i:
|
||||
ret.append(bwd_call.gettuple(gi))
|
||||
gi += 1
|
||||
else:
|
||||
ret.append(None)
|
||||
return tuple(ret)
|
||||
|
|
@ -72,11 +107,26 @@ def _deepwalk(root:UOp, targets:set[UOp]) -> list[UOp]:
|
|||
return list(root.toposort(lambda node: node.op not in {Ops.DETACH, Ops.ASSIGN} and in_target_path[node]))
|
||||
|
||||
def compute_gradient(root:UOp, root_grad:UOp, targets:set[UOp]) -> dict[UOp, UOp]:
|
||||
grads = {root: root_grad}
|
||||
grads: dict[UOp, UOp] = {root: root_grad}
|
||||
# for GETTUPLE nodes on tuple-returning CALLs, collect per-output gradients
|
||||
tuple_call_grads: dict[UOp, dict[int, UOp]] = {}
|
||||
for t0 in reversed(_deepwalk(root, targets)):
|
||||
if t0 not in grads: continue
|
||||
lgrads: tuple[UOp|None, ...]|None = cast(tuple[UOp|None, ...]|None, pm_gradient.rewrite(t0, ctx=grads[t0]))
|
||||
if lgrads is None: raise RuntimeError(f"failed to compute gradient for {t0.op}\n\nin {str(t0)[0:1000]}...")
|
||||
if t0 not in grads and t0 not in tuple_call_grads: continue
|
||||
# for tuple-returning CALLs, use accumulated per-output gradients
|
||||
if t0.op is Ops.CALL and t0 in tuple_call_grads:
|
||||
lgrads = cast(tuple[UOp|None, ...], call_gradient(tuple_call_grads[t0], t0))
|
||||
elif t0 not in grads:
|
||||
continue
|
||||
# for GETTUPLE on a CALL, accumulate gradient per output index instead of propagating to CALL directly
|
||||
elif t0.op is Ops.GETTUPLE and t0.src[0].op is Ops.CALL:
|
||||
call = t0.src[0]
|
||||
if call not in tuple_call_grads: tuple_call_grads[call] = {}
|
||||
if t0.arg in tuple_call_grads[call]: tuple_call_grads[call][t0.arg] = tuple_call_grads[call][t0.arg] + grads[t0]
|
||||
else: tuple_call_grads[call][t0.arg] = grads[t0]
|
||||
continue
|
||||
else:
|
||||
lgrads = cast(tuple[UOp|None, ...]|None, pm_gradient.rewrite(t0, ctx=grads[t0]))
|
||||
if lgrads is None: raise RuntimeError(f"failed to compute gradient for {t0.op}\n\nin {str(t0)[0:1000]}...")
|
||||
assert len(lgrads) == len(t0.src), f"got {len(lgrads)} gradient, expected {len(t0.src)}"
|
||||
for k,v in zip(t0.src, lgrads):
|
||||
if v is None: continue
|
||||
|
|
|
|||
|
|
@ -137,6 +137,9 @@ earliest_rewrites = mop_cleanup+PatternMatcher([
|
|||
# resolve calls
|
||||
(UPat(Ops.CALL, name="c"), resolve_call),
|
||||
|
||||
# resolve TUPLE+GETTUPLE
|
||||
(UPat(Ops.GETTUPLE, src=(UPat(Ops.TUPLE, name="t"),), name="g"), lambda g,t: t.src[g.arg]),
|
||||
|
||||
# resolve allreduce (must be bottom up)
|
||||
(UPat(Ops.ALLREDUCE, src=(UPat.var("buf"), UPat()), name="red"), create_allreduce_function),
|
||||
|
||||
|
|
|
|||
|
|
@ -39,6 +39,9 @@ class Ops(FastEnum):
|
|||
# vector creation / item selection
|
||||
GEP = auto(); VECTORIZE = auto()
|
||||
|
||||
# tuple/gettuple for function with multiple returns
|
||||
TUPLE = auto(); GETTUPLE = auto()
|
||||
|
||||
# ** 3 -- load/store **
|
||||
|
||||
# INDEX is a BinaryOp similar to ADD, but it operates on pointers
|
||||
|
|
|
|||
|
|
@ -209,9 +209,15 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
# late ops don't have shape
|
||||
case Ops.UNIQUE | Ops.LUNIQUE | Ops.DEVICE | Ops.RANGE | Ops.LOAD | Ops.STORE | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \
|
||||
Ops.VECTORIZE | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.CONTRACT | Ops.SINK | \
|
||||
Ops.LINEAR | Ops.PROGRAM | Ops.SOURCE | Ops.BINARY | Ops.INS:
|
||||
Ops.LINEAR | Ops.PROGRAM | Ops.SOURCE | Ops.BINARY | Ops.INS | Ops.TUPLE:
|
||||
return None
|
||||
|
||||
case Ops.GETTUPLE:
|
||||
# GETTUPLE extracts from a TUPLE
|
||||
in_tuple = self.src[0].src[0] if self.src[0].op is Ops.CALL else self.src[0]
|
||||
assert in_tuple.op is Ops.TUPLE
|
||||
return in_tuple.src[self.arg]._shape
|
||||
|
||||
case Ops.CAST:
|
||||
# when PTX casts from ptr to non ptr, remove the shape
|
||||
if isinstance(self.src[0].dtype, PtrDType) and not isinstance(self.src[0].dtype, ImageDType) and not isinstance(self.dtype, PtrDType):
|
||||
|
|
@ -404,6 +410,12 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
|
||||
def sink(*srcs:UOp|None, **kwargs): # pylint: disable=no-self-argument
|
||||
return UOp(Ops.SINK, dtypes.void, tuple([x for x in srcs if x is not None]), **kwargs)
|
||||
def maketuple(*srcs:UOp): # pylint: disable=no-self-argument
|
||||
return UOp(Ops.TUPLE, dtypes.void, srcs)
|
||||
def gettuple(self, idx:int) -> UOp:
|
||||
in_tuple = self.src[0] if self.op is Ops.CALL else self
|
||||
assert in_tuple.op is Ops.TUPLE, f"gettuple requires CALL or TUPLE source, got {self.op}"
|
||||
return UOp(Ops.GETTUPLE, in_tuple.src[idx].dtype, (self,), idx)
|
||||
def group(*srcs:UOp|None): # pylint: disable=no-self-argument
|
||||
if len(srcs) == 1 and isinstance(srcs[0], UOp): return srcs[0]
|
||||
return UOp(Ops.GROUP, dtypes.void, tuple([x for x in srcs if x is not None]))
|
||||
|
|
@ -903,9 +915,10 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
if self.axis is not None: p = p.replace(src=p.src + (UOp(Ops.MULTI, arg=self.axis),))
|
||||
return p
|
||||
|
||||
def call(self, *srcs:UOp, grad_fxn:Callable|None=None, metadata:tuple[Metadata, ...]=(), name:str|None=None, precompile:bool=False) -> UOp:
|
||||
def call(self, *srcs:UOp, grad_fxn:Callable|None=None, metadata:tuple[Metadata, ...]=(),
|
||||
name:str|None=None, precompile:bool=False, precompile_backward:bool=False) -> UOp:
|
||||
assert len(self.ranges) == 0, f"ranges {self.ranges} are leaking out of the call in {self.pyrender()}"
|
||||
return UOp(Ops.CALL, self.dtype, (self,)+srcs, CallInfo(grad_fxn, metadata, name, precompile))
|
||||
return UOp(Ops.CALL, self.dtype, (self,)+srcs, CallInfo(grad_fxn, metadata, name, precompile, precompile_backward))
|
||||
def custom_kernel(*srcs:UOp, fxn:Callable, grad_fxn:Callable|None=None) -> list[UOp]:
|
||||
contig_srcs = tuple(x.contiguous() if x.op is not Ops.AFTER else x for x in srcs)
|
||||
placeholders = [UOp.placeholder_like(s, slot=i) for i,s in enumerate(contig_srcs)]
|
||||
|
|
@ -929,9 +942,12 @@ class CallInfo:
|
|||
metadata: tuple[Metadata, ...] = ()
|
||||
name: str|None = None
|
||||
precompile: bool = False
|
||||
precompile_backward: bool = False
|
||||
# grad_fxn can't be pickled, but metadata can
|
||||
def __reduce__(self): return (CallInfo, (None, self.metadata, self.name, self.precompile))
|
||||
def __repr__(self): return f"CallInfo({id(self.grad_fxn) if self.grad_fxn else None}, {self.metadata}, {repr(self.name)}, {self.precompile})"
|
||||
def __reduce__(self): return (CallInfo, (None, self.metadata, self.name, self.precompile, self.precompile_backward))
|
||||
def __repr__(self):
|
||||
gf = id(self.grad_fxn) if self.grad_fxn else None
|
||||
return f"CallInfo({gf}, {self.metadata}, {repr(self.name)}, {self.precompile}, {self.precompile_backward})"
|
||||
|
||||
def should_resolve_call(c:UOp) -> bool:
|
||||
# don't resolve real kernel calls, sink or program
|
||||
|
|
|
|||
|
|
@ -139,6 +139,10 @@ _tensor_spec = PatternMatcher([
|
|||
(UPat(Ops.PARAM), lambda: True),
|
||||
(UPat(Ops.CUSTOM_FUNCTION, name="x"), lambda x: isinstance(x.arg, str)),
|
||||
|
||||
# TUPLE must have void dtype, GETTUPLE can only appear on CALL or TUPLE
|
||||
(UPat(Ops.TUPLE, dtypes.void), lambda: True),
|
||||
(UPat(Ops.GETTUPLE, src=(UPat((Ops.CALL, Ops.TUPLE)),), name="g"), lambda g: isinstance(g.arg, int)),
|
||||
|
||||
# ** for custom kernels **
|
||||
|
||||
# codegen: PROGRAM with progressive sources through the pipeline (SINK, DEVICE, LINEAR?, SOURCE?, BINARY?)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue