Compare commits

...

1 commit

Author SHA1 Message Date
George Hotz
e7cfac01f8 test_mismatch_reduce 2025-08-06 09:49:37 -07:00
2 changed files with 8 additions and 1 deletions

View file

@ -111,6 +111,13 @@ class TestFuse(unittest.TestCase):
with Context(NOOPT=1):
self._test_fuse(Tensor.scaled_dot_product_attention, q, k, v, atol=1e-5)
@unittest.expectedFailure
def test_mismatch_reduce(self):
a = Tensor.ones(16, 10).contiguous().realize()
b = Tensor.ones(16, 20).contiguous().realize()
c = (a.sum(axis=1) + b.sum(axis=1)).fuse()
self.assertListEqual(c.tolist(), [30]*16)
class TestSoftmaxFusion(unittest.TestCase):
@classmethod
def setUpClass(cls):

View file

@ -39,7 +39,7 @@ def get_program(ast:UOp, renderer:Renderer) -> ProgramSpec:
if DEBUG >= 6: print_uops(uops)
src = renderer.render(uops)
return ProgramSpec(uops[-1].arg.name, src, renderer.device, ast, uops,
return ProgramSpec(uops[-1].arg.name if uops[-1].arg is not None else "test", src, renderer.device, ast, uops,
global_size=[1,1,1] if renderer.has_local else None, local_size=[1,1,1] if renderer.has_local else None)
# **************** Runners ****************