Compare commits

...

2 commits

Author SHA1 Message Date
George Hotz
2a0e7491f7 gemm assign 2026-02-10 17:57:14 +08:00
George Hotz
e35f7a0777 all sink/program calls aren't sub in graph 2026-02-10 17:38:51 +08:00
3 changed files with 15 additions and 2 deletions

View file

@ -60,6 +60,16 @@ class TestCall(unittest.TestCase):
c = Tensor.call(a, b, fxn=a.as_param(0) @ b.as_param(1)) c = Tensor.call(a, b, fxn=a.as_param(0) @ b.as_param(1))
np.testing.assert_allclose(c.numpy(), a.numpy() @ b.numpy(), rtol=1e-5, atol=1e-6) np.testing.assert_allclose(c.numpy(), a.numpy() @ b.numpy(), rtol=1e-5, atol=1e-6)
def test_call_gemm_assign(self):
M, K, N = 4, 8, 4
a = Tensor.randn(M, K)
b = Tensor.randn(K, N)
c = Tensor.empty(M, N)
Tensor.realize(a, b)
sink = c.as_param(0).assign(a.as_param(1) @ b.as_param(2)).sink()
new_c = c.after(Tensor.call(c, a, b, fxn=sink))
np.testing.assert_allclose(new_c.numpy(), a.numpy() @ b.numpy(), rtol=1e-5, atol=1e-6)
@unittest.skip("needs GEMM on mixins") @unittest.skip("needs GEMM on mixins")
def test_call_gemm_uop(self): def test_call_gemm_uop(self):
M, K, N = 4, 8, 4 M, K, N = 4, 8, 4

View file

@ -80,8 +80,7 @@ mop_cleanup = PatternMatcher([
def resolve_call(c:UOp) -> UOp|None: def resolve_call(c:UOp) -> UOp|None:
# don't resolve real kernel calls, sink or program # don't resolve real kernel calls, sink or program
if c.src[0].op is Ops.SINK and isinstance(c.src[0].arg, KernelInfo): return None if c.src[0].op in {Ops.SINK, Ops.PROGRAM}: return None
if c.src[0].op is Ops.PROGRAM: return None
params = sorted([x for x in c.src[0].toposort() if x.op == Ops.PARAM], key=lambda x: x.arg) params = sorted([x for x in c.src[0].toposort() if x.op == Ops.PARAM], key=lambda x: x.arg)
args = c.src[1:] args = c.src[1:]
# TODO: this check belongs in spec, not here # TODO: this check belongs in spec, not here

View file

@ -242,6 +242,10 @@ class Tensor(OpMixin):
def call(self, *lst:Tensor, fxn:Tensor|UOp, grad_fxn:Callable|None=None) -> Tensor: def call(self, *lst:Tensor, fxn:Tensor|UOp, grad_fxn:Callable|None=None) -> Tensor:
return Tensor((fxn.uop if isinstance(fxn, Tensor) else fxn).call(*[t.uop for t in (self,)+lst], grad_fxn=grad_fxn), device=self.device) return Tensor((fxn.uop if isinstance(fxn, Tensor) else fxn).call(*[t.uop for t in (self,)+lst], grad_fxn=grad_fxn), device=self.device)
# TODO: Tensor should just proxy everything to UOp
def after(self, *lst:Tensor): return Tensor(self.uop.after(*[t.uop for t in lst]), device=self.device)
def sink(self): return Tensor(self.uop.sink(), device=self.device)
def custom_kernel(self, *lst:Tensor, fxn:Callable, grad_fxn:Callable|None=None) -> list[Tensor]: def custom_kernel(self, *lst:Tensor, fxn:Callable, grad_fxn:Callable|None=None) -> list[Tensor]:
""" """
Call into a custom kernel written in UOps. Returns the Tensors after the Kernel has been applied. Call into a custom kernel written in UOps. Returns the Tensors after the Kernel has been applied.