mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
triple gemm
This commit is contained in:
parent
3cff1a6b13
commit
8dff2c1375
1 changed files with 8 additions and 0 deletions
|
|
@ -33,6 +33,14 @@ class TestTiny(unittest.TestCase):
|
|||
self.assertListEqual((out:=a@b).flatten().tolist(), [1.0]*(N*N))
|
||||
if IMAGE < 2: self.assertEqual(out.dtype, out_dtype)
|
||||
|
||||
def test_double_gemm(self, N=64, out_dtype=dtypes.float):
|
||||
a = Tensor.ones(N,N).contiguous().realize()
|
||||
b = Tensor.eye(N).contiguous().realize()
|
||||
c = Tensor.eye(N).contiguous().realize()
|
||||
d = Tensor.eye(N).contiguous().realize()
|
||||
out = (((a@b).relu()@c).relu()@d).contiguous().realize()
|
||||
self.assertListEqual(out.flatten().tolist(), [1.0]*(N*N))
|
||||
|
||||
# *** randomness ***
|
||||
|
||||
def test_random(self):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue