Compare commits

...

3 commits

Author SHA1 Message Date
George Hotz
cf3f67e0e5 more crap 2026-05-21 13:48:22 -07:00
George Hotz
ea3dd000d3 more movement from count 2026-05-21 13:28:05 -07:00
George Hotz
ded5cdf2ea no dtypes.count in renderer, use shape 2026-05-21 12:52:23 -07:00
8 changed files with 261 additions and 205 deletions

View file

@ -55,7 +55,7 @@ class Estimates:
lds += u.dtype.itemsize * mults
elif u.op is Ops.STORE and (not isinstance(u.src[0].dtype, PtrDType) or u.src[0].dtype.addrspace != AddrSpace.REG):
lds += u.src[1].dtype.itemsize * mults
elif u.op in GroupOp.ALU and u not in dont_count: flops += (mults * (2 if u.op is Ops.MULACC else 1)) * u.dtype.count
elif u.op in GroupOp.ALU and u not in dont_count: flops += (mults * (2 if u.op is Ops.MULACC else 1)) * u.max_numel()
elif u.op is Ops.WMMA and u not in dont_count: flops += 2 * prod(u.arg[1]) // u.arg[5] * mults
return Estimates(flops, lds, sum(mem.values()))

View file

@ -10,37 +10,44 @@ from tinygrad.codegen.late.devectorizer import no_vectorized_alu
base_rewrite = PatternMatcher([
(UPat(Ops.DEFINE_REG, name="x"), lambda ctx,x: f"{ctx.render_dtype(x.dtype.base)} {ctx[x]}[{x.dtype.size}];"),
(UPat(Ops.DEFINE_REG, name="x"), lambda ctx,x:
f"{ctx.render_dtype(x.dtype.base.scalar(), lanes=x.max_shape[-1] if len(x.max_shape) > 1 else 1)} {ctx[x]}[{x.dtype.size}];"),
(UPat(Ops.IF, name="x"), lambda ctx,x: f"if ({ctx[x.src[0]]}) {{"),
(UPat((Ops.ENDIF, Ops.END)), lambda ctx: "}"),
(UPat(Ops.WMMA, name="x"), lambda ctx,x: f"__{x.arg[0]}({ctx[x.src[0]]}, {ctx[x.src[1]]}, {ctx[x.src[2]]})"),
# r method accesses
(UPat(Ops.RANGE, name="x"),
lambda ctx,x: f"for ({ctx.render_dtype(x.dtype)} {ctx[x]} = 0; {ctx[x]} < {ctx[x.src[0]]}; {ctx[x]}++) {{"),
lambda ctx,x: f"for ({ctx.render_dtype(x.dtype, lanes=x.max_numel())} {ctx[x]} = 0; {ctx[x]} < {ctx[x.src[0]]}; {ctx[x]}++) {{"),
(UPat(Ops.STACK, name="x"),
lambda ctx,x: f"{ctx.float4.replace('float4', ctx.render_dtype(x.dtype))}" + \
lambda ctx,x: f"{ctx.float4.replace('float4', ctx.render_dtype(x.dtype, lanes=x.max_numel()))}" + \
f"{ctx.float4_style[0]}{','.join([ctx[y] for y in x.src])}{ctx.float4_style[1]}"),
(UPat(Ops.CAST, name="x"), lambda ctx,x:
f"__builtin_convertvector({ctx[x.src[0]]}, {ctx.render_dtype(x.dtype)})" if x.dtype.count > 1 and not isinstance(x.dtype, PtrDType) else None),
(UPat(Ops.CAST, name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, ctx[x.src[0]])})"),
f"__builtin_convertvector({ctx[x.src[0]]}, "
f"{ctx.render_dtype(x.dtype, lanes=x.max_numel())})" if x.max_numel() > 1 and not isinstance(x.dtype, PtrDType) else None),
(UPat(Ops.CAST, name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, ctx[x.src[0]], lanes=x.max_numel())})"),
(UPat(Ops.BITCAST, name="x"), lambda ctx,x:
f"__builtin_bit_cast({ctx.render_dtype(x.dtype)}, ({ctx.render_dtype(x.src[0].dtype)})({ctx[x.src[0]]}))"),
(UPat(Ops.DEFINE_LOCAL, name="x"), lambda ctx,x: f"{ctx.smem_align}{ctx.smem_prefix}{ctx.render_dtype(x.dtype.base)} {ctx[x]}[{x.dtype.size}];"),
f"__builtin_bit_cast({ctx.render_dtype(x.dtype, lanes=x.max_numel())}, "
f"({ctx.render_dtype(x.src[0].dtype, lanes=x.src[0].max_numel())})({ctx[x.src[0]]}))"),
(UPat(Ops.DEFINE_LOCAL, name="x"), lambda ctx,x:
f"{ctx.smem_align}{ctx.smem_prefix}"
f"{ctx.render_dtype(x.dtype.base.scalar(), lanes=x.max_shape[-1] if len(x.max_shape) > 1 else 1)} {ctx[x]}[{x.dtype.size}];"),
(UPat(Ops.BARRIER), lambda ctx: ctx.barrier),
(UPat(Ops.SPECIAL, name="x"), lambda ctx,x: f"{ctx.code_for_workitem[x.arg[0]](x.arg[-1])}; /* {(x.src[0]).render()} */"),
# const
(UPat(Ops.CONST, arg=math.inf, name="x"), lambda ctx, x: f"({ctx.render_cast(x.dtype, ctx.infinity)})"),
(UPat(Ops.CONST, arg=-math.inf, name="x"), lambda ctx, x: f"({ctx.render_cast(x.dtype, f'-{ctx.infinity}')})"),
(UPat(Ops.CONST, dtype=dtypes.floats, name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, ctx.nan)})" if math.isnan(x.arg) else None),
(UPat(Ops.CONST, arg=math.inf, name="x"), lambda ctx, x: f"({ctx.render_cast(x.dtype, ctx.infinity, lanes=x.max_numel())})"),
(UPat(Ops.CONST, arg=-math.inf, name="x"), lambda ctx, x: f"({ctx.render_cast(x.dtype, f'-{ctx.infinity}', lanes=x.max_numel())})"),
(UPat(Ops.CONST, dtype=dtypes.floats, name="x"), lambda ctx,x:
f"({ctx.render_cast(x.dtype, ctx.nan, lanes=x.max_numel())})" if math.isnan(x.arg) else None),
(UPat(Ops.CONST, dtype=dtypes.float, name="x"), lambda ctx,x: f"{x.arg}f"),
(UPat(Ops.CONST, dtype=dtypes.int64, name="x"), lambda ctx,x: f"{x.arg}ll"),
(UPat(Ops.CONST, dtype=dtypes.uint64, name="x"), lambda ctx,x: f"{truncate[x.dtype](x.arg)}ull"),
(UPat(Ops.CONST, dtype=dtypes.uint32, name="x"), lambda ctx,x: f"{truncate[x.dtype](x.arg)}u"),
(UPat(Ops.CONST, dtype=dtypes.bool, name="x"), lambda ctx,x: "1" if x.arg else "0"),
# consts are rendered to larger type and casted
(UPat(Ops.CONST, (*dtypes.fp8s, dtypes.bfloat16, dtypes.half), name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, f'{x.arg}f')})"),
(UPat(Ops.CONST, (dtypes.uint8, dtypes.uint16), name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, f'{x.arg}u')})"),
(UPat(Ops.CONST, (dtypes.int8, dtypes.int16), name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, str(x.arg))})"),
(UPat(Ops.CONST, (*dtypes.fp8s, dtypes.bfloat16, dtypes.half), name="x"), lambda ctx,x:
f"({ctx.render_cast(x.dtype, f'{x.arg}f', lanes=x.max_numel())})"),
(UPat(Ops.CONST, (dtypes.uint8, dtypes.uint16), name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, f'{x.arg}u', lanes=x.max_numel())})"),
(UPat(Ops.CONST, (dtypes.int8, dtypes.int16), name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, str(x.arg), lanes=x.max_numel())})"),
# default const render
(UPat(Ops.CONST, name="x"), lambda ctx,x: str(x.arg)),
# new load/store
@ -53,7 +60,7 @@ base_rewrite = PatternMatcher([
(UPat(GroupOp.ALU, name="x"), lambda ctx,x: ctx.code_for_op[x.op](
*([strip_parens(ctx[v]) if v.op == x.op and x.op in {Ops.ADD, Ops.MUL, Ops.XOR, Ops.OR, Ops.AND} else ctx[v] for v in x.src]), x.dtype)),
(UPat(Ops.GEP, name="x"), lambda ctx,x: ctx[x.src[0]] + \
(f"[{x.arg[0]}]" if x.src[0].dtype.count > ctx.gep_arr_threshold else f".{'xyzwabcd'[x.arg[0]]}")),
(f"[{x.arg[0]}]" if x.src[0].max_numel() > ctx.gep_arr_threshold else f".{'xyzwabcd'[x.arg[0]]}")),
# custom passes through with format
(UPat((Ops.CUSTOM, Ops.CUSTOMI), name="x"), lambda ctx,x: x.arg.format(*[ctx[y] for y in x.src])),
])
@ -96,6 +103,9 @@ pm_manual_bf16_cast = PatternMatcher([
])
def uops_to_dtypes(uops:list[UOp]) -> list[DType]: return dedup(u.dtype for u in uops if not isinstance(u.dtype, (ImageDType, PtrDType)))
def uops_to_type_lanes(uops:list[UOp]) -> list[tuple[DType, int]]:
return dedup((u.dtype.scalar(), u.max_numel()) for u in uops if u.dtype is not dtypes.void and \
not isinstance(u.dtype, (ImageDType, PtrDType)))
# (name, dims, dtype_in, dtype_out, device, threads, upcast_axes, reduce_axes)
def wmma_args(uops:list[UOp]):
@ -144,15 +154,15 @@ class CStyleLanguage(Renderer):
[") {\n" + tmp] + ['\n'.join(kernel), "\n}"])
return prg if prefix is None else "\n".join(prefix)+f"\n{prg}"
def render_cast(self, dt:DType, val: str) -> str: return f"({self.render_dtype(dt)})({val})"
def render_dtype(self, dt:DType, mutable=True) -> str:
def render_cast(self, dt:DType, val: str, lanes:int=1) -> str: return f"({self.render_dtype(dt, lanes=lanes)})({val})"
def render_dtype(self, dt:DType, mutable=True, lanes:int=1) -> str:
if isinstance(dt, ImageDType): return f"{'write_only' if mutable else 'read_only'} image2d_t"
if isinstance(dt, PtrDType):
prefix = ""
if dt.addrspace == AddrSpace.LOCAL and self.smem_prefix_for_cast: prefix = self.smem_prefix
if dt.addrspace == AddrSpace.GLOBAL: prefix = self.buffer_prefix
return prefix + self.render_dtype(dt.base) + "*"
if dt.count > 1: return self.type_map.get(scalar:=dt.scalar(), scalar.name).replace(" ", "_") + str(dt.count)
return prefix + self.render_dtype(dt.base.scalar(), lanes=lanes) + "*"
if lanes > 1: return self.type_map.get(scalar:=dt.scalar(), scalar.name).replace(" ", "_") + str(lanes)
return self.type_map.get(scalar:=dt.scalar(), scalar.name)
def __getitem__(self, key): return self.r[key] # hacky helper
@ -197,14 +207,14 @@ class CStyleLanguage(Renderer):
assert l is not None, f"failed to render {u.op} {u.dtype} {[(x.op,x.dtype) for x in u.src]} {u.arg}"
if u.op in {Ops.ENDIF, Ops.END}: depth -= 1
if (u.op is not Ops.CAST or u.dtype.vcount == 1) and (u.op in {Ops.CONST, Ops.GEP, Ops.INDEX, Ops.CUSTOMI} or \
if (u.op is not Ops.CAST or u.max_numel() == 1) and (u.op in {Ops.CONST, Ops.GEP, Ops.INDEX, Ops.CUSTOMI} or \
(u.op is Ops.LOAD and u.src[0].ptrdtype.addrspace == AddrSpace.REG) or \
(u.op is Ops.CAST and isinstance(u.dtype, PtrDType)) or \
(u.op in {Ops.STACK, *(GroupOp.ALU-{Ops.WHERE}), Ops.CAST, Ops.BITCAST} and child_count[u] == 1 and not getenv("EXPAND_SSA"))):
r[u] = l
else:
if u.op not in {Ops.RANGE, Ops.DEFINE_LOCAL, Ops.STORE, Ops.DEFINE_REG} and u.dtype != dtypes.void:
l = f"{self.render_dtype(u.dtype)} {r[u]} = {l}" + (";" if u.op is not Ops.SPECIAL else "")
l = f"{self.render_dtype(u.dtype, lanes=u.max_numel())} {r[u]} = {l}" + (";" if u.op is not Ops.SPECIAL else "")
kernel.append(" "*depth + l)
if prefix: c[prefix] += 1 # if it was used, increment
if u.op in {Ops.IF, Ops.RANGE}: depth += 1
@ -242,13 +252,13 @@ class ClangRenderer(CStyleLanguage):
if sys.platform == 'win32':
kernel_typedef = "__attribute__((ms_abi)) void"
def render_vector_prefix(self, dt:DType) -> str:
def render_vector_prefix(self, dt:DType, lanes:int) -> str:
# round (down) to power of two (this is actually the default clang behavior)
alignment = 2**int(math.log2(dt.itemsize)) if getenv("ALIGNED", 1) and not dtypes.is_bool(dt) else 1
return f"typedef {self.render_dtype(dt.scalar())} {self.render_dtype(dt)} __attribute__((aligned({alignment}),ext_vector_type({dt.count})));"
alignment = 2**int(math.log2(dt.itemsize*lanes)) if getenv("ALIGNED", 1) and not dtypes.is_bool(dt) else 1
return f"typedef {self.render_dtype(dt)} {self.render_dtype(dt, lanes=lanes)} __attribute__((aligned({alignment}),ext_vector_type({lanes})));"
def _render_defines(self, uops) -> list[str]:
prefix = [self.render_vector_prefix(dt) for dt in uops_to_dtypes(uops) if dt.count > 1]
prefix = [self.render_vector_prefix(dt, lanes) for dt, lanes in uops_to_type_lanes(uops) if lanes > 1]
# https://github.com/corsix/amx
for name, (N, M, _), dtype_in, _, _, _, _, _ in wmma_args(uops):
prefix += [
@ -258,7 +268,7 @@ class ClangRenderer(CStyleLanguage):
# 'static' in C roughly means that function symbol isn't exported. LLVM puts those symbols at the end of object file which allows Clang JIT
# to just jump at the start of a shellcode without having to deal with symbols or trampolines at all. This is better than having to inline
# wmma function every time it is called or wasting complexity on a symbol parsing and a memory page on trampoline.
out, dt1, dt2 = self.render_dtype(dtype_in.vec(N*N)), self.render_dtype(dtype_in.vec(N)), self.render_dtype(dtype_in.vec(M))
out, dt1, dt2 = self.render_dtype(dtype_in, lanes=N*N), self.render_dtype(dtype_in, lanes=N), self.render_dtype(dtype_in, lanes=M)
prefix += [f"""static {out} __{name}({dt1} data1, {dt2} data2, {out} data0){{
AMX_SET(0);\n for(int ridx0 = 0; ridx0 < 16; ridx0++){{ AMX(4, (int *)(&data0), 0ull<<62 | (ridx0*4ull)<<56 | ridx0*64ull); }}
AMX(0, (int *)(&data2), 0ull<<62); AMX(1, (int *)(&data1), 0ull<<62); AMX(12, 0, 0ull);
@ -298,19 +308,23 @@ class OpenCLRenderer(CStyleLanguage):
extra_matcher = create_non_native_float_pats((dtypes.bfloat16,)) + pm_manual_bf16_cast + extra_pm
string_rewrite = PatternMatcher([
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"as_{ctx.render_dtype(x.dtype)}(({ctx.render_dtype(x.src[0].dtype)})({ctx[x.src[0]]}))"),
(UPat(Ops.BITCAST, name="x"), lambda ctx,x:
f"as_{ctx.render_dtype(x.dtype, lanes=x.max_numel())}(({ctx.render_dtype(x.src[0].dtype, lanes=x.src[0].max_numel())})({ctx[x.src[0]]}))"),
# bfloat16 constants need to be rendered as their bit pattern since bf16 is stored as ushort
(UPat(Ops.CONST, dtypes.bfloat16, name="x"),
lambda ctx,x: f"{(struct.unpack('I', struct.pack('f', float_to_bf16(x.arg)))[0] >> 16)}u"),
# load/store image (OpenCL)
(UPat.var('buf').index(UPat.var('idx_y'), UPat.var('idx_x')), lambda ctx,buf,idx_y,idx_x: f"IMAGE<{ctx[buf]}, {ctx[idx_y]}, {ctx[idx_x]}>"),
(UPat(Ops.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf').index(UPat.var('idx_y'), UPat.var('idx_x')), UPat.var("var"), UPat.var("gate"))),
lambda ctx,buf,idx_y,idx_x,var,gate: f"({ctx[gate]}?read_imagef({ctx[buf]}, smp, (int2)({ctx[idx_x]},{ctx[idx_y]})):{ctx[var]})"),
(UPat(Ops.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf').index(UPat.var('idx_y'), UPat.var('idx_x')),)),
lambda ctx,buf,idx_y,idx_x: f"read_imagef({ctx[buf]}, smp, (int2)({ctx[idx_x]},{ctx[idx_y]}))"),
(UPat(Ops.LOAD,
dtype=dtypes.float, src=(UPat.var('buf').index(UPat.var('idx_y'), UPat.var('idx_x')), UPat.var("var"), UPat.var("gate")), name="x"),
lambda ctx,x,buf,idx_y,idx_x,var,gate:
f"({ctx[gate]}?read_imagef({ctx[buf]}, smp, (int2)({ctx[idx_x]},{ctx[idx_y]})):{ctx[var]})" if x.max_numel() == 4 else None),
(UPat(Ops.LOAD, dtype=dtypes.float, src=(UPat.var('buf').index(UPat.var('idx_y'), UPat.var('idx_x')),), name="x"),
lambda ctx,x,buf,idx_y,idx_x: f"read_imagef({ctx[buf]}, smp, (int2)({ctx[idx_x]},{ctx[idx_y]}))" if x.max_numel() == 4 else None),
(UPat(Ops.STORE, src=(UPat.var('buf').index(UPat.var('idx_y'), UPat.var('idx_x')),
UPat.var("var", dtypes.float.vec(4))), allow_any_len=True),
lambda ctx,buf,idx_y,idx_x,var: f"write_imagef({ctx[buf]}, (int2)({ctx[idx_x]},{ctx[idx_y]}), {ctx[var]});"),
UPat.var("var", dtypes.float)), allow_any_len=True),
lambda ctx,buf,idx_y,idx_x,var:
f"write_imagef({ctx[buf]}, (int2)({ctx[idx_x]},{ctx[idx_y]}), {ctx[var]});" if var.max_numel() == 4 else None),
]) + base_rewrite
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
@ -375,14 +389,15 @@ class MetalRenderer(CStyleLanguage):
]) + extra_pm
string_rewrite = PatternMatcher([
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"as_type<{ctx.render_dtype(x.dtype)}>(({ctx.render_dtype(x.src[0].dtype)})({ctx[x.src[0]]}))"),
(UPat(Ops.BITCAST, name="x"), lambda ctx,x:
f"as_type<{ctx.render_dtype(x.dtype, lanes=x.max_numel())}>(({ctx.render_dtype(x.src[0].dtype, lanes=x.src[0].max_numel())})({ctx[x.src[0]]}))"),
]) + base_rewrite
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None):
prefix = ["#include <metal_stdlib>","using namespace metal;"]
deduped_wmma_args = dedup([(name, dtype_in, dtype_out) for name, _, dtype_in, dtype_out, _, _, _, _ in wmma_args(uops)])
for name, dtype_in, dtype_out in deduped_wmma_args: prefix.append(
f"""{(dstr_out:=self.render_dtype(dtype_out.vec(2)))} __{name}({(dstr_in:=self.render_dtype(dtype_in.vec(2)))} a, {dstr_in} b, {dstr_out} c){{
f"""{(dstr_out:=self.render_dtype(dtype_out, lanes=2))} __{name}({(dstr_in:=self.render_dtype(dtype_in, lanes=2))} a, {dstr_in} b, {dstr_out} c){{
simdgroup_{self.render_dtype(dtype_in)}8x8 mat_a, mat_b; simdgroup_{self.render_dtype(dtype_out)}8x8 mat_c;
mat_a.thread_elements()[0] = a[0]; mat_b.thread_elements()[0] = b[0]; mat_c.thread_elements()[0] = c[0];
mat_a.thread_elements()[1] = a[1]; mat_b.thread_elements()[1] = b[1]; mat_c.thread_elements()[1] = c[1];
@ -429,12 +444,14 @@ class CUDARenderer(CStyleLanguage):
(UPat(Ops.CAST, dtypes.fp8s, UPat.var("x", dtypes.fp8s), name='y'), lambda x,y: x.cast(dtypes.float).cast(y.dtype) if x.dtype!=y.dtype else None),
]) + extra_pm
string_rewrite = PatternMatcher([
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"tg_bitcast<{ctx.render_dtype(x.dtype)}>(({ctx.render_dtype(x.src[0].dtype)})({ctx[x.src[0]]}))"),
(UPat(Ops.BITCAST, name="x"), lambda ctx,x:
f"tg_bitcast<{ctx.render_dtype(x.dtype, lanes=x.max_numel())}>"
f"(({ctx.render_dtype(x.src[0].dtype, lanes=x.src[0].max_numel())})({ctx[x.src[0]]}))"),
]) + base_rewrite
def render_vector_prefix(self, dt:DType) -> str:
vec, scal = self.render_dtype(dt), self.render_dtype(dt.scalar()),
elems, header = ', '.join(_nms[:dt.count]), ', '.join([f"{scal} {x}" for x in _nms[:dt.count]])
def render_vector_prefix(self, dt:DType, lanes:int) -> str:
vec, scal = self.render_dtype(dt, lanes=lanes), self.render_dtype(dt)
elems, header = ', '.join(_nms[:lanes]), ', '.join([f"{scal} {x}" for x in _nms[:lanes]])
return f"struct __align__({dt.itemsize}) {vec} {{ {scal} {elems}; }}; __device__ {vec} make_{vec}({header}) {{ {vec} r={{{elems}}}; return r; }}"
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None):
@ -445,13 +462,13 @@ class CUDARenderer(CStyleLanguage):
if any(dt.scalar() in dtypes.fp8s for dt in used_dtypes): prefix.append("#include <cuda_fp8.h>")
if any(dt.scalar() == dtypes.half for dt in used_dtypes): prefix.append("#include <cuda_fp16.h>")
if any(dt.scalar() == dtypes.bfloat16 for dt in used_dtypes): prefix.append("#include <cuda_bf16.h>")
prefix += [self.render_vector_prefix(dt) for dt in used_dtypes if (dt.count in (4,8) and dt.scalar() in {dtypes.half, dtypes.bfloat16})
or (dt.count in (2,4,8,16) and dt.scalar() in dtypes.fp8s)]
prefix += [self.render_vector_prefix(dt, lanes) for dt, lanes in uops_to_type_lanes(uops)
if (lanes in (4,8) and dt in {dtypes.half, dtypes.bfloat16}) or (lanes in (2,4,8,16) and dt in dtypes.fp8s)]
dt_map_in = { dtypes.float: "tf32", dtypes.half: "f16", dtypes.bfloat16: "bf16", dtypes.fp8e4m3: "e4m3", dtypes.fp8e5m2: "e5m2" }
dt_map_out = { dtypes.float: "f32", dtypes.half: "f16" }
for name, (N, M, K), dtype_in, dtype_out, _, _, upcast_axes, _ in wmma_args(uops):
upcast_sizes = [prod(size for _, size in upcast) for upcast in upcast_axes]
wmma_dtypes = [self.render_dtype(dtype.vec(size)) for dtype, size in zip([dtype_in, dtype_in, dtype_out], upcast_sizes)]
wmma_dtypes = [self.render_dtype(dtype, lanes=size) for dtype, size in zip([dtype_in, dtype_in, dtype_out], upcast_sizes)]
n_operands = [size*dtype.itemsize//4 for dtype, size in zip([dtype_in, dtype_in, dtype_out], upcast_sizes)] # 4 => CUDA reg size in bytes
operands = [f"%{i}" for i in range(sum(n_operands))]
@ -518,9 +535,9 @@ class HIPRenderer(CStyleLanguage):
float4 = "make_float4"
type_map = {dtypes.bfloat16: "hip_bfloat16", dtypes.fp8e4m3: "hip_fp8", dtypes.fp8e5m2: "hip_bf8"}
extra_matcher = create_non_native_float_pats((dtypes.bfloat16, *dtypes.fp8s)) + PatternMatcher([
(UPat(Ops.WMMA, name="x", dtype=dtypes.float.vec(4)),
lambda x: UOp(Ops.WMMA, x.dtype, (x.src[0].bitcast(dtypes.uint64), x.src[1].bitcast(dtypes.uint64),
x.src[2]), (*x.arg,)) if x.src[0].dtype in (dtypes.fp8e4m3.vec(8), dtypes.fp8e5m2.vec(8)) else None),
(UPat(Ops.WMMA, name="x", dtype=dtypes.float),
lambda x: UOp(Ops.WMMA, x.dtype.scalar(), (x.src[0].bitcast(dtypes.uint64), x.src[1].bitcast(dtypes.uint64),
x.src[2]), (*x.arg,)) if x.max_numel() == 4 and x.src[0].dtype.scalar() in dtypes.fp8_ocp and x.src[0].max_numel() == 8 else None),
# bfloat16 constant casting
(UPat.cvar('x', dtypes.bfloat16), lambda x: cast_float_to_bf16(UOp.const(dtypes.float, x.arg))),
])
@ -529,10 +546,10 @@ class HIPRenderer(CStyleLanguage):
from tinygrad.renderer.amd.elf import assemble_linear
return assemble_linear(prg, lin, self.target.arch)
def render_vector_prefix(self, dtype:DType) -> str:
vec, scal = self.render_dtype(dtype), self.render_dtype(dtype.scalar())
return f"typedef {scal} {vec} __attribute__((ext_vector_type({dtype.count})));\nstatic inline __attribute__((device)) "+ \
f"{vec} make_{vec}({', '.join([f'{scal} {x}' for x in _nms[:dtype.count]])}) {{ return {{ {', '.join(_nms[:dtype.count])} }}; }}"
def render_vector_prefix(self, dtype:DType, lanes:int) -> str:
vec, scal = self.render_dtype(dtype, lanes=lanes), self.render_dtype(dtype)
return f"typedef {scal} {vec} __attribute__((ext_vector_type({lanes})));\nstatic inline __attribute__((device)) "+ \
f"{vec} make_{vec}({', '.join([f'{scal} {x}' for x in _nms[:lanes]])}) {{ return {{ {', '.join(_nms[:lanes])} }}; }}"
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
prefix, ockl = [], []
@ -556,7 +573,7 @@ class HIPRenderer(CStyleLanguage):
v = (((*(unsigned*)&v)&0x7F800000)!=0x7F800000)?__builtin_amdgcn_fmed3f(v,is_bf8?57344.0f:448.0f,is_bf8?-57344.0f:-448.0f) : v;
return (unsigned char)(is_bf8?__builtin_amdgcn_cvt_pk_bf8_f32(v,v,0,false):__builtin_amdgcn_cvt_pk_fp8_f32(v,v,0,false));\n}""")
prefix += [f'extern "C" __attribute__((device{f", {atr}" if atr else ""})) {dto} {meth}({dti});' for meth,dti,dto,atr in ockl+ocml]
prefix += [self.render_vector_prefix(dt) for dt in used_dtypes if dt.count > 1]
prefix += [self.render_vector_prefix(dt, lanes) for dt, lanes in uops_to_type_lanes(uops) if lanes > 1]
for name, (N, M, K), dtype_in, dtype_out, _, _, _, _ in wmma_args(uops): # TODO: handle TCs f32_bf16 and bf16_bf16 w/ wrapper
if self.is_cdna(self.target.arch):

View file

@ -153,15 +153,15 @@ extra_matcher = PatternMatcher([
# no int8 mul or cmove, cast to int16
(UPat.var("a", dtypes.int8s) * UPat.var("b"), lambda a,b: (a.cast(dtypes.int16) * b.cast(dtypes.int16)).cast(a.dtype)),
(UPat.var("m").where(UPat.var("a", (dtypes.bool,)+dtypes.int8s), UPat.var("b")),
lambda m,a,b: m.where(a.cast(dtypes.int16), b.cast(dtypes.int16)).cast(a.dtype) if a.dtype.count == 1 else None),
lambda m,a,b: m.where(a.cast(dtypes.int16), b.cast(dtypes.int16)).cast(a.dtype) if a.max_numel() == 1 else None),
# float16 alus are done in float32
(UPat(GroupOp.ALU, dtypes.float16, name="x"), lambda x: UOp(x.op, dtypes.float.vec(x.dtype.count),
tuple(s.cast(dtypes.float) if s.dtype != dtypes.bool else s for s in x.src)).cast(x.dtype)),
(UPat(GroupOp.ALU, dtypes.float16, name="x"), lambda x: UOp(Ops.CAST, x.dtype.scalar(), (UOp(x.op, dtypes.float,
tuple(UOp(Ops.CAST, dtypes.float, (s,)) if s.dtype != dtypes.bool else s for s in x.src)),))),
(UPat(GroupOp.Comparison, src=(UPat.var("a", dtypes.float16), UPat.var("b")), name="x"),
lambda x,a,b: UOp(x.op, x.dtype, (a.cast(dtypes.float32), b.cast(dtypes.float32))).cast(x.dtype)),
# no cmpne for packed ints, y != x => !(y==x)
(UPat(Ops.CMPNE, src=(UPat.var("y", dtypes.ints), UPat.var("x")), name="cmp"),
lambda y,x,cmp: UOp(Ops.CMPEQ, cmp.dtype, (y,x))^True if y.dtype.count > 1 else None),
lambda y,x,cmp: UOp(Ops.CMPEQ, cmp.dtype, (y,x))^True if y.max_numel() > 1 else None),
# float where expects a mask
(UPat.var("m", dtypes.bool).where(UPat.var("a", dtypes.floats), UPat.var("b")),
lambda m,a,b: m.cast(a.dtype).ne(0).where(a, b) if m.src[0].dtype not in dtypes.floats else None),
@ -174,26 +174,26 @@ extra_matcher = PatternMatcher([
# ***** X86 pre instruction selection *****
def gated_load(ctx, base:UOp, idx:UOp, cast:UOp, alt:UOp, gate:UOp, x:UOp):
local = UOp(Ops.DEFINE_LOCAL, base.dtype.base.ptr(x.dtype.count, AddrSpace.LOCAL), arg=next(ctx))
local = UOp(Ops.DEFINE_LOCAL, base.dtype.base.ptr(x.max_numel(), AddrSpace.LOCAL), arg=next(ctx))
local_idx = local.index(UOp.const(dtypes.int32, 0), ptr=True)
ptr = gate.where(base.index(idx, ptr=True), local_idx).after((local_idx if x.dtype.count == 1 else local).store(alt))
ptr = gate.where(base.index(idx, ptr=True), local_idx).after((local_idx if x.max_numel() == 1 else local).store(alt))
return ptr.cast(cast.dtype).load(dtype=x.dtype)
def gated_store(base:UOp, idx:UOp, cast:UOp, gate:UOp, val:UOp):
local = UOp(Ops.DEFINE_LOCAL, base.dtype.base.ptr(val.dtype.count, AddrSpace.LOCAL), arg=-1)
local = UOp(Ops.DEFINE_LOCAL, base.dtype.base.ptr(val.max_numel(), AddrSpace.LOCAL), arg=-1)
ptr = gate.where(base.index(idx, ptr=True), local.index(UOp.const(dtypes.int32, 0), ptr=True))
return ptr.cast(cast.dtype).store(val)
# these must be done in a separate matcher because they violate the spec
pre_isel_matcher = PatternMatcher([
# zero extending scalar 32bit int is a noop
(UPat.var("y", dtypes.uint32).cast(dtypes.int64s, name="x"), lambda y,x: x.replace(op=Ops.NOOP) if y.dtype.count == 1 else None),
(UPat.var("y", dtypes.uint32).cast(dtypes.int64s, name="x"), lambda y,x: x.replace(op=Ops.NOOP) if y.max_numel() == 1 else None),
# cast between signed and unsigned int is a noop
(UPat.var("y", dtypes.ints+(dtypes.bool,)).cast(dtypes.ints, name="x"),
lambda y,x: x.replace(op=Ops.NOOP) if x.dtype.itemsize == y.dtype.itemsize else None),
# cast to < scalar int is a noop
(UPat.var("y", dtypes.ints).cast(dtypes.ints, name="x"),
lambda y,x: x.replace(op=Ops.NOOP) if x.dtype.itemsize < y.dtype.itemsize and y.dtype.count == 1 else None),
lambda y,x: x.replace(op=Ops.NOOP) if x.dtype.itemsize < y.dtype.itemsize and y.max_numel() == 1 else None),
# bitcasts between scalar floats and ints are real, rest are noops
(UPat.var("y").bitcast().named("x"), lambda y,x: None if y.dtype in dtypes.floats and x.dtype in dtypes.ints or \
y.dtype in dtypes.ints and x.dtype in dtypes.floats else x.replace(op=Ops.NOOP)),
@ -208,7 +208,7 @@ pre_isel_matcher = PatternMatcher([
# TODO: remove this once we allow all flag producing ops in cmove
# if gate in scalar int cmove is not a comparison need to add one to set the flag
(UPat.var("m", dtypes.bool).where(UPat.var("a"), UPat.var("b")),
lambda m,a,b: m.ne(0).where(a,b) if m.op not in GroupOp.Comparison and a.dtype.count == 1 else None),
lambda m,a,b: m.ne(0).where(a,b) if m.op not in GroupOp.Comparison and a.max_numel() == 1 else None),
])
# ***** X86 registers *****
@ -252,8 +252,8 @@ def cmp(x:UOp) -> UOp:
return x.ins(X86Ops.CMP, dtype=dtypes.void) if (i:=to_imm(x.src[1])) is None else x.ins(X86Ops.CMPi, dtype=dtypes.void, src=(x.src[0], i))
def vcmp(x:UOp) -> UOp:
v = imm(dtypes.uint8, {Ops.CMPLT: 1, Ops.CMPNE: 4, Ops.CMPEQ: 0}[x.op])
if x.dtype.scalar() is dtypes.float32: return x.ins(X86Ops.VCMPSS if x.dtype.count == 1 else X86Ops.VCMPPS, src=x.src + (v,))
return x.ins(X86Ops.VCMPSD if x.dtype.count == 1 else X86Ops.VCMPPD, src=x.src + (v,))
if x.dtype.scalar() is dtypes.float32: return x.ins(X86Ops.VCMPSS if x.max_numel() == 1 else X86Ops.VCMPPS, src=x.src + (v,))
return x.ins(X86Ops.VCMPSD if x.max_numel() == 1 else X86Ops.VCMPPD, src=x.src + (v,))
# vshufps xmm2, xmm0, xmm1, imm
# for 128 bit xmm2 selects its lower 2 32 bits from xmm0 and its upper 2 32 bits from xmm1 according to imm
@ -296,7 +296,7 @@ def vpbroadcast(ctx:IselContext, x:UOp, y:UOp) -> UOp:
n = x.ins({1: X86Ops.VPBROADCASTB, 2: X86Ops.VPBROADCASTW, 4: X86Ops.VPBROADCASTD, 8: X86Ops.VPBROADCASTQ}[y.dtype.itemsize], src=(y,))
if y.op is Ops.LOAD and len(y.src) == 1 and is_foldable(ctx, n, y): return n
# if there isn't a load we can fold we need to move y from gpr to xmm
# this is hacky but required because int.vec(1) isn't supported
# this is hacky but required because scalar int bitcasts need a float register type
y = y if y.dtype.itemsize > 1 else y.cast(dtypes.int16)
return n.replace(src=(y.bitcast({2:dtypes.float16, 4:dtypes.float32, 8:dtypes.float64}[y.dtype.itemsize]),))
@ -351,17 +351,12 @@ def alloc_vregs(ctx:IselContext, x:UOp) -> UOp|None:
defs = []
if isinstance(x.tag, tuple): defs = [ctx.vreg(x.tag)]
elif x.dtype in dtypes.ints+(dtypes.bool,) or isinstance(x.dtype, PtrDType): defs = [ctx.vreg(WGPR)]
elif x.dtype in dtypes.floats or x.dtype.count > 1: defs = [ctx.vreg(XMM)]
elif x.dtype.scalar() in dtypes.floats or (x._shape is not None and x.max_numel() > 1) or (x.op is Ops.INS and x.arg.name.startswith("V")):
defs = [ctx.vreg(XMM)]
# TODO: add this once the scheduler can track register pressure
# if x.arg in X86GroupOp.WriteFlags: defs.append(ctx.vreg(RFLAGS))
return x.replace(tag=tuple(defs))
dts = dtypes.ints + (dtypes.bool, dtypes.float16, dtypes.float32, dtypes.float64)
dt_16bit = tuple(dt.vec(l) for dt in dts for l in [2,1] if l*dt.itemsize == 2 and dt not in dtypes.int16s)
dt_32bit = tuple(dt.vec(l) for dt in dts for l in [4,2,1] if l*dt.itemsize == 4 and dt not in dtypes.int32s)
dt_64bit = tuple(dt.vec(l) for dt in dts for l in [8,4,2,1] if l*dt.itemsize == 8 and dt not in dtypes.int64s)
dt_128bit = tuple(dt.vec(l) for dt in dts for l in [16,8,4,2,1] if l*dt.itemsize == 16)
isel_matcher = PatternMatcher([
# **** Op -> Op ****
# cast to pointer is a noop
@ -376,7 +371,7 @@ isel_matcher = PatternMatcher([
# add callee saved registers to the RET, these will be scheduled at the top of the kernel and will be saved/restored if they are used in regalloc
# so regalloc builds the prologue/epilogue naturally
(UPat(Ops.SINK, name="x"), lambda x:
x.replace(src=(x.ins(X86Ops.RET, src=x.src + tuple(def_reg(dtypes.uint64 if r in GPR else dtypes.float64.vec(2), r) for r in CALLEE_SAVED)),)) \
x.replace(src=(x.ins(X86Ops.RET, src=x.src + tuple(def_reg(dtypes.uint64 if r in GPR else dtypes.float64, r) for r in CALLEE_SAVED)),)) \
if not x.src or x.src[0].arg is not X86Ops.RET else None),
# function abi constraints
(UPat((Ops.PARAM, Ops.DEFINE_VAR, Ops.SPECIAL), name="x"), abi),
@ -390,16 +385,16 @@ isel_matcher = PatternMatcher([
UOp.const(dt:=to_int(x.dtype), struct.unpack(dt.fmt, struct.pack(x.dtype.fmt, x.arg))[0]).bitcast(x.dtype) if not x.tag else None),
# TODO: these should use a.maximum(b) / a.minimum(b)
((UPat.var("a") < UPat.var("b")).where(UPat.var("b", dtypes.float32), UPat.var("a")), lambda a,b:
a.ins(X86Ops.VMAXSS if a.dtype.count == 1 else X86Ops.VMAXPS, src=(a, b))),
a.ins(X86Ops.VMAXSS if a.max_numel() == 1 else X86Ops.VMAXPS, src=(a, b))),
((UPat.var("a") < UPat.var("b")).where(UPat.var("b", dtypes.float64), UPat.var("a")), lambda a,b:
a.ins(X86Ops.VMAXSD if a.dtype.count == 1 else X86Ops.VMAXPD, src=(a, b))),
a.ins(X86Ops.VMAXSD if a.max_numel() == 1 else X86Ops.VMAXPD, src=(a, b))),
((UPat.var("a") < UPat.var("b")).where(UPat.var("a", dtypes.float32), UPat.var("b")), lambda a,b:
a.ins(X86Ops.VMINSS if a.dtype.count == 1 else X86Ops.VMINPS, src=(a, b))),
a.ins(X86Ops.VMINSS if a.max_numel() == 1 else X86Ops.VMINPS, src=(a, b))),
((UPat.var("a") < UPat.var("b")).where(UPat.var("a", dtypes.float64), UPat.var("b")), lambda a,b:
a.ins(X86Ops.VMINSD if a.dtype.count == 1 else X86Ops.VMINPD, src=(a, b))),
a.ins(X86Ops.VMINSD if a.max_numel() == 1 else X86Ops.VMINPD, src=(a, b))),
# conditional moves that use masks NOTE: these currently assume a mask producing cmp exists
(UPat.var("m").where(UPat.var("a", dtypes.ints), UPat.var("b")), lambda m,a,b:
a.ins(X86Ops.VPBLENDVB, src=(b, a, m.replace(dtype=m.src[0].dtype))) if a.dtype.count > 1 else None),
a.ins(X86Ops.VPBLENDVB, src=(b, a, m.replace(dtype=m.src[0].dtype))) if a.max_numel() > 1 else None),
(UPat.var("m").where(UPat.var("a", dtypes.float32), UPat.var("b")), lambda m,a,b:
a.ins(X86Ops.VBLENDVPS, src=(b, a, m.replace(dtype=m.src[0].dtype)))),
(UPat.var("m").where(UPat.var("a", dtypes.float64), UPat.var("b")), lambda m,a,b:
@ -434,19 +429,19 @@ isel_matcher = PatternMatcher([
(UPat(Ops.CMPLT, src=(UPat.var("a", dtypes.int32s), UPat.var("b")), name="x"), lambda a,b,x: x.ins(X86Ops.VPCMPGTD, src=(b, a))),
(UPat(Ops.CMPLT, src=(UPat.var("a", dtypes.int64s), UPat.var("b")), name="x"), lambda a,b,x: x.ins(X86Ops.VPCMPGTQ, src=(b, a))),
# float unary
(UPat.var("y", dtypes.float32).sqrt().named("x"), lambda y,x: x.ins(X86Ops.VSQRTSS, src=(y, y)) if x.dtype.count == 1 else x.ins(X86Ops.VSQRTPS)),
(UPat.var("y", dtypes.float64).sqrt().named("x"), lambda y,x: x.ins(X86Ops.VSQRTSD, src=(y, y)) if x.dtype.count == 1 else x.ins(X86Ops.VSQRTPD)),
(UPat.var("y", dtypes.float32).sqrt().named("x"), lambda y,x: x.ins(X86Ops.VSQRTSS, src=(y, y)) if x.max_numel() == 1 else x.ins(X86Ops.VSQRTPS)),
(UPat.var("y", dtypes.float64).sqrt().named("x"), lambda y,x: x.ins(X86Ops.VSQRTSD, src=(y, y)) if x.max_numel() == 1 else x.ins(X86Ops.VSQRTPD)),
(UPat.var("y", dtypes.float32).trunc().named("x"), lambda y,x:
x.ins(X86Ops.VROUNDSS, src=(y, y, imm(dtypes.uint8, 3))) if x.dtype.count == 1 else x.ins(X86Ops.VROUNDPS, src=(y, imm(dtypes.uint8, 3)))),
x.ins(X86Ops.VROUNDSS, src=(y, y, imm(dtypes.uint8, 3))) if x.max_numel() == 1 else x.ins(X86Ops.VROUNDPS, src=(y, imm(dtypes.uint8, 3)))),
(UPat.var("y", dtypes.float64).trunc().named("x"), lambda y,x:
x.ins(X86Ops.VROUNDSD, src=(y, y, imm(dtypes.uint8, 3))) if x.dtype.count == 1 else x.ins(X86Ops.VROUNDPD, src=(y, imm(dtypes.uint8, 3)))),
x.ins(X86Ops.VROUNDSD, src=(y, y, imm(dtypes.uint8, 3))) if x.max_numel() == 1 else x.ins(X86Ops.VROUNDPD, src=(y, imm(dtypes.uint8, 3)))),
# shufles
(UPat.var("y", dtypes.float32).broadcast(name="x"), lambda y,x: x.ins(X86Ops.VBROADCASTSS, src=(y,))),
# for float16 we route the srcs through gprs unless we can fold them, this is suboptimal for values in xmms, in that case we want vpunpcklwd
(UPat(Ops.STACK, dtypes.float16, name="x"), lambda ctx,x:
vpins(x.replace(src=tuple(s if s.op is Ops.LOAD and is_foldable(ctx, x, s) else s.bitcast(dtypes.int16) for s in x.src)))),
(UPat(Ops.STACK, (dtypes.float32.vec(4), dtypes.float32.vec(8)), name="x"), vshufps),
(UPat(Ops.STACK, (dtypes.float64.vec(2), dtypes.float64.vec(4)), name="x"), vshufpd),
(UPat(Ops.STACK, dtypes.float32, name="x"), lambda x: vshufps(x) if x.max_numel() in (4, 8) else None),
(UPat(Ops.STACK, dtypes.float64, name="x"), lambda x: vshufpd(x) if x.max_numel() in (2, 4) else None),
(UPat(Ops.STACK, dtypes.float32, name="x"), vinsertps),
(UPat.var("y", dtypes.ints+(dtypes.bool,)).broadcast(name="x"), vpbroadcast),
(UPat(Ops.STACK, dtypes.ints+(dtypes.bool,), name="x"), vpins),
@ -458,29 +453,29 @@ isel_matcher = PatternMatcher([
(UPat.var("y", dtypes.floats).gep(name="x"), lambda y,x: x.ins(X86Ops.VPSRLDQ, src=(y, imm(dtypes.uint8, x.arg[0] * x.dtype.itemsize)))),
# fused multiply add
((UPat(Ops.MUL, dtypes.float32, name="a") + UPat.var("b")).named("c"), lambda ctx,a,b,c:
a.ins(X86Ops.VFMADD213SS if a.dtype.count == 1 else X86Ops.VFMADD213PS, src=(*a.src, b)) if is_foldable(ctx, c, a) else None),
a.ins(X86Ops.VFMADD213SS if a.max_numel() == 1 else X86Ops.VFMADD213PS, src=(*a.src, b)) if is_foldable(ctx, c, a) else None),
((UPat(Ops.MUL, dtypes.float64, name="a") + UPat.var("b")).named("c"), lambda ctx,a,b,c:
a.ins(X86Ops.VFMADD213SD if a.dtype.count == 1 else X86Ops.VFMADD213PD, src=(*a.src, b)) if is_foldable(ctx, c, a) else None),
a.ins(X86Ops.VFMADD213SD if a.max_numel() == 1 else X86Ops.VFMADD213PD, src=(*a.src, b)) if is_foldable(ctx, c, a) else None),
# packed bitwise
((UPat() & UPat()).named("x"), lambda x: x.ins(X86Ops.VPAND) if x.dtype.count > 1 else None),
((UPat() | UPat()).named("x"), lambda x: x.ins(X86Ops.VPOR) if x.dtype.count > 1 else None),
((UPat() ^ UPat()).named("x"), lambda x: x.ins(X86Ops.VPXOR) if x.dtype.count > 1 else None),
((UPat() & UPat()).named("x"), lambda x: x.ins(X86Ops.VPAND) if x.max_numel() > 1 else None),
((UPat() | UPat()).named("x"), lambda x: x.ins(X86Ops.VPOR) if x.max_numel() > 1 else None),
((UPat() ^ UPat()).named("x"), lambda x: x.ins(X86Ops.VPXOR) if x.max_numel() > 1 else None),
# packed int binary
((UPat(dtype=dtypes.int32s) << UPat()).named("x"), lambda x: x.ins(X86Ops.VPSLLVD) if x.dtype.count > 1 else None),
((UPat(dtype=dtypes.int64s) << UPat()).named("x"), lambda x: x.ins(X86Ops.VPSLLVQ) if x.dtype.count > 1 else None),
((UPat(dtype=dtypes.uint32) >> UPat()).named("x"), lambda x: x.ins(X86Ops.VPSRLVD) if x.dtype.count > 1 else None),
((UPat(dtype=dtypes.uint64) >> UPat()).named("x"), lambda x: x.ins(X86Ops.VPSRLVQ) if x.dtype.count > 1 else None),
((UPat(dtype=dtypes.int32) >> UPat()).named("x"), lambda x: x.ins(X86Ops.VPSRAVD) if x.dtype.count > 1 else None),
((UPat(dtype=dtypes.int8s) + UPat()).named("x"), lambda x: x.ins(X86Ops.VPADDB) if x.dtype.count > 1 else None),
((UPat(dtype=dtypes.int16s) + UPat()).named("x"), lambda x: x.ins(X86Ops.VPADDW) if x.dtype.count > 1 else None),
((UPat(dtype=dtypes.int32s) + UPat()).named("x"), lambda x: x.ins(X86Ops.VPADDD) if x.dtype.count > 1 else None),
((UPat(dtype=dtypes.int64s) + UPat()).named("x"), lambda x: x.ins(X86Ops.VPADDQ) if x.dtype.count > 1 else None),
(UPat(Ops.SUB, dtypes.int8s, name="x"), lambda x: x.ins(X86Ops.VPSUBB) if x.dtype.count > 1 else None),
(UPat(Ops.SUB, dtypes.int16s, name="x"), lambda x: x.ins(X86Ops.VPSUBW) if x.dtype.count > 1 else None),
(UPat(Ops.SUB, dtypes.int32s, name="x"), lambda x: x.ins(X86Ops.VPSUBD) if x.dtype.count > 1 else None),
(UPat(Ops.SUB, dtypes.int64s, name="x"), lambda x: x.ins(X86Ops.VPSUBQ) if x.dtype.count > 1 else None),
(UPat(Ops.MUL, dtypes.int16s, name="x"), lambda x: x.ins(X86Ops.VPMULLW) if x.dtype.count > 1 else None),
(UPat(Ops.MUL, dtypes.int32s, name="x"), lambda x: x.ins(X86Ops.VPMULLD) if x.dtype.count > 1 else None),
((UPat(dtype=dtypes.int32s) << UPat()).named("x"), lambda x: x.ins(X86Ops.VPSLLVD) if x.max_numel() > 1 else None),
((UPat(dtype=dtypes.int64s) << UPat()).named("x"), lambda x: x.ins(X86Ops.VPSLLVQ) if x.max_numel() > 1 else None),
((UPat(dtype=dtypes.uint32) >> UPat()).named("x"), lambda x: x.ins(X86Ops.VPSRLVD) if x.max_numel() > 1 else None),
((UPat(dtype=dtypes.uint64) >> UPat()).named("x"), lambda x: x.ins(X86Ops.VPSRLVQ) if x.max_numel() > 1 else None),
((UPat(dtype=dtypes.int32) >> UPat()).named("x"), lambda x: x.ins(X86Ops.VPSRAVD) if x.max_numel() > 1 else None),
((UPat(dtype=dtypes.int8s) + UPat()).named("x"), lambda x: x.ins(X86Ops.VPADDB) if x.max_numel() > 1 else None),
((UPat(dtype=dtypes.int16s) + UPat()).named("x"), lambda x: x.ins(X86Ops.VPADDW) if x.max_numel() > 1 else None),
((UPat(dtype=dtypes.int32s) + UPat()).named("x"), lambda x: x.ins(X86Ops.VPADDD) if x.max_numel() > 1 else None),
((UPat(dtype=dtypes.int64s) + UPat()).named("x"), lambda x: x.ins(X86Ops.VPADDQ) if x.max_numel() > 1 else None),
(UPat(Ops.SUB, dtypes.int8s, name="x"), lambda x: x.ins(X86Ops.VPSUBB) if x.max_numel() > 1 else None),
(UPat(Ops.SUB, dtypes.int16s, name="x"), lambda x: x.ins(X86Ops.VPSUBW) if x.max_numel() > 1 else None),
(UPat(Ops.SUB, dtypes.int32s, name="x"), lambda x: x.ins(X86Ops.VPSUBD) if x.max_numel() > 1 else None),
(UPat(Ops.SUB, dtypes.int64s, name="x"), lambda x: x.ins(X86Ops.VPSUBQ) if x.max_numel() > 1 else None),
(UPat(Ops.MUL, dtypes.int16s, name="x"), lambda x: x.ins(X86Ops.VPMULLW) if x.max_numel() > 1 else None),
(UPat(Ops.MUL, dtypes.int32s, name="x"), lambda x: x.ins(X86Ops.VPMULLD) if x.max_numel() > 1 else None),
# scalar int binary
((UPat(dtype=dtypes.ints).alu(Ops.CDIV, UPat())).named("x"), idiv),
# scalar int binary with immediate
@ -504,21 +499,21 @@ isel_matcher = PatternMatcher([
(UPat.var("a", dtypes.ints+(dtypes.bool,)) ^ UPat.var("b"), lambda a,b: a.ins(X86Ops.XOR, src=(a, b))),
(UPat(Ops.SUB, dtypes.ints, (UPat.var("a"), UPat.var("b"))), lambda a,b: a.ins(X86Ops.SUB, src=(a, b))),
# float binary
((UPat(dtype=dtypes.float32) + UPat()).named("x"), lambda x: x.ins(X86Ops.VADDSS if x.dtype.count == 1 else X86Ops.VADDPS)),
((UPat(dtype=dtypes.float64) + UPat()).named("x"), lambda x: x.ins(X86Ops.VADDSD if x.dtype.count == 1 else X86Ops.VADDPD)),
((UPat(dtype=dtypes.float32) * UPat()).named("x"), lambda x: x.ins(X86Ops.VMULSS if x.dtype.count == 1 else X86Ops.VMULPS)),
((UPat(dtype=dtypes.float64) * UPat()).named("x"), lambda x: x.ins(X86Ops.VMULSD if x.dtype.count == 1 else X86Ops.VMULPD)),
(UPat(Ops.SUB, dtypes.float32, name="x"), lambda x: x.ins(X86Ops.VSUBSS if x.dtype.count == 1 else X86Ops.VSUBPS)),
(UPat(Ops.SUB, dtypes.float64, name="x"), lambda x: x.ins(X86Ops.VSUBSD if x.dtype.count == 1 else X86Ops.VSUBPD)),
(UPat(Ops.FDIV, dtypes.float32, name="x"), lambda x: x.ins(X86Ops.VDIVSS if x.dtype.count == 1 else X86Ops.VDIVPS)),
(UPat(Ops.FDIV, dtypes.float64, name="x"), lambda x: x.ins(X86Ops.VDIVSD if x.dtype.count == 1 else X86Ops.VDIVPD)),
((UPat(dtype=dtypes.float32) + UPat()).named("x"), lambda x: x.ins(X86Ops.VADDSS if x.max_numel() == 1 else X86Ops.VADDPS)),
((UPat(dtype=dtypes.float64) + UPat()).named("x"), lambda x: x.ins(X86Ops.VADDSD if x.max_numel() == 1 else X86Ops.VADDPD)),
((UPat(dtype=dtypes.float32) * UPat()).named("x"), lambda x: x.ins(X86Ops.VMULSS if x.max_numel() == 1 else X86Ops.VMULPS)),
((UPat(dtype=dtypes.float64) * UPat()).named("x"), lambda x: x.ins(X86Ops.VMULSD if x.max_numel() == 1 else X86Ops.VMULPD)),
(UPat(Ops.SUB, dtypes.float32, name="x"), lambda x: x.ins(X86Ops.VSUBSS if x.max_numel() == 1 else X86Ops.VSUBPS)),
(UPat(Ops.SUB, dtypes.float64, name="x"), lambda x: x.ins(X86Ops.VSUBSD if x.max_numel() == 1 else X86Ops.VSUBPD)),
(UPat(Ops.FDIV, dtypes.float32, name="x"), lambda x: x.ins(X86Ops.VDIVSS if x.max_numel() == 1 else X86Ops.VDIVPS)),
(UPat(Ops.FDIV, dtypes.float64, name="x"), lambda x: x.ins(X86Ops.VDIVSD if x.max_numel() == 1 else X86Ops.VDIVPD)),
# casts
(UPat(dtype=dtypes.int32).cast(dtypes.float32, name="x"), lambda x: x.ins(X86Ops.VCVTDQ2PS) if x.dtype.count > 1 else None),
(UPat(dtype=dtypes.int32).cast(dtypes.float64, name="x"), lambda x: x.ins(X86Ops.VCVTDQ2PD) if x.dtype.count > 1 else None),
(UPat(dtype=dtypes.float32).cast(dtypes.int32s, name="x"), lambda x: x.ins(X86Ops.VCVTTPS2DQ) if x.dtype.count > 1 else None),
(UPat(dtype=dtypes.float64).cast(dtypes.int32s, name="x"), lambda x: x.ins(X86Ops.VCVTTPD2DQ) if x.dtype.count > 1 else None),
(UPat(dtype=dtypes.float32).cast(dtypes.float64, name="x"), lambda x: x.ins(X86Ops.VCVTPS2PD) if x.dtype.count > 1 else None),
(UPat(dtype=dtypes.float64).cast(dtypes.float32, name="x"), lambda x: x.ins(X86Ops.VCVTPD2PS) if x.dtype.count > 1 else None),
(UPat(dtype=dtypes.int32).cast(dtypes.float32, name="x"), lambda x: x.ins(X86Ops.VCVTDQ2PS) if x.max_numel() > 1 else None),
(UPat(dtype=dtypes.int32).cast(dtypes.float64, name="x"), lambda x: x.ins(X86Ops.VCVTDQ2PD) if x.max_numel() > 1 else None),
(UPat(dtype=dtypes.float32).cast(dtypes.int32s, name="x"), lambda x: x.ins(X86Ops.VCVTTPS2DQ) if x.max_numel() > 1 else None),
(UPat(dtype=dtypes.float64).cast(dtypes.int32s, name="x"), lambda x: x.ins(X86Ops.VCVTTPD2DQ) if x.max_numel() > 1 else None),
(UPat(dtype=dtypes.float32).cast(dtypes.float64, name="x"), lambda x: x.ins(X86Ops.VCVTPS2PD) if x.max_numel() > 1 else None),
(UPat(dtype=dtypes.float64).cast(dtypes.float32, name="x"), lambda x: x.ins(X86Ops.VCVTPD2PS) if x.max_numel() > 1 else None),
(UPat(dtype=dtypes.float32).cast(dtypes.float16, name="x"), lambda x: x.ins(X86Ops.VCVTPS2PH, src=x.src + (imm(dtypes.uint8, 4),))),
(UPat(dtype=dtypes.float16).cast(dtypes.float32, name="x"), lambda x: x.ins(X86Ops.VCVTPH2PS)),
(UPat(dtype=dtypes.float32).cast(dtypes.int32s+dtypes.int64s, name="x"), lambda x: x.ins(X86Ops.VCVTTSS2SI)),
@ -527,9 +522,9 @@ isel_matcher = PatternMatcher([
(UPat.var("y", dtypes.float64).cast(dtypes.float32, name="x"), lambda y,x: x.ins(X86Ops.VCVTSD2SS, src=(y, y))),
(UPat.var("y", (dtypes.int32, dtypes.int64)).cast(dtypes.float32, name="x"), lambda y,x: x.ins(X86Ops.VCVTSI2SS, src=(def_reg(x.dtype), y))),
(UPat.var("y", (dtypes.int32, dtypes.int64)).cast(dtypes.float64, name="x"), lambda y,x: x.ins(X86Ops.VCVTSI2SD, src=(def_reg(x.dtype), y))),
(UPat(dtype=dtypes.uints+(dtypes.bool,)).cast(dtypes.ints, name="x"), lambda x: x.ins(X86Ops.MOVZX) if x.dtype.count == 1 else None),
(UPat(dtype=dtypes.int32).cast(dtypes.int64s, name="x"), lambda x: x.ins(X86Ops.MOVSXD) if x.dtype.count == 1 else None),
(UPat(dtype=dtypes.sints).cast(dtypes.ints, name="x"), lambda x: x.ins(X86Ops.MOVSX) if x.dtype.count == 1 else None),
(UPat(dtype=dtypes.uints+(dtypes.bool,)).cast(dtypes.ints, name="x"), lambda x: x.ins(X86Ops.MOVZX) if x.max_numel() == 1 else None),
(UPat(dtype=dtypes.int32).cast(dtypes.int64s, name="x"), lambda x: x.ins(X86Ops.MOVSXD) if x.max_numel() == 1 else None),
(UPat(dtype=dtypes.sints).cast(dtypes.ints, name="x"), lambda x: x.ins(X86Ops.MOVSX) if x.max_numel() == 1 else None),
(UPat(dtype=(dtypes.uint8, dtypes.bool)).cast(dtypes.int16s, name="x"), lambda x: x.ins(X86Ops.VPMOVZXBW)),
(UPat(dtype=(dtypes.uint8, dtypes.bool)).cast(dtypes.int32s, name="x"), lambda x: x.ins(X86Ops.VPMOVZXBD)),
(UPat(dtype=(dtypes.uint8, dtypes.bool)).cast(dtypes.int64s, name="x"), lambda x: x.ins(X86Ops.VPMOVZXBQ)),
@ -554,20 +549,32 @@ isel_matcher = PatternMatcher([
# TODO: fuse stores, very few cases -- store cmp becomes setcc, store gep int becomes vpextr, store bitcast to int becomes vmovd/q
# copy, load, store
# NOTE: copy here violates the spec, it only happens post register allocation when a reg to reg move needs to be inserted
(UPat(Ops.COPY, dt_128bit, name="x"), lambda x: x.ins(X86Ops.VMOVUPS)),
(UPat(Ops.COPY, dt_64bit, name="x"), lambda x: x.ins(X86Ops.VMOVSD)),
(UPat(Ops.COPY, dt_32bit+dt_16bit, name="x"), lambda x: x.ins(X86Ops.VMOVSS)),
(UPat(Ops.COPY, name="x"), lambda x: x.ins(X86Ops.VMOVUPS) if x.dtype.itemsize == 16 else None),
(UPat(Ops.COPY, name="x"), lambda x: x.ins(X86Ops.VMOVSD) if x.dtype.itemsize == 8 and x.dtype.scalar() not in dtypes.int64s else None),
(UPat(Ops.COPY, name="x"), lambda x:
x.ins(X86Ops.VMOVSS) if x.dtype.itemsize in (2, 4) and
(x.dtype.scalar() in dtypes.floats or (x._shape is not None and x.max_numel() > 1)) else None),
(UPat(Ops.COPY, dtypes.ints+(dtypes.bool,), name="x"), lambda x: x.ins(X86Ops.MOV)),
(UPat(Ops.LOAD, dt_128bit, src=(UPat(name="a"),), name="x"), lambda x,a: x.ins(X86Ops.VMOVUPS, src=fold_address(a))),
(UPat(Ops.LOAD, dt_64bit, src=(UPat(name="a"),), name="x"), lambda x,a: x.ins(X86Ops.VMOVSD, src=fold_address(a))),
(UPat(Ops.LOAD, dt_32bit, src=(UPat(name="a"),), name="x"), lambda x,a: x.ins(X86Ops.VMOVSS, src=fold_address(a))),
(UPat(Ops.LOAD, dt_16bit, src=(UPat(name="a"),), name="x"), lambda x,a:
x.ins(X86Ops.VPINSRW, src=(def_reg(x.dtype, x.tag),) + fold_address(a) + (imm(dtypes.uint8, 0),))),
(UPat(Ops.LOAD, src=(UPat(name="a"),), name="x"), lambda x,a: x.ins(X86Ops.VMOVUPS, src=fold_address(a)) if x.dtype.itemsize == 16 else None),
(UPat(Ops.LOAD, src=(UPat(name="a"),), name="x"), lambda x,a:
x.ins(X86Ops.VMOVSD, src=fold_address(a)) if x.dtype.itemsize == 8 and x.dtype.scalar() not in dtypes.int64s else None),
(UPat(Ops.LOAD, src=(UPat(name="a"),), name="x"), lambda x,a:
x.ins(X86Ops.VMOVSS, src=fold_address(a))
if x.dtype.itemsize == 4 and (x.dtype.scalar() in dtypes.floats or (x._shape is not None and x.max_numel() > 1)) else None),
(UPat(Ops.LOAD, src=(UPat(name="a"),), name="x"), lambda x,a:
x.ins(X86Ops.VPINSRW, src=(def_reg(x.dtype, x.tag),) + fold_address(a) + (imm(dtypes.uint8, 0),))
if x.dtype.itemsize == 2 and (x.dtype.scalar() in dtypes.floats or (x._shape is not None and x.max_numel() > 1)) else None),
(UPat(Ops.LOAD, dtypes.ints+(dtypes.bool,), src=(UPat(name="a"),), name="x"), lambda x,a: x.ins(X86Ops.MOV, src=fold_address(a))),
(UPat.var("a").store(UPat.var("b", dt_128bit), name="x"), lambda a,b,x: x.ins(X86Ops.VMOVUPSm, src=fold_address(a) + (b,))),
(UPat.var("a").store(UPat.var("b", dt_64bit), name="x"), lambda a,b,x: x.ins(X86Ops.VMOVSDm, src=fold_address(a) + (b,))),
(UPat.var("a").store(UPat.var("b", dt_32bit), name="x"), lambda a,b,x: x.ins(X86Ops.VMOVSSm, src=fold_address(a) + (b,))),
(UPat.var("a").store(UPat.var("b", dt_16bit), name="x"), lambda a,b,x: x.ins(X86Ops.VPEXTRW, src=fold_address(a) + (b, imm(dtypes.uint8, 0)))),
(UPat.var("a").store(UPat.var("b"), name="x"), lambda a,b,x:
x.ins(X86Ops.VMOVUPSm, src=fold_address(a) + (b,)) if b.dtype.itemsize == 16 else None),
(UPat.var("a").store(UPat.var("b"), name="x"), lambda a,b,x:
x.ins(X86Ops.VMOVSDm, src=fold_address(a) + (b,)) if b.dtype.itemsize == 8 and b.dtype.scalar() not in dtypes.int64s else None),
(UPat.var("a").store(UPat.var("b"), name="x"), lambda a,b,x:
x.ins(X86Ops.VMOVSSm, src=fold_address(a) + (b,))
if b.dtype.itemsize == 4 and (b.dtype.scalar() in dtypes.floats or (b._shape is not None and b.max_numel() > 1)) else None),
(UPat.var("a").store(UPat.var("b"), name="x"), lambda a,b,x:
x.ins(X86Ops.VPEXTRW, src=fold_address(a) + (b, imm(dtypes.uint8, 0)))
if b.dtype.itemsize == 2 and (b.dtype.scalar() in dtypes.floats or (b._shape is not None and b.max_numel() > 1)) else None),
(UPat.var("a").store(UPat.var("b", dtypes.ints+(dtypes.bool,)), name="x"), lambda a,b,x:
x.ins(X86Ops.MOVm, src=fold_address(a) + (b,)) if (i:=to_imm(b)) is None else x.ins(X86Ops.MOVi, src=fold_address(a) + (i,))),
# **** X86Op -> X86Op ****

View file

@ -8,12 +8,12 @@ from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, GroupOp, range_str
from tinygrad.dtype import dtypes, float_to_fp8, DType, PtrDType, truncate
from tinygrad.helpers import prod, Target, CPU_COUNT, getenv, OSX
def ldt(dt:DType):
if dt.vcount > 1: return f"<{dt.vcount} x {ldt(dt.scalar())}>"
if isinstance(dt, PtrDType): return ldt(dt.base) + "*"
def ldt(dt:DType, lanes:int=1):
if isinstance(dt, PtrDType): return ldt(dt.base.scalar()) + "*"
if lanes > 1: return f"<{lanes} x {ldt(dt.scalar())}>"
return {dtypes.void: "void", dtypes.bool: "i1", dtypes.int8: "i8", dtypes.int16: "i16", dtypes.int32: "i32", dtypes.int64: "i64",
dtypes.uint8: "i8", dtypes.uint16: "i16", dtypes.uint32: "i32", dtypes.uint64: "i64", dtypes.fp8e4m3: "i8", dtypes.fp8e5m2: "i8",
dtypes.float16: "half", dtypes.bfloat16: "bfloat", dtypes.float32: "float", dtypes.float64: "double"}[dt]
dtypes.float16: "half", dtypes.bfloat16: "bfloat", dtypes.float32: "float", dtypes.float64: "double"}[dt.scalar()]
def lconst(x, dtype:DType):
if dtype in dtypes.floats:
@ -39,13 +39,14 @@ def render_wmma_amx(ctx, wmma: UOp) -> str:
def AMX(op, gpr): return f'call void asm sideeffect ".word (0x201000+($0<<5)+0$1-((0$1>>4)*6))", "i,r,~{{memory}}"(i32 {op}, i64 {gpr}) #0; AMX'
return "\n".join([
*[f' store {ldt(src.dtype)} {ctx[src]}, {ldt(src.dtype.ptr())} {ctx[wmma]}_amx{i}, align {src.dtype.itemsize}' for i,src in enumerate(wmma.src)],
*[f' store {ldt(src.dtype, src.max_numel())} {ctx[src]}, '
f'{ldt(src.dtype, src.max_numel())}* {ctx[wmma]}_amx{i}, align {src.dtype.itemsize}' for i,src in enumerate(wmma.src)],
f' call void asm sideeffect "nop\\0Anop\\0Anop\\0A.word ({0x201000 + (17 << 5) + 0})", "~{{memory}}"() #0; AMX set', # set
*[f' {ctx[wmma]}_ld{i} = add i64 {ctx[wmma]}_ptr_amx2, {i*4<<56 | i*64}\n {AMX(4,f"{ctx[wmma]}_ld{i}")} ldz' for i in range(16)], # ldz
f' {AMX(0, f"{ctx[wmma]}_ptr_amx1")} ldx\n {AMX(1, f"{ctx[wmma]}_ptr_amx0")} ldy\n {AMX(12, 0)} fma32', # ldx ldy fma
*[f' {ctx[wmma]}_st{i} = add i64 {ctx[wmma]}_ptr_amx2, {i*4<<56 | i*64}\n {AMX(5,f"{ctx[wmma]}_st{i}")} stz' for i in range(16)], # stz
f' call void asm sideeffect "nop\\0Anop\\0Anop\\0A.word ({0x201000 + (17 << 5) + 1})", "~{{memory}}"() #0; AMX clr', # clr
f' {ctx[wmma]} = load {ldt(wmma.dtype)}, ptr {ctx[wmma]}_amx2, align {wmma.dtype.itemsize}'])
f' {ctx[wmma]} = load {ldt(wmma.dtype, wmma.max_numel())}, ptr {ctx[wmma]}_amx2, align {wmma.dtype.itemsize}'])
def render_wmma_amd(ctx, wmma: UOp, cdna=False) -> str:
dt_map = {dtypes.half: "f16", dtypes.float: "f32", dtypes.ushort: "bf16.1k" if cdna else "bf16", dtypes.bfloat16: "bf16.1k" if cdna else "bf16",
@ -54,12 +55,12 @@ def render_wmma_amd(ctx, wmma: UOp, cdna=False) -> str:
N,M,K = wmma.arg[1]
if cdna:
if K == 32: dt_map.update({dtypes.half: ".f16", dtypes.bfloat16: ".bf16"})
return f" {ctx[wmma]} = call {ldt(wmma.dtype)} @llvm.amdgcn.mfma.{dt_map[wmma.src[-1].dtype.scalar()]}" + \
f".{N}x{M}x{K}{dt_map[wmma.arg[2]]}(" + ", ".join([f"{ldt(w.dtype)} {ctx[w]}" for w in wmma.src]) + ", i32 0, i32 0, i32 0)"
return f" {ctx[wmma]} = call {ldt(wmma.dtype, wmma.max_numel())} @llvm.amdgcn.mfma.{dt_map[wmma.src[-1].dtype.scalar()]}" + \
f".{N}x{M}x{K}{dt_map[wmma.arg[2]]}(" + ", ".join([f"{ldt(w.dtype, w.max_numel())} {ctx[w]}" for w in wmma.src]) + ", i32 0, i32 0, i32 0)"
# https://github.com/llvm/llvm-project/blob/main/llvm/test/CodeGen/AMDGPU/GlobalISel/llvm.amdgcn.wmma_32.ll
# example: %wmma0 = call <8 x float> @llvm.amdgcn.wmma.f32.16x16x16.f16(<16 x half> %v99,<16 x half> %v100,<8 x float> %v101)
return f" {ctx[wmma]} = call {ldt(wmma.dtype)} @llvm.amdgcn.wmma.{dt_map[wmma.src[-1].dtype.scalar()]}.16x16x16." + \
f"{dt_map[wmma.src[0].dtype.scalar()]}(" + ", ".join([f"{ldt(w.dtype)} {ctx[w]}" for w in wmma.src]) + (", i1 false)" \
return f" {ctx[wmma]} = call {ldt(wmma.dtype, wmma.max_numel())} @llvm.amdgcn.wmma.{dt_map[wmma.src[-1].dtype.scalar()]}.16x16x16." + \
f"{dt_map[wmma.src[0].dtype.scalar()]}(" + ", ".join([f"{ldt(w.dtype, w.max_numel())} {ctx[w]}" for w in wmma.src]) + (", i1 false)" \
if wmma.dtype.scalar() != dtypes.float else ")")
# llvm ops, lop[<dtype>][<op>]
@ -75,35 +76,47 @@ lop = {**{x:unsigned_lop for x in (dtypes.bool,)+dtypes.uints}, **{x:signed_lop
base_rewrite = PatternMatcher([
# memory load/store
(UPat(Ops.INDEX, name="x"), lambda ctx,x:
f" {ctx[x]} = getelementptr inbounds {ldt(x.dtype.base)}, {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, {ldt(x.src[1].dtype)} {ctx[x.src[1]]}"),
f" {ctx[x]} = getelementptr inbounds {ldt(x.dtype.base)}, {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, "
f"{ldt(x.src[1].dtype, x.src[1].max_numel())} {ctx[x.src[1]]}"),
(UPat(Ops.LOAD, src=(UPat.var("idx"), UPat.var("alt"), UPat.var("mask")), name="x"),
lambda ctx,x,idx,alt,mask:
f" br label {ctx[x]}_entry\n{ctx[x][1:]}_entry:\n"
f" br i1 {ctx[mask]}, label {ctx[x]}_load, label {ctx[x]}_exit\n{ctx[x][1:]}_load:\n"
f" {ctx[x]}_yes = load {ldt(x.dtype)}, {ldt(idx.dtype)} {ctx[idx]}\n"
f" {ctx[x]}_yes = load {ldt(x.dtype, x.max_numel())}, {ldt(idx.dtype)} {ctx[idx]}\n"
f" br label {ctx[x]}_exit\n{ctx[x][1:]}_exit:\n"
f" {ctx[x]} = phi {ldt(x.dtype)} [{ctx[x]}_yes, {ctx[x]}_load], [{ctx[alt]}, {ctx[x]}_entry]"),
f" {ctx[x]} = phi {ldt(x.dtype, x.max_numel())} [{ctx[x]}_yes, {ctx[x]}_load], [{ctx[alt]}, {ctx[x]}_entry]"),
(UPat(Ops.LOAD, src=(UPat.var('idx'),), name="x"),
lambda ctx,x,idx: f" {ctx[x]} = load {ldt(x.dtype)}, {ldt(idx.dtype)} {ctx[idx]}"),
(UPat(Ops.STORE, name="x"), lambda ctx,x: f" store {ldt(x.src[1].dtype)} {ctx[x.src[1]]}, {ldt(x.src[0].dtype)} {ctx[x.src[0]]}"),
lambda ctx,x,idx: f" {ctx[x]} = load {ldt(x.dtype, x.max_numel())}, {ldt(idx.dtype)} {ctx[idx]}"),
(UPat(Ops.STORE, name="x"), lambda ctx,x:
f" store {ldt(x.src[1].dtype, x.src[1].max_numel())} {ctx[x.src[1]]}, "
f"{ldt(x.src[0].dtype)} {ctx[x.src[0]]}"),
# GEP/VECTORIZE/CAST for float4 support
(UPat(Ops.GEP, name="x"), lambda ctx,x: f" {ctx[x]} = extractelement {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, i32 {x.arg[0]}"),
(UPat(Ops.GEP, name="x"), lambda ctx,x: f" {ctx[x]} = extractelement {ldt(x.src[0].dtype, x.src[0].max_numel())} {ctx[x.src[0]]}, i32 {x.arg[0]}"),
(UPat(Ops.STACK, src=UPat.var('y'), name="x"), lambda ctx,x,y:
f" {ctx[x]}_z = insertelement <1 x {ldt(y.dtype)}> poison, {ldt(y.dtype)} {ctx[y]}, i32 0\n"
f" {ctx[x]} = shufflevector <1 x {ldt(y.dtype)}> {ctx[x]}_z, <1 x {ldt(y.dtype)}> poison, <{x.dtype.count} x i32> zeroinitializer"),
f" {ctx[x]}_z = insertelement <1 x {ldt(y.dtype, y.max_numel())}> poison, {ldt(y.dtype, y.max_numel())} {ctx[y]}, i32 0\n"
f" {ctx[x]} = shufflevector <1 x {ldt(y.dtype, y.max_numel())}> {ctx[x]}_z, "
f"<1 x {ldt(y.dtype, y.max_numel())}> poison, <{x.max_numel()} x i32> zeroinitializer"),
(UPat(Ops.STACK, name="x"), lambda ctx,x: "\n".join([(f" {ctx[x]}_{i}" if i+1 != len(x.src) else f" {ctx[x]}")+
f" = insertelement {ldt(x.dtype)} "+(f"{ctx[x]}_{i-1}" if i != 0 else "poison")+
f", {ldt(u.dtype)} {ctx[u]}, i32 {i}" for i,u in enumerate(x.src)])),
f" = insertelement {ldt(x.dtype, x.max_numel())} "+
(f"{ctx[x]}_{i-1}" if i != 0 else "poison")+
f", {ldt(u.dtype, u.max_numel())} {ctx[u]}, i32 {i}" for i,u in enumerate(x.src)])),
# unary/binary/ternary ops
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f" {ctx[x]} = bitcast {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}"),
(UPat(Ops.CAST, name="x"), lambda ctx,x: f" {ctx[x]} = {lcast(x.src[0].dtype, x.dtype)} {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}"),
(UPat(Ops.BITCAST, name="x"), lambda ctx,x:
f" {ctx[x]} = bitcast {ldt(x.src[0].dtype, x.src[0].max_numel())} {ctx[x.src[0]]} "
f"to {ldt(x.dtype, x.max_numel())}"),
(UPat(Ops.CAST, name="x"), lambda ctx,x:
f" {ctx[x]} = {lcast(x.src[0].dtype, x.dtype)} {ldt(x.src[0].dtype, x.src[0].max_numel())} {ctx[x.src[0]]} "
f"to {ldt(x.dtype, x.max_numel())}"),
(UPat(Ops.TRUNC, name="x"),
lambda ctx,x: f" {ctx[x]} = call {ldt(x.dtype)} @llvm.trunc.{ldt(x.dtype.scalar())}({ldt(x.src[0].dtype)} {ctx[x.src[0]]})"),
lambda ctx,x: f" {ctx[x]} = call {ldt(x.dtype, x.max_numel())} @llvm.trunc.{ldt(x.dtype.scalar())}"
f"({ldt(x.src[0].dtype, x.src[0].max_numel())} {ctx[x.src[0]]})"),
(UPat(GroupOp.Binary, name="x"), lambda ctx,x:
f" {ctx[x]} = {lop[x.src[0].dtype.scalar()][x.op]} {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, {ctx[x.src[1]]}"),
f" {ctx[x]} = {lop[x.src[0].dtype.scalar()][x.op]} {ldt(x.src[0].dtype, x.src[0].max_numel())} {ctx[x.src[0]]}, {ctx[x.src[1]]}"),
(UPat(Ops.WHERE, name="x"), lambda ctx,x:
f" {ctx[x]} = select {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, {ldt(x.src[1].dtype)} {ctx[x.src[1]]}, {ldt(x.src[2].dtype)} {ctx[x.src[2]]}"),
f" {ctx[x]} = select {ldt(x.src[0].dtype, x.src[0].max_numel())} {ctx[x.src[0]]}, "
f"{ldt(x.src[1].dtype, x.src[1].max_numel())} {ctx[x.src[1]]}, "
f"{ldt(x.src[2].dtype, x.src[2].max_numel())} {ctx[x.src[2]]}"),
# range
(UPat(Ops.RANGE, name="r"), lambda ctx,r:
@ -111,9 +124,9 @@ base_rewrite = PatternMatcher([
f"loop_entry_{range_str(r)}:\n"
f" br label %loop_latch_{range_str(r)}\n"
f"loop_latch_{range_str(r)}:\n"
f" {ctx[r]} = phi {ldt(r.dtype)} [ 0, %loop_entry_{range_str(r)} ], [ {ctx[r]}phi, %loop_footer_{range_str(r)} ]\n"
f" {ctx[r]}phi = add {ldt(r.dtype)} {ctx[r]}, 1\n"
f" {ctx[r]}cmp = icmp ult {ldt(r.dtype)} {ctx[r]}, {ctx[r.src[0]]}\n"
f" {ctx[r]} = phi {ldt(r.dtype, r.max_numel())} [ 0, %loop_entry_{range_str(r)} ], [ {ctx[r]}phi, %loop_footer_{range_str(r)} ]\n"
f" {ctx[r]}phi = add {ldt(r.dtype, r.max_numel())} {ctx[r]}, 1\n"
f" {ctx[r]}cmp = icmp ult {ldt(r.dtype, r.max_numel())} {ctx[r]}, {ctx[r.src[0]]}\n"
f" br i1 {ctx[r]}cmp, label %loop_body_{range_str(r)}, label %loop_exit_{range_str(r)}\n"
f"loop_body_{range_str(r)}:"),
(UPat(Ops.END, src=(UPat(), UPat(Ops.RANGE, name="r"))), lambda r:
@ -151,9 +164,10 @@ class LLVMRenderer(Renderer):
if self.tensor_cores == tc.amx and u.op is Ops.WMMA: # prealloc aux buffers as AMX can only load from memory
vc += 1
r[u] = f"%wmma{vc}"
for i, dtype in enumerate(u.arg[2].vec(sz) for sz in [prod(size for _, size in upcast) for upcast in u.arg[6]]):
kernel += [f" {r[u]}_amx{i} = alloca {ldt(dtype)}, align {dtype.itemsize}",
f" {r[u]}_ptr_amx{i} = ptrtoint {ldt(dtype.ptr())} {r[u]}_amx{i} to i64"]
for i, sz in enumerate(prod(size for _, size in upcast) for upcast in u.arg[6]):
dtype = u.arg[2]
kernel += [f" {r[u]}_amx{i} = alloca {ldt(dtype, sz)}, align {dtype.itemsize*sz}",
f" {r[u]}_ptr_amx{i} = ptrtoint {ldt(dtype, sz)}* {r[u]}_amx{i} to i64"]
name = "test"
for u in uops:
@ -170,15 +184,17 @@ class LLVMRenderer(Renderer):
elif u.op in (Ops.DEFINE_LOCAL, Ops.DEFINE_REG):
r[u] = f"%{'local' if u.op is Ops.DEFINE_LOCAL else 'reg'}_{str(u.arg).replace('(', '').replace(')', '').replace(',', '_').replace(' ', '')}"
assert isinstance(u.dtype, PtrDType)
lanes = u.max_shape[-1] if len(u.max_shape) > 1 else 1
etype = ldt(u.dtype.base.scalar(), lanes)
if u.op is Ops.DEFINE_REG:
kernel.append(f" {r[u]} = alloca [{u.dtype.size} x {ldt(u.dtype.base)}]")
kernel.append(f" {r[u]} = alloca [{u.dtype.size} x {etype}]")
elif self.has_local:
local_args.append(f"@{r[u][1:]} = internal unnamed_addr addrspace(3) global [{u.dtype.size} x {ldt(u.dtype)}] undef, align 16")
kernel.append(f" {r[u]} = addrspacecast [{u.dtype.size} x {ldt(u.dtype)}] addrspace(3)* @{r[u][1:]} to [{u.dtype.size} x {ldt(u.dtype)}]*")
local_args.append(f"@{r[u][1:]} = internal unnamed_addr addrspace(3) global [{u.dtype.size} x {etype}] undef, align 16")
kernel.append(f" {r[u]} = addrspacecast [{u.dtype.size} x {etype}] addrspace(3)* @{r[u][1:]} to [{u.dtype.size} x {etype}]*")
else:
kernel.append(f" {r[u]} = alloca [{u.dtype.size} x {ldt(u.dtype.base)}], align 16")
kernel.append(f" {r[u]} = alloca [{u.dtype.size} x {etype}], align 16")
elif u.op is Ops.CONST: r[u] = lconst(u.arg, u.dtype)
elif u.op is Ops.CAST and (ldt(u.dtype) == ldt(u.src[0].dtype) or isinstance(u.dtype, PtrDType)):
elif u.op is Ops.CAST and (ldt(u.dtype, u.max_numel()) == ldt(u.src[0].dtype, u.src[0].max_numel()) or isinstance(u.dtype, PtrDType)):
r[u] = r[u.src[0]] # cast from signed to unsigned of the same size is a noop, or pointer cast
else:
# if it's an assign target, it's already preallocated
@ -226,19 +242,22 @@ class AMDLLVMRenderer(LLVMRenderer):
string_rewrite = PatternMatcher([
(UPat(Ops.SPECIAL, name="x"), lambda ctx, x: f" {ctx[x]} = " + f"{ code_for_workitem[x.arg[0]](x.arg[-1])}; "),
(UPat(tuple(llvm_intrinsics), name="x"),
lambda ctx, x: f" {ctx[x]} = call {ldt(x.dtype)} @llvm.{llvm_intrinsics[x.op]}.{ldt(x.dtype.scalar())}({ldt(x.src[0].dtype)} {ctx[x.src[0]]})"),
lambda ctx, x: f" {ctx[x]} = call {ldt(x.dtype, x.max_numel())} @llvm.{llvm_intrinsics[x.op]}.{ldt(x.dtype.scalar())}"
f"({ldt(x.src[0].dtype, x.src[0].max_numel())} {ctx[x.src[0]]})"),
(UPat(Ops.BARRIER), lambda ctx: barrier),
(UPat(Ops.CAST, dtypes.fp8s, (UPat(dtype=dtypes.float),), name="x",), lambda ctx,x:
f" {ctx[x]} = call i8 @f32_to_fp8({ldt(x.src[0].dtype)} {ctx[x.src[0]]}, i1 {'1' if x.dtype == dtypes.fp8e5m2 else '0'})"),
f" {ctx[x]} = call i8 @f32_to_fp8({ldt(x.src[0].dtype, x.src[0].max_numel())} {ctx[x.src[0]]}, "
f"i1 {'1' if x.dtype == dtypes.fp8e5m2 else '0'})"),
(UPat(Ops.CAST, dtypes.float, (UPat.var("y", dtypes.fp8s),), name="x",), lambda ctx,x,y:
f" {ctx[x.src[0]]}_i32 = zext i8 {ctx[x.src[0]]} to i32\n"
f" {ctx[x]} = call float @llvm.amdgcn.cvt.f32.{'bf8' if y.dtype == dtypes.fp8e5m2 else 'fp8'}(i32 {ctx[x.src[0]]}_i32, i32 0)"),
]) + base_rewrite
extra_matcher = LLVMRenderer.extra_matcher + create_non_native_float_pats(dtypes.fp8s) + PatternMatcher([
(UPat(Ops.CAST, dtype=dtypes.half.vec(16), src=UPat.var("y", dtypes.half.vec(8))),
lambda y: UOp(Ops.STACK, dtypes.half.vec(16), tuple(y.gep(i // 2) if i % 2 == 0 else UOp.const(dtypes.half, 0.0) for i in range(16)))),
(UPat(Ops.CAST, dtype=dtypes.half.vec(8), src=UPat.var("y", dtypes.half.vec(16))),
lambda y: UOp(Ops.STACK, dtypes.half.vec(8), tuple(y.gep(i * 2) for i in range(8)))),
(UPat(Ops.CAST, dtype=dtypes.half, src=UPat.var("y", dtypes.half), name="x"),
lambda x,y: UOp(Ops.STACK, dtypes.half, tuple(y.gep(i // 2) if i % 2 == 0 else UOp.const(dtypes.half, 0.0) for i in range(16)))
if x.max_numel() == 16 and y.max_numel() == 8 else None),
(UPat(Ops.CAST, dtype=dtypes.half, src=UPat.var("y", dtypes.half), name="x"),
lambda x,y: UOp(Ops.STACK, dtypes.half, tuple(y.gep(i * 2) for i in range(8))) if x.max_numel() == 8 and y.max_numel() == 16 else None),
# amd llvm intrinsics llvm.log2/llvm.exp2 don't support double
(UPat(Ops.LOG2, dtype=dtypes.double, src=(UPat.var("d"),)), xlog2),
(UPat(Ops.EXP2, dtype=dtypes.double, src=(UPat.var("d"),)), xexp2),
@ -273,28 +292,37 @@ exit: %packed = phi i32 [%packed_bf8, %do_bf8], [%packed_fp8, %do_fp8]\n %trunc
self.string_rewrite += PatternMatcher([(UPat(Ops.WMMA, name="wmma"), lambda ctx, wmma, cdna=self.is_cdna: render_wmma_amd(ctx, wmma, cdna))])
if self.is_cdna:
self.extra_matcher += PatternMatcher([
(UPat(Ops.WMMA, name="x", dtype=dtypes.float.vec(4)),
lambda x: UOp(Ops.WMMA, dtypes.float.vec(4), (x.src[0].bitcast(dtypes.uint16.vec(4)), x.src[1].bitcast(dtypes.uint16.vec(4)),
x.src[2]), (*x.arg,)) if x.src[0].dtype == dtypes.bfloat16.vec(4) else None),
(UPat(Ops.WMMA, name="x", dtype=dtypes.float.vec(4)),
lambda x: UOp(Ops.WMMA, dtypes.float.vec(4), (x.src[0].bitcast(dtypes.uint64), x.src[1].bitcast(dtypes.uint64),
x.src[2]), (*x.arg,)) if x.src[0].dtype in (dtypes.fp8e4m3.vec(8), dtypes.fp8e5m2.vec(8)) else None),
(UPat(Ops.WMMA, name="x", dtype=dtypes.float),
lambda x: UOp(Ops.WMMA, dtypes.float, (x.src[0].replace(dtype=x.src[0].dtype.scalar()).bitcast(dtypes.uint16),
x.src[1].replace(dtype=x.src[1].dtype.scalar()).bitcast(dtypes.uint16), x.src[2]), (*x.arg,))
if x.max_numel() == 4 and x.src[0].dtype.scalar() == dtypes.bfloat16 and x.src[0].max_numel() == 4 else None),
(UPat(Ops.WMMA, name="x", dtype=dtypes.float),
lambda x: UOp(Ops.WMMA, dtypes.float, (x.src[0].replace(dtype=x.src[0].dtype.scalar()).bitcast(dtypes.uint64),
x.src[1].replace(dtype=x.src[1].dtype.scalar()).bitcast(dtypes.uint64), x.src[2]), (*x.arg,))
if x.max_numel() == 4 and x.src[0].dtype.scalar() in dtypes.fp8_ocp and x.src[0].max_numel() == 8 else None),
])
if target.arch in {"gfx1100", "gfx1151"}:
self.extra_matcher += PatternMatcher([
(UPat(Ops.WMMA, name="x", dtype=dtypes.half.vec(8)),
lambda x: UOp(Ops.WMMA, dtypes.half.vec(16), (x.src[0], x.src[1], x.src[2].cast(dtypes.half.vec(16))), (*x.arg,)).cast(dtypes.half.vec(8))),
(UPat(Ops.WMMA, name="x"), lambda x: UOp(Ops.WMMA, x.dtype, (x.src[0].bitcast(dtypes.uint16.vec(16)), x.src[1].bitcast(dtypes.uint16.vec(16)),
x.src[2]), x.arg) if x.src[0].dtype == dtypes.bfloat16.vec(16) else None),
(UPat(Ops.WMMA, name="x", dtype=dtypes.half),
lambda x: UOp(Ops.STACK, dtypes.half, tuple(UOp(Ops.WMMA, dtypes.half, (x.src[0], x.src[1],
UOp(Ops.STACK, dtypes.half, tuple(x.src[2].gep(i // 2) if i % 2 == 0 else UOp.const(dtypes.half, 0.0) for i in range(16)))),
(*x.arg,)).gep(i * 2) for i in range(8))) if x.max_numel() == 8 else None),
(UPat(Ops.WMMA, name="x"), lambda x: UOp(Ops.WMMA, x.dtype.scalar(),
(x.src[0].replace(dtype=x.src[0].dtype.scalar()).bitcast(dtypes.uint16),
x.src[1].replace(dtype=x.src[1].dtype.scalar()).bitcast(dtypes.uint16), x.src[2]), x.arg)
if x.src[0].dtype.scalar() == dtypes.bfloat16 and x.src[0].max_numel() == 16 else None),
])
if target.arch in {"gfx1200", "gfx1201"}:
self.extra_matcher += PatternMatcher([
(UPat(Ops.WMMA, name="x", dtype=dtypes.bfloat16.vec(8)), lambda x: UOp(Ops.WMMA, dtypes.uint16.vec(8),
(x.src[0].bitcast(dtypes.uint16.vec(8)), x.src[1].bitcast(dtypes.uint16.vec(8)), x.src[2].bitcast(dtypes.uint16.vec(8))), (*x.arg,))
.bitcast(dtypes.bfloat16.vec(8)) if x.src[0].dtype == dtypes.bfloat16.vec(8) else None),
(UPat(Ops.WMMA, name="x", dtype=dtypes.float.vec(8)),
lambda x: UOp(Ops.WMMA, dtypes.float.vec(8), (x.src[0].bitcast(dtypes.uint16.vec(8)), x.src[1].bitcast(dtypes.uint16.vec(8)),
x.src[2]), (*x.arg,)) if x.src[0].dtype == dtypes.bfloat16.vec(8) else None)
(UPat(Ops.WMMA, name="x", dtype=dtypes.bfloat16), lambda x: UOp(Ops.WMMA, dtypes.uint16,
(x.src[0].replace(dtype=x.src[0].dtype.scalar()).bitcast(dtypes.uint16),
x.src[1].replace(dtype=x.src[1].dtype.scalar()).bitcast(dtypes.uint16),
x.src[2].replace(dtype=x.src[2].dtype.scalar()).bitcast(dtypes.uint16)), (*x.arg,))
.bitcast(dtypes.bfloat16) if x.max_numel() == 8 and x.src[0].dtype.scalar() == dtypes.bfloat16 and x.src[0].max_numel() == 8 else None),
(UPat(Ops.WMMA, name="x", dtype=dtypes.float),
lambda x: UOp(Ops.WMMA, dtypes.float, (x.src[0].replace(dtype=x.src[0].dtype.scalar()).bitcast(dtypes.uint16),
x.src[1].replace(dtype=x.src[1].dtype.scalar()).bitcast(dtypes.uint16), x.src[2]), (*x.arg,))
if x.max_numel() == 8 and x.src[0].dtype.scalar() == dtypes.bfloat16 and x.src[0].max_numel() == 8 else None)
])
def supported_dtypes(self): return {d for d in super().supported_dtypes()

View file

@ -86,9 +86,10 @@ def scope(space): return 'global' if space == AddrSpace.GLOBAL else ('shared' if
nstore = nir_instr(has_def=False, df=lambda addr:addr, intrins=lambda space,val: {"WRITE_MASK":(1<<val.num_components)-1, **iointr(space)},
num_components=lambda val:val.num_components, srcs=lambda space, addr, val: [nsrc(val), nsrc(addr)][::1 if space != AddrSpace.REG else -1])(
lambda b, space, addr, val, dtype: mesa.nir_intrinsic_instr_create(b.shader, g(f"nir_intrinsic_store_{scope(space)}")))
nload = nir_instr(nc=lambda dtype:dtype.count, bs=lambda dtype:dtype.bitsize//dtype.count, num_components=lambda dtype:dtype.count,
nload = nir_instr(nc=lambda dtype,num_components:num_components, bs=lambda dtype,num_components:dtype.bitsize//num_components,
num_components=lambda dtype,num_components:num_components,
intrins=lambda space:{**({"ACCESS":mesa.ACCESS_CAN_REORDER} if space==AddrSpace.GLOBAL else {}), **iointr(space)}, srcs=lambda addr: [nsrc(addr)])(
lambda b, space, addr, dtype: mesa.nir_intrinsic_instr_create(b.shader, g(f"nir_intrinsic_load_{scope(space)}")))
lambda b, space, addr, dtype, num_components: mesa.nir_intrinsic_instr_create(b.shader, g(f"nir_intrinsic_load_{scope(space)}")))
ngid = nir_instr(nc=3, bs=32)(lambda b: mesa.nir_intrinsic_instr_create(b.shader, mesa.nir_intrinsic_load_workgroup_id))
nlid = nir_instr(nc=3, bs=32)(lambda b: mesa.nir_intrinsic_instr_create(b.shader, mesa.nir_intrinsic_load_local_invocation_id))
@ -152,10 +153,11 @@ class NIRRenderer(Renderer):
lambda ctx,buf,off,val: nstore(ctx.b, buf.ptrdtype.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype), ctx.r[val], val.dtype)),
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("off"))).or_casted(), UPat.var("alt"), UPat.var("gate")), name="x"),
lambda ctx,x,buf,off,alt,gate: if_phi(ctx.b, ctx.r[gate],
lambda: nload(ctx.b, buf.ptrdtype.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype, ctx.r[gate]), x.dtype), lambda: ctx.r[alt])),
lambda: nload(ctx.b, buf.ptrdtype.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype, ctx.r[gate]), x.dtype, x.max_numel()),
lambda: ctx.r[alt])),
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("off"))).or_casted(),), name="x"),
lambda ctx,x,buf,off: nload(ctx.b, buf.ptrdtype.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype), x.dtype)),
(UPat(Ops.STACK, name="x"), lambda ctx,x: nalu(ctx.b, f"vec{x.dtype.count}", *[ctx.r[src] for src in x.src])),
lambda ctx,x,buf,off: nload(ctx.b, buf.ptrdtype.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype), x.dtype, x.max_numel())),
(UPat(Ops.STACK, name="x"), lambda ctx,x: nalu(ctx.b, f"vec{x.max_numel()}", *[ctx.r[src] for src in x.src])),
(UPat(GroupOp.ALU, name="x"), lambda ctx,x: nalu(ctx.b, aop[x.src[0].dtype.scalar()][x.op], *[ctx.r[src] for src in x.src])),
(UPat(Ops.CAST, name="x"), lambda ctx,x: ncast(ctx.b, ctx.r[x.src[0]], x.src[0].dtype, x.dtype)),
(UPat(Ops.BITCAST, src=(UPat.var("a"),), allow_any_len=True), lambda ctx,a: ctx.r[a]),
@ -203,7 +205,7 @@ class NIRRenderer(Renderer):
ranges.append(i:=deref_var(self.b, mesa.nir_local_variable_create(self.b.impl, glsl_type(u.dtype), f"idx{range_str(u)}".encode()).contents))
nstore(self.b, AddrSpace.REG, i, nimm(self.b, 0, u.dtype), u.dtype)
mesa.nir_push_loop(self.b)
self.r[u] = nload(self.b, AddrSpace.REG, i, u.dtype)
self.r[u] = nload(self.b, AddrSpace.REG, i, u.dtype, u.max_numel())
nif(self.b, nalu(self.b, "ilt", self.r[u], self.r[u.src[0]]), lambda: None, lambda: njump(self.b, mesa.nir_jump_break))
elif u.op == Ops.END:
r = u.src[1]

View file

@ -102,18 +102,18 @@ string_rewrite = PatternMatcher([
# store / gated load / load
(UPat(Ops.STORE, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("loc"))).or_casted(), UPat.var("var")), allow_any_len=True),
lambda ctx, loc, var, buf: f"st.{mem_type(buf)}" + \
f"{f'.v{cnt}' if ((cnt:=var.dtype.count)>1) else ''}.{ctx.mem_types[var.dtype.scalar()]} " + \
f"[{ctx.r[loc]}+0], {('{' + ', '.join(ctx.r[var]) + '}') if var.dtype.count > 1 else ctx.r[var]};"),
f"{f'.v{cnt}' if ((cnt:=var.max_numel())>1) else ''}.{ctx.mem_types[var.dtype.scalar()]} " + \
f"[{ctx.r[loc]}+0], {('{' + ', '.join(ctx.r[var]) + '}') if var.max_numel() > 1 else ctx.r[var]};"),
(UPat(Ops.LOAD, name="x", src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("loc"))).or_casted(), UPat.var("alt"), UPat.var("gate"))),
lambda ctx, x, loc, alt, gate, buf: flatten([
[f"mov.{ctx.mem_types[x.dtype.scalar()]} {v}, {render_val(0, x.dtype.scalar())};" for v in ctx.r[x]],
[f"@{ctx.r[gate]} ld.{mem_type(buf)}.v{x.dtype.count}.{ctx.mem_types[x.dtype.scalar()]} {{{', '.join(ctx.r[x])}}}, [{ctx.r[loc]}+0];"]
]) if alt.dtype.count > 1 else [
[f"@{ctx.r[gate]} ld.{mem_type(buf)}.v{x.max_numel()}.{ctx.mem_types[x.dtype.scalar()]} {{{', '.join(ctx.r[x])}}}, [{ctx.r[loc]}+0];"]
]) if x.max_numel() > 1 else [
f"@{ctx.r[gate]} ld.{mem_type(buf)}.{ctx.mem_types[x.dtype.scalar()]} {ctx.r[x]}, [{ctx.r[loc]}+0];",
f"@!{ctx.r[gate]} mov.b{ctx.types[x.dtype.scalar()][1:]} {ctx.r[x]}, {ctx.r[alt]};"]),
(UPat(Ops.LOAD, name="x", src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("loc"))).or_casted(),)),
lambda ctx, x, loc, buf: f"ld.{mem_type(buf)}.v{x.dtype.count}.{ctx.mem_types[x.dtype.scalar()]} {{{', '.join(ctx.r[x])}}}, [{ctx.r[loc]}+0];" \
if x.dtype.count > 1 else f"ld.{mem_type(buf)}.{ctx.mem_types[x.dtype]} {ctx.r[x]}, [{ctx.r[loc]}+0];"),
lambda ctx, x, loc, buf: f"ld.{mem_type(buf)}.v{x.max_numel()}.{ctx.mem_types[x.dtype.scalar()]} {{{', '.join(ctx.r[x])}}}, [{ctx.r[loc]}+0];" \
if x.max_numel() > 1 else f"ld.{mem_type(buf)}.{ctx.mem_types[x.dtype]} {ctx.r[x]}, [{ctx.r[loc]}+0];"),
# simple
(UPat(Ops.DEFINE_REG, src=()), lambda ctx: []),
(UPat(Ops.RANGE, name="r"), lambda ctx, r: [
@ -218,14 +218,14 @@ class PTXRenderer(Renderer):
if u.op is Ops.SPECIAL: r[u] = "%" + u.arg
elif u.op is Ops.DEFINE_VAR: bufs.append((u.expr, u.dtype))
elif u.op is Ops.LOAD:
r[u] = [ssa('val', dtype=self.types[u.dtype.scalar()]) for _ in range(u.dtype.count)] if u.dtype.count > 1 else ssa('val', u)
r[u] = [ssa('val', dtype=self.types[u.dtype.scalar()]) for _ in range(u.max_numel())] if u.max_numel() > 1 else ssa('val', u)
elif u.op is Ops.PARAM: bufs.append((f"data{u.arg}", u.dtype))
elif u.op is Ops.WMMA:
# registers for packing/unpacking input and acc
self.wmma_r = [[ssa("wmma_in", dtype="b32") for _ in range(0, len(r[u.src[0]]), 4 // u.src[0].dtype.scalar().itemsize)],
[ssa("wmma_in", dtype="b32") for _ in range(0, len(r[u.src[1]]), 4 // u.src[0].dtype.scalar().itemsize)],
[ssa("wmma_acc", dtype="b32") for _ in range(0, len(r[u.src[2]]), 4 // u.dtype.scalar().itemsize)]]
r[u] = [ssa("wmma", dtype=self.types[u.dtype.scalar()]) for _ in range(u.dtype.count)]
r[u] = [ssa("wmma", dtype=self.types[u.dtype.scalar()]) for _ in range(u.max_numel())]
prefix, dtype = {Ops.CAST: ("cast", None), Ops.BITCAST: ("cast", None), Ops.END: ("pred", "pred"), Ops.RANGE: ("ridx", None),
Ops.DEFINE_VAR: ("dat", None), Ops.CONST: ("const", None), Ops.DEFINE_LOCAL: ("local", self.types[dtypes.ulong]),
Ops.PARAM: ("dat", self.types[dtypes.ulong]), **{op: ("alu", None) for op in GroupOp.ALU}}.get(u.op, (None, None))

View file

@ -94,8 +94,8 @@ class WGSLRenderer(CStyleLanguage):
lambda ctx,b,idx: f"{ctx[b]}[{strip_parens(ctx[idx]) if idx.arg is Ops.ADD else ctx[idx]}]"),
]) + base_rewrite
def render_cast(self, dt:DType, val: str) -> str: return f"{self.type_map[dt]}({val})"
def render_dtype(self, dt:DType, mutable=True) -> str: return "var"
def render_cast(self, dt:DType, val: str, lanes:int=1) -> str: return f"{self.type_map[dt.scalar()]}({val})"
def render_dtype(self, dt:DType, mutable=True, lanes:int=1) -> str: return "var"
def render_load(self, x:str, dt:DType) -> str: return f"atomicLoad(&{x})" if is_packed(dt) else x
def buf_map(self, dt:DType) -> str: return "atomic<u32>" if is_packed(dt) else self.type_map[dt.base]
def render_kernel(self, function_name:str, kernel:list[str], bufs:list[tuple[str,tuple[DType,bool]]], uops:list[UOp], prefix=None) -> str:

View file

@ -368,6 +368,8 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
@property
def max_shape(self) -> tuple[int, ...]: return to_max_shape(self.shape)
def max_numel(self) -> int: return prod(self.max_shape)
@property
def shard_shape(self) -> tuple[sint, ...]:
if not isinstance(self.device, tuple) or self.axis is None: return self.shape