Compare commits

...

1 commit

Author SHA1 Message Date
George Hotz
9a4f54a1f3 don't use the wmma args as much 2025-07-25 18:46:46 -07:00

View file

@ -308,13 +308,14 @@ class MetalRenderer(CStyleLanguage):
]) + base_rewrite
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None):
prefix, wmma_args = ["#include <metal_stdlib>","using namespace metal;"], set([uop.arg for uop in uops if uop.op is Ops.WMMA])
for arg in wmma_args: prefix.append(
f"""{(dtype_out:=self.render_dtype(arg[3].vec(2)))} __{arg[0]}({(dtype_in:=self.render_dtype(arg[2].vec(2)))} a, {dtype_in} b, {dtype_out} c){{
simdgroup_{self.render_dtype(arg[2])}8x8 mat_a, mat_b; simdgroup_{self.render_dtype(arg[3])}8x8 mat_c;
wmma_args = set([(uop.dtype, uop.src[0].dtype, uop.arg[0]) for uop in uops if uop.op is Ops.WMMA])
prefix = ["#include <metal_stdlib>","using namespace metal;"]
for dtype_out, dtype_in, name in wmma_args: prefix.append(
f"""{(dstr_out:=self.render_dtype(dtype_out))} __{name}({(dstr_in:=self.render_dtype(dtype_in))} a, {dstr_in} b, {dstr_out} c){{
simdgroup_{self.render_dtype(dtype_in.scalar())}8x8 mat_a, mat_b; simdgroup_{self.render_dtype(dtype_out.scalar())}8x8 mat_c;
mat_a.thread_elements()[0] = a[0]; mat_b.thread_elements()[0] = b[0]; mat_c.thread_elements()[0] = c[0];
mat_a.thread_elements()[1] = a[1]; mat_b.thread_elements()[1] = b[1]; mat_c.thread_elements()[1] = c[1];
simdgroup_multiply_accumulate(mat_c, mat_a, mat_b, mat_c);\n return {dtype_out}(mat_c.thread_elements()[0], mat_c.thread_elements()[1]);\n}}""")
simdgroup_multiply_accumulate(mat_c, mat_a, mat_b, mat_c);\n return {dstr_out}(mat_c.thread_elements()[0], mat_c.thread_elements()[1]);\n}}""")
return super().render_kernel(function_name, kernel, bufs, uops, prefix)
_nms = "xyzwabcdefghijkl"