mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
4 commits
master
...
call_inlin
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2fe45b0660 |
||
|
|
1d88723aa0 |
||
|
|
b0dd3af093 | ||
|
|
e89221e9aa |
5 changed files with 17 additions and 16 deletions
|
|
@ -13,7 +13,7 @@ class TestCall(unittest.TestCase):
|
||||||
# we define a plus function
|
# we define a plus function
|
||||||
plus_fxn = UOp.param(0, dtypes.float, (10,10)) + UOp.param(1, dtypes.float, (10,10))
|
plus_fxn = UOp.param(0, dtypes.float, (10,10)) + UOp.param(1, dtypes.float, (10,10))
|
||||||
|
|
||||||
c = Tensor.call(a, b, fxn=plus_fxn)
|
c = Tensor.call(a, b, fxn=plus_fxn, inline=True)
|
||||||
np.testing.assert_equal(c.numpy(), (a+b).numpy())
|
np.testing.assert_equal(c.numpy(), (a+b).numpy())
|
||||||
|
|
||||||
def test_call_plus_backward(self):
|
def test_call_plus_backward(self):
|
||||||
|
|
@ -30,7 +30,7 @@ class TestCall(unittest.TestCase):
|
||||||
|
|
||||||
# we define a plus function
|
# we define a plus function
|
||||||
plus_fxn = UOp.param(0, dtypes.float, (10,10)) + UOp.param(1, dtypes.float, (10,10))
|
plus_fxn = UOp.param(0, dtypes.float, (10,10)) + UOp.param(1, dtypes.float, (10,10))
|
||||||
c = Tensor.call(a, b, fxn=plus_fxn, grad_fxn=grad_fxn)
|
c = Tensor.call(a, b, fxn=plus_fxn, grad_fxn=grad_fxn, inline=True)
|
||||||
c.mean().backward()
|
c.mean().backward()
|
||||||
|
|
||||||
np.testing.assert_allclose(a.grad.numpy(), gt_a_grad, rtol=1e-5)
|
np.testing.assert_allclose(a.grad.numpy(), gt_a_grad, rtol=1e-5)
|
||||||
|
|
@ -46,7 +46,7 @@ class TestCall(unittest.TestCase):
|
||||||
a.grad, b.grad = None, None
|
a.grad, b.grad = None, None
|
||||||
|
|
||||||
plus_fxn = UOp.param(0, dtypes.float, (10,10)) + UOp.param(1, dtypes.float, (10,10))
|
plus_fxn = UOp.param(0, dtypes.float, (10,10)) + UOp.param(1, dtypes.float, (10,10))
|
||||||
c = Tensor.call(a, b, fxn=plus_fxn)
|
c = Tensor.call(a, b, fxn=plus_fxn, inline=True)
|
||||||
c.mean().backward()
|
c.mean().backward()
|
||||||
|
|
||||||
np.testing.assert_allclose(a.grad.numpy(), gt_a_grad, rtol=1e-5)
|
np.testing.assert_allclose(a.grad.numpy(), gt_a_grad, rtol=1e-5)
|
||||||
|
|
@ -57,7 +57,7 @@ class TestCall(unittest.TestCase):
|
||||||
a = Tensor.randn(M, K)
|
a = Tensor.randn(M, K)
|
||||||
b = Tensor.randn(K, N)
|
b = Tensor.randn(K, N)
|
||||||
Tensor.realize(a, b)
|
Tensor.realize(a, b)
|
||||||
c = Tensor.call(a, b, fxn=a.as_param(0) @ b.as_param(1))
|
c = Tensor.call(a, b, fxn=a.as_param(0) @ b.as_param(1), inline=True)
|
||||||
np.testing.assert_allclose(c.numpy(), a.numpy() @ b.numpy(), rtol=1e-5, atol=1e-6)
|
np.testing.assert_allclose(c.numpy(), a.numpy() @ b.numpy(), rtol=1e-5, atol=1e-6)
|
||||||
|
|
||||||
@unittest.skip("needs GEMM on mixins")
|
@unittest.skip("needs GEMM on mixins")
|
||||||
|
|
@ -70,7 +70,7 @@ class TestCall(unittest.TestCase):
|
||||||
# we define a gemm function
|
# we define a gemm function
|
||||||
x = UOp.param(0, dtypes.float, shape=(M, K))
|
x = UOp.param(0, dtypes.float, shape=(M, K))
|
||||||
y = UOp.param(1, dtypes.float, shape=(K, N))
|
y = UOp.param(1, dtypes.float, shape=(K, N))
|
||||||
c = Tensor.call(a, b, fxn=x@y)
|
c = Tensor.call(a, b, fxn=x@y, inline=True)
|
||||||
|
|
||||||
np.testing.assert_allclose(c.numpy(), a.numpy() @ b.numpy(), rtol=1e-5, atol=1e-6)
|
np.testing.assert_allclose(c.numpy(), a.numpy() @ b.numpy(), rtol=1e-5, atol=1e-6)
|
||||||
|
|
||||||
|
|
@ -86,7 +86,7 @@ class TestCall(unittest.TestCase):
|
||||||
|
|
||||||
p0, p1 = UOp.param(0, dtypes.float, (10,10)), UOp.param(1, dtypes.float, (10,10))
|
p0, p1 = UOp.param(0, dtypes.float, (10,10)), UOp.param(1, dtypes.float, (10,10))
|
||||||
complex_fxn = (p0*p1 + p0).exp2() * p1.reciprocal()
|
complex_fxn = (p0*p1 + p0).exp2() * p1.reciprocal()
|
||||||
c = Tensor.call(a, b, fxn=complex_fxn)
|
c = Tensor.call(a, b, fxn=complex_fxn, inline=True)
|
||||||
c.mean().backward()
|
c.mean().backward()
|
||||||
|
|
||||||
np.testing.assert_allclose(a.grad.numpy(), gt_a_grad, rtol=1e-5)
|
np.testing.assert_allclose(a.grad.numpy(), gt_a_grad, rtol=1e-5)
|
||||||
|
|
|
||||||
|
|
@ -360,7 +360,8 @@ class Embedding:
|
||||||
|
|
||||||
def __call__(self, idx:Tensor) -> Tensor:
|
def __call__(self, idx:Tensor) -> Tensor:
|
||||||
if not dtypes.is_int(idx.dtype): raise TypeError(f"Expected integer dtype for index in embedding, got {idx.dtype}")
|
if not dtypes.is_int(idx.dtype): raise TypeError(f"Expected integer dtype for index in embedding, got {idx.dtype}")
|
||||||
if USE_ATOMICS: return Tensor.call(self.weight, idx, fxn=_embedding_fwd(self.weight.as_param(0), idx.as_param(1)), grad_fxn=_embedding_bwd)
|
if USE_ATOMICS:
|
||||||
|
return Tensor.call(self.weight, idx, fxn=_embedding_fwd(self.weight.as_param(0), idx.as_param(1)), grad_fxn=_embedding_bwd, inline=True)
|
||||||
return _embedding_fwd(self.weight, idx)
|
return _embedding_fwd(self.weight, idx)
|
||||||
|
|
||||||
class LSTMCell:
|
class LSTMCell:
|
||||||
|
|
|
||||||
|
|
@ -79,9 +79,8 @@ mop_cleanup = PatternMatcher([
|
||||||
])
|
])
|
||||||
|
|
||||||
def resolve_call(c:UOp) -> UOp|None:
|
def resolve_call(c:UOp) -> UOp|None:
|
||||||
# don't resolve real kernel calls, sink or program
|
# we only resolve here if the call is inlined
|
||||||
if c.src[0].op is Ops.SINK and isinstance(c.src[0].arg, KernelInfo): return None
|
if not c.arg.inline: return None
|
||||||
if c.src[0].op is Ops.PROGRAM: return None
|
|
||||||
params = sorted([x for x in c.src[0].toposort() if x.op == Ops.PARAM], key=lambda x: x.arg)
|
params = sorted([x for x in c.src[0].toposort() if x.op == Ops.PARAM], key=lambda x: x.arg)
|
||||||
args = c.src[1:]
|
args = c.src[1:]
|
||||||
# TODO: this check belongs in spec, not here
|
# TODO: this check belongs in spec, not here
|
||||||
|
|
|
||||||
|
|
@ -239,8 +239,8 @@ class Tensor(OpMixin):
|
||||||
else:
|
else:
|
||||||
param = UOp.param(slot, self.dtype, self.shape, self.device)
|
param = UOp.param(slot, self.dtype, self.shape, self.device)
|
||||||
return Tensor(param, device=self.device)
|
return Tensor(param, device=self.device)
|
||||||
def call(self, *lst:Tensor, fxn:Tensor|UOp, grad_fxn:Callable|None=None) -> Tensor:
|
def call(self, *lst:Tensor, fxn:Tensor|UOp, **kwargs) -> Tensor:
|
||||||
return Tensor((fxn.uop if isinstance(fxn, Tensor) else fxn).call(*[t.uop for t in (self,)+lst], grad_fxn=grad_fxn), device=self.device)
|
return Tensor((fxn.uop if isinstance(fxn, Tensor) else fxn).call(*[t.uop for t in (self,)+lst], **kwargs), device=self.device)
|
||||||
|
|
||||||
def custom_kernel(self, *lst:Tensor, fxn:Callable, grad_fxn:Callable|None=None) -> list[Tensor]:
|
def custom_kernel(self, *lst:Tensor, fxn:Callable, grad_fxn:Callable|None=None) -> list[Tensor]:
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -823,10 +823,10 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||||
src = (UOp(Ops.NOOP) if shape is None else shape_to_shape_arg(shape),) + (() if device is None else (UOp(Ops.DEVICE, arg=device),))
|
src = (UOp(Ops.NOOP) if shape is None else shape_to_shape_arg(shape),) + (() if device is None else (UOp(Ops.DEVICE, arg=device),))
|
||||||
return UOp(Ops.PARAM, dtype, src, arg=slot)
|
return UOp(Ops.PARAM, dtype, src, arg=slot)
|
||||||
|
|
||||||
def call(self, *srcs:UOp, grad_fxn:Callable|None=None, metadata:tuple[Metadata, ...]=()) -> UOp:
|
def call(self, *srcs:UOp, grad_fxn:Callable|None=None, metadata:tuple[Metadata, ...]=(), inline=False) -> UOp:
|
||||||
# TODO: reenable this after ENCDEC is fixed
|
# TODO: reenable this after ENCDEC is fixed
|
||||||
#assert len(self.ranges) == 0, f"ranges {self.ranges} are leaking out of the call in {self.pyrender()}"
|
#assert len(self.ranges) == 0, f"ranges {self.ranges} are leaking out of the call in {self.pyrender()}"
|
||||||
return UOp(Ops.CALL, self.dtype, (self,)+srcs, CallInfo(grad_fxn, metadata))
|
return UOp(Ops.CALL, self.dtype, (self,)+srcs, CallInfo(grad_fxn, metadata, inline))
|
||||||
def custom_kernel(*srcs:UOp, fxn:Callable, grad_fxn:Callable|None=None) -> list[UOp]:
|
def custom_kernel(*srcs:UOp, fxn:Callable, grad_fxn:Callable|None=None) -> list[UOp]:
|
||||||
contig_srcs = tuple(x.contiguous() if x.op is not Ops.AFTER else x for x in srcs)
|
contig_srcs = tuple(x.contiguous() if x.op is not Ops.AFTER else x for x in srcs)
|
||||||
placeholders = [UOp.placeholder_like(s, slot=i) for i,s in enumerate(contig_srcs)]
|
placeholders = [UOp.placeholder_like(s, slot=i) for i,s in enumerate(contig_srcs)]
|
||||||
|
|
@ -848,9 +848,10 @@ class KernelInfo:
|
||||||
class CallInfo:
|
class CallInfo:
|
||||||
grad_fxn: Callable|None = None
|
grad_fxn: Callable|None = None
|
||||||
metadata: tuple[Metadata, ...] = ()
|
metadata: tuple[Metadata, ...] = ()
|
||||||
|
inline: bool = False
|
||||||
# grad_fxn can't be pickled, but metadata can
|
# grad_fxn can't be pickled, but metadata can
|
||||||
def __reduce__(self): return (CallInfo, (None, self.metadata))
|
def __reduce__(self): return (CallInfo, (None, self.metadata, self.inline))
|
||||||
def __repr__(self): return f"CallInfo({id(self.grad_fxn) if self.grad_fxn else None}, {self.metadata})"
|
def __repr__(self): return f"CallInfo({id(self.grad_fxn) if self.grad_fxn else None}, {self.metadata}, {self.inline})"
|
||||||
|
|
||||||
# ******** ops in python ********
|
# ******** ops in python ********
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue