triple gemm

This commit is contained in:
George Hotz 2025-07-30 17:03:02 -07:00
commit 8dff2c1375

View file

@ -33,6 +33,14 @@ class TestTiny(unittest.TestCase):
self.assertListEqual((out:=a@b).flatten().tolist(), [1.0]*(N*N))
if IMAGE < 2: self.assertEqual(out.dtype, out_dtype)
def test_double_gemm(self, N=64, out_dtype=dtypes.float):
a = Tensor.ones(N,N).contiguous().realize()
b = Tensor.eye(N).contiguous().realize()
c = Tensor.eye(N).contiguous().realize()
d = Tensor.eye(N).contiguous().realize()
out = (((a@b).relu()@c).relu()@d).contiguous().realize()
self.assertListEqual(out.flatten().tolist(), [1.0]*(N*N))
# *** randomness ***
def test_random(self):