mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
1 commit
master
...
less_wmma_
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9a4f54a1f3 |
1 changed files with 6 additions and 5 deletions
|
|
@ -308,13 +308,14 @@ class MetalRenderer(CStyleLanguage):
|
|||
]) + base_rewrite
|
||||
|
||||
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None):
|
||||
prefix, wmma_args = ["#include <metal_stdlib>","using namespace metal;"], set([uop.arg for uop in uops if uop.op is Ops.WMMA])
|
||||
for arg in wmma_args: prefix.append(
|
||||
f"""{(dtype_out:=self.render_dtype(arg[3].vec(2)))} __{arg[0]}({(dtype_in:=self.render_dtype(arg[2].vec(2)))} a, {dtype_in} b, {dtype_out} c){{
|
||||
simdgroup_{self.render_dtype(arg[2])}8x8 mat_a, mat_b; simdgroup_{self.render_dtype(arg[3])}8x8 mat_c;
|
||||
wmma_args = set([(uop.dtype, uop.src[0].dtype, uop.arg[0]) for uop in uops if uop.op is Ops.WMMA])
|
||||
prefix = ["#include <metal_stdlib>","using namespace metal;"]
|
||||
for dtype_out, dtype_in, name in wmma_args: prefix.append(
|
||||
f"""{(dstr_out:=self.render_dtype(dtype_out))} __{name}({(dstr_in:=self.render_dtype(dtype_in))} a, {dstr_in} b, {dstr_out} c){{
|
||||
simdgroup_{self.render_dtype(dtype_in.scalar())}8x8 mat_a, mat_b; simdgroup_{self.render_dtype(dtype_out.scalar())}8x8 mat_c;
|
||||
mat_a.thread_elements()[0] = a[0]; mat_b.thread_elements()[0] = b[0]; mat_c.thread_elements()[0] = c[0];
|
||||
mat_a.thread_elements()[1] = a[1]; mat_b.thread_elements()[1] = b[1]; mat_c.thread_elements()[1] = c[1];
|
||||
simdgroup_multiply_accumulate(mat_c, mat_a, mat_b, mat_c);\n return {dtype_out}(mat_c.thread_elements()[0], mat_c.thread_elements()[1]);\n}}""")
|
||||
simdgroup_multiply_accumulate(mat_c, mat_a, mat_b, mat_c);\n return {dstr_out}(mat_c.thread_elements()[0], mat_c.thread_elements()[1]);\n}}""")
|
||||
return super().render_kernel(function_name, kernel, bufs, uops, prefix)
|
||||
|
||||
_nms = "xyzwabcdefghijkl"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue