mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
2 commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2a0e7491f7 | ||
|
|
e35f7a0777 |
3 changed files with 15 additions and 2 deletions
|
|
@ -60,6 +60,16 @@ class TestCall(unittest.TestCase):
|
||||||
c = Tensor.call(a, b, fxn=a.as_param(0) @ b.as_param(1))
|
c = Tensor.call(a, b, fxn=a.as_param(0) @ b.as_param(1))
|
||||||
np.testing.assert_allclose(c.numpy(), a.numpy() @ b.numpy(), rtol=1e-5, atol=1e-6)
|
np.testing.assert_allclose(c.numpy(), a.numpy() @ b.numpy(), rtol=1e-5, atol=1e-6)
|
||||||
|
|
||||||
|
def test_call_gemm_assign(self):
|
||||||
|
M, K, N = 4, 8, 4
|
||||||
|
a = Tensor.randn(M, K)
|
||||||
|
b = Tensor.randn(K, N)
|
||||||
|
c = Tensor.empty(M, N)
|
||||||
|
Tensor.realize(a, b)
|
||||||
|
sink = c.as_param(0).assign(a.as_param(1) @ b.as_param(2)).sink()
|
||||||
|
new_c = c.after(Tensor.call(c, a, b, fxn=sink))
|
||||||
|
np.testing.assert_allclose(new_c.numpy(), a.numpy() @ b.numpy(), rtol=1e-5, atol=1e-6)
|
||||||
|
|
||||||
@unittest.skip("needs GEMM on mixins")
|
@unittest.skip("needs GEMM on mixins")
|
||||||
def test_call_gemm_uop(self):
|
def test_call_gemm_uop(self):
|
||||||
M, K, N = 4, 8, 4
|
M, K, N = 4, 8, 4
|
||||||
|
|
|
||||||
|
|
@ -80,8 +80,7 @@ mop_cleanup = PatternMatcher([
|
||||||
|
|
||||||
def resolve_call(c:UOp) -> UOp|None:
|
def resolve_call(c:UOp) -> UOp|None:
|
||||||
# don't resolve real kernel calls, sink or program
|
# don't resolve real kernel calls, sink or program
|
||||||
if c.src[0].op is Ops.SINK and isinstance(c.src[0].arg, KernelInfo): return None
|
if c.src[0].op in {Ops.SINK, Ops.PROGRAM}: return None
|
||||||
if c.src[0].op is Ops.PROGRAM: return None
|
|
||||||
params = sorted([x for x in c.src[0].toposort() if x.op == Ops.PARAM], key=lambda x: x.arg)
|
params = sorted([x for x in c.src[0].toposort() if x.op == Ops.PARAM], key=lambda x: x.arg)
|
||||||
args = c.src[1:]
|
args = c.src[1:]
|
||||||
# TODO: this check belongs in spec, not here
|
# TODO: this check belongs in spec, not here
|
||||||
|
|
|
||||||
|
|
@ -242,6 +242,10 @@ class Tensor(OpMixin):
|
||||||
def call(self, *lst:Tensor, fxn:Tensor|UOp, grad_fxn:Callable|None=None) -> Tensor:
|
def call(self, *lst:Tensor, fxn:Tensor|UOp, grad_fxn:Callable|None=None) -> Tensor:
|
||||||
return Tensor((fxn.uop if isinstance(fxn, Tensor) else fxn).call(*[t.uop for t in (self,)+lst], grad_fxn=grad_fxn), device=self.device)
|
return Tensor((fxn.uop if isinstance(fxn, Tensor) else fxn).call(*[t.uop for t in (self,)+lst], grad_fxn=grad_fxn), device=self.device)
|
||||||
|
|
||||||
|
# TODO: Tensor should just proxy everything to UOp
|
||||||
|
def after(self, *lst:Tensor): return Tensor(self.uop.after(*[t.uop for t in lst]), device=self.device)
|
||||||
|
def sink(self): return Tensor(self.uop.sink(), device=self.device)
|
||||||
|
|
||||||
def custom_kernel(self, *lst:Tensor, fxn:Callable, grad_fxn:Callable|None=None) -> list[Tensor]:
|
def custom_kernel(self, *lst:Tensor, fxn:Callable, grad_fxn:Callable|None=None) -> list[Tensor]:
|
||||||
"""
|
"""
|
||||||
Call into a custom kernel written in UOps. Returns the Tensors after the Kernel has been applied.
|
Call into a custom kernel written in UOps. Returns the Tensors after the Kernel has been applied.
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue