mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
1 commit
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e7cfac01f8 |
2 changed files with 8 additions and 1 deletions
|
|
@ -111,6 +111,13 @@ class TestFuse(unittest.TestCase):
|
|||
with Context(NOOPT=1):
|
||||
self._test_fuse(Tensor.scaled_dot_product_attention, q, k, v, atol=1e-5)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_mismatch_reduce(self):
|
||||
a = Tensor.ones(16, 10).contiguous().realize()
|
||||
b = Tensor.ones(16, 20).contiguous().realize()
|
||||
c = (a.sum(axis=1) + b.sum(axis=1)).fuse()
|
||||
self.assertListEqual(c.tolist(), [30]*16)
|
||||
|
||||
class TestSoftmaxFusion(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ def get_program(ast:UOp, renderer:Renderer) -> ProgramSpec:
|
|||
if DEBUG >= 6: print_uops(uops)
|
||||
src = renderer.render(uops)
|
||||
|
||||
return ProgramSpec(uops[-1].arg.name, src, renderer.device, ast, uops,
|
||||
return ProgramSpec(uops[-1].arg.name if uops[-1].arg is not None else "test", src, renderer.device, ast, uops,
|
||||
global_size=[1,1,1] if renderer.has_local else None, local_size=[1,1,1] if renderer.has_local else None)
|
||||
|
||||
# **************** Runners ****************
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue