wmma: enable METAL half tensor cores and clean up cstyle (#3095)

* wmma: enable METAL half tensor cores and clean up cstyle

* revert simple_matmul rand changes and break line in tensor

* added metal fp16->fp32 tensor core
This commit is contained in:
Francis Lam 2024-01-12 13:25:28 -08:00 committed by GitHub
commit ddbdb52f77
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 37 additions and 42 deletions

View file

@ -87,10 +87,7 @@ class TestLinearizer(unittest.TestCase):
if tc.arch is not None and tc.arch != os.uname().machine: continue
a, b = Tensor.rand(tc.dims[0], tc.dims[2], dtype=tc.dtype_in), Tensor.rand(tc.dims[2], tc.dims[1], dtype=tc.dtype_in)
np_a, np_b = a.numpy(), b.numpy()
if tc.dtype_out != tc.dtype_in:
r = (a.reshape(tc.dims[0], 1, tc.dims[2]) * b.permute(1,0).reshape(1, tc.dims[1], tc.dims[2])).cast(tc.dtype_out).sum(axis=2)
else:
r = a @ b
r = a.matmul(b, acc_dtype=tc.dtype_out)
realized_ast, _ = helper_realized_ast(r)
k = Linearizer(realized_ast)
k.apply_tensor_cores(1)