Compare commits

...

3 commits

Author SHA1 Message Date
George Hotz
f8a66639a6 junk 2026-04-29 21:55:30 +00:00
George Hotz
0095b91ce6 fixes 2026-04-29 21:46:08 +00:00
George Hotz
3e13225f62 switch the renderers to use shape instead of dtype.count 2026-04-29 21:38:15 +00:00
7 changed files with 30 additions and 28 deletions

View file

@ -55,7 +55,7 @@ class Estimates:
lds += u.dtype.itemsize * mults
elif u.op is Ops.STORE and (not isinstance(u.src[0].dtype, PtrDType) or u.src[0].dtype.addrspace != AddrSpace.REG):
lds += u.src[1].dtype.itemsize * mults
elif u.op in GroupOp.ALU and u not in dont_count: flops += (mults * (2 if u.op is Ops.MULACC else 1)) * u.dtype.count
elif u.op in GroupOp.ALU and u not in dont_count: flops += (mults * (2 if u.op is Ops.MULACC else 1)) * u.numel()
elif u.op is Ops.WMMA and u not in dont_count: flops += 2 * prod(u.arg[1]) // u.arg[5] * mults
return Estimates(flops, lds, sum(mem.values()))

View file

@ -21,7 +21,8 @@ base_rewrite = PatternMatcher([
lambda ctx,x: f"{ctx.float4.replace('float4', ctx.render_dtype(x.dtype))}" + \
f"{ctx.float4_style[0]}{','.join([ctx[y] for y in x.src])}{ctx.float4_style[1]}"),
(UPat(Ops.CAST, name="x"), lambda ctx,x:
f"__builtin_convertvector({ctx[x.src[0]]}, {ctx.render_dtype(x.dtype)})" if x.dtype.count > 1 and not isinstance(x.dtype, PtrDType) else None),
f"__builtin_convertvector({ctx[x.src[0]]}, {ctx.render_dtype(x.dtype)})"
if not isinstance(x.dtype, PtrDType) and x._shape is not None and x.numel() > 1 else None),
(UPat(Ops.CAST, name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, ctx[x.src[0]])})"),
(UPat(Ops.BITCAST, name="x"), lambda ctx,x:
f"__builtin_bit_cast({ctx.render_dtype(x.dtype)}, ({ctx.render_dtype(x.src[0].dtype)})({ctx[x.src[0]]}))"),
@ -55,7 +56,7 @@ base_rewrite = PatternMatcher([
(UPat(GroupOp.ALU, name="x"), lambda ctx,x: ctx.code_for_op[x.op](
*([strip_parens(ctx[v]) if v.op == x.op and x.op in {Ops.ADD, Ops.MUL, Ops.XOR, Ops.OR, Ops.AND} else ctx[v] for v in x.src]), x.dtype)),
(UPat(Ops.GEP, name="x"), lambda ctx,x: ctx[x.src[0]] + \
(f"[{x.arg[0]}]" if x.src[0].dtype.count > ctx.gep_arr_threshold else f".{'xyzwabcd'[x.arg[0]]}")),
(f"[{x.arg[0]}]" if x.src[0].numel() > ctx.gep_arr_threshold else f".{'xyzwabcd'[x.arg[0]]}")),
# custom passes through with format
(UPat((Ops.CUSTOM, Ops.CUSTOMI), name="x"), lambda ctx,x: x.arg.format(*[ctx[y] for y in x.src])),
])

View file

@ -91,7 +91,7 @@ base_rewrite = PatternMatcher([
(UPat(Ops.GEP, name="x"), lambda ctx,x: f" {ctx[x]} = extractelement {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, i32 {x.arg[0]}"),
(UPat(Ops.STACK, src=UPat.var('y'), name="x"), lambda ctx,x,y:
f" {ctx[x]}_z = insertelement <1 x {ldt(y.dtype)}> poison, {ldt(y.dtype)} {ctx[y]}, i32 0\n"
f" {ctx[x]} = shufflevector <1 x {ldt(y.dtype)}> {ctx[x]}_z, <1 x {ldt(y.dtype)}> poison, <{x.dtype.count} x i32> zeroinitializer"),
f" {ctx[x]} = shufflevector <1 x {ldt(y.dtype)}> {ctx[x]}_z, <1 x {ldt(y.dtype)}> poison, <{x.numel()} x i32> zeroinitializer"),
(UPat(Ops.STACK, name="x"), lambda ctx,x: "\n".join([(f" {ctx[x]}_{i}" if i+1 != len(x.src) else f" {ctx[x]}")+
f" = insertelement {ldt(x.dtype)} "+(f"{ctx[x]}_{i-1}" if i != 0 else "poison")+
f", {ldt(u.dtype)} {ctx[u]}, i32 {i}" for i,u in enumerate(x.src)])),

View file

@ -86,9 +86,9 @@ def scope(space): return 'global' if space == AddrSpace.GLOBAL else ('shared' if
nstore = nir_instr(has_def=False, df=lambda addr:addr, intrins=lambda space,val: {"WRITE_MASK":(1<<val.num_components)-1, **iointr(space)},
num_components=lambda val:val.num_components, srcs=lambda space, addr, val: [nsrc(val), nsrc(addr)][::1 if space != AddrSpace.REG else -1])(
lambda b, space, addr, val, dtype: mesa.nir_intrinsic_instr_create(b.shader, g(f"nir_intrinsic_store_{scope(space)}")))
nload = nir_instr(nc=lambda dtype:dtype.count, bs=lambda dtype:dtype.bitsize//dtype.count, num_components=lambda dtype:dtype.count,
nload = nir_instr(nc=lambda x:x.numel(), bs=lambda x:x.dtype.scalar().bitsize, num_components=lambda x:x.numel(),
intrins=lambda space:{**({"ACCESS":mesa.ACCESS_CAN_REORDER} if space==AddrSpace.GLOBAL else {}), **iointr(space)}, srcs=lambda addr: [nsrc(addr)])(
lambda b, space, addr, dtype: mesa.nir_intrinsic_instr_create(b.shader, g(f"nir_intrinsic_load_{scope(space)}")))
lambda b, space, addr, x: mesa.nir_intrinsic_instr_create(b.shader, g(f"nir_intrinsic_load_{scope(space)}")))
ngid = nir_instr(nc=3, bs=32)(lambda b: mesa.nir_intrinsic_instr_create(b.shader, mesa.nir_intrinsic_load_workgroup_id))
nlid = nir_instr(nc=3, bs=32)(lambda b: mesa.nir_intrinsic_instr_create(b.shader, mesa.nir_intrinsic_load_local_invocation_id))
@ -150,10 +150,10 @@ class NIRRenderer(Renderer):
lambda ctx,buf,off,val: nstore(ctx.b, buf.ptrdtype.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype), ctx.r[val], val.dtype)),
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("off"), UPat.var("gate"))), UPat.var("alt")), allow_any_len=True, name="x"),
lambda ctx,x,buf,off,alt,gate: if_phi(ctx.b, ctx.r[gate],
lambda: nload(ctx.b, buf.ptrdtype.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype, ctx.r[gate]), x.dtype), lambda: ctx.r[alt])),
lambda: nload(ctx.b, buf.ptrdtype.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype, ctx.r[gate]), x), lambda: ctx.r[alt])),
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("off"))),), allow_any_len=True, name="x"),
lambda ctx,x,buf,off: nload(ctx.b, buf.ptrdtype.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype), x.dtype)),
(UPat(Ops.STACK, name="x"), lambda ctx,x: nalu(ctx.b, f"vec{x.dtype.count}", *[ctx.r[src] for src in x.src])),
lambda ctx,x,buf,off: nload(ctx.b, buf.ptrdtype.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype), x)),
(UPat(Ops.STACK, name="x"), lambda ctx,x: nalu(ctx.b, f"vec{x.numel()}", *[ctx.r[src] for src in x.src])),
(UPat(GroupOp.ALU, name="x"), lambda ctx,x: nalu(ctx.b, aop[x.src[0].dtype.scalar()][x.op], *[ctx.r[src] for src in x.src])),
(UPat(Ops.CAST, name="x"), lambda ctx,x: ncast(ctx.b, ctx.r[x.src[0]], x.src[0].dtype, x.dtype)),
(UPat(Ops.BITCAST, src=(UPat.var("a"),), allow_any_len=True), lambda ctx,a: ctx.r[a]),
@ -200,7 +200,7 @@ class NIRRenderer(Renderer):
ranges.append(i:=deref_var(self.b, mesa.nir_local_variable_create(self.b.impl, glsl_type(u.dtype), f"idx{range_str(u)}".encode()).contents))
nstore(self.b, AddrSpace.REG, i, nimm(self.b, 0, u.dtype), u.dtype)
mesa.nir_push_loop(self.b)
self.r[u] = nload(self.b, AddrSpace.REG, i, u.dtype)
self.r[u] = nload(self.b, AddrSpace.REG, i, u)
nif(self.b, nalu(self.b, "ilt", self.r[u], self.r[u.src[0]]), lambda: None, lambda: njump(self.b, mesa.nir_jump_break))
elif u.op == Ops.END:
r = u.src[1]

View file

@ -104,18 +104,18 @@ string_rewrite = PatternMatcher([
# store / gated load / load
(UPat(Ops.STORE, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("loc")), allow_any_len=True), UPat.var("var"))),
lambda ctx, loc, var, buf: f"st.{mem_type(buf)}" + \
f"{f'.v{cnt}' if ((cnt:=var.dtype.count)>1) else ''}.{ctx.mem_types[var.dtype.scalar()]} " + \
f"[{ctx.r[loc]}+0], {('{' + ', '.join(ctx.r[var]) + '}') if var.dtype.count > 1 else ctx.r[var]};"),
f"{f'.v{cnt}' if ((cnt:=var.numel())>1) else ''}.{ctx.mem_types[var.dtype.scalar()]} " + \
f"[{ctx.r[loc]}+0], {('{' + ', '.join(ctx.r[var]) + '}') if var.numel() > 1 else ctx.r[var]};"),
(UPat(Ops.LOAD, name="x", src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("loc"), UPat.var("gate"))), UPat.var("alt")), allow_any_len=True),
lambda ctx, x, loc, alt, gate, buf: flatten([
[f"mov.{ctx.mem_types[x.dtype.scalar()]} {v}, {render_val(0, x.dtype.scalar())};" for v in ctx.r[x]],
[f"@{ctx.r[gate]} ld.{mem_type(buf)}.v{x.dtype.count}.{ctx.mem_types[x.dtype.scalar()]} {{{', '.join(ctx.r[x])}}}, [{ctx.r[loc]}+0];"]
]) if alt.dtype.count > 1 else [
[f"@{ctx.r[gate]} ld.{mem_type(buf)}.v{x.numel()}.{ctx.mem_types[x.dtype.scalar()]} {{{', '.join(ctx.r[x])}}}, [{ctx.r[loc]}+0];"]
]) if x.numel() > 1 else [
f"@{ctx.r[gate]} ld.{mem_type(buf)}.{ctx.mem_types[x.dtype.scalar()]} {ctx.r[x]}, [{ctx.r[loc]}+0];",
f"@!{ctx.r[gate]} mov.b{ctx.types[x.dtype.scalar()][1:]} {ctx.r[x]}, {ctx.r[alt]};"]),
(UPat(Ops.LOAD, name="x", src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("loc"))),), allow_any_len=True),
lambda ctx, x, loc, buf: f"ld.{mem_type(buf)}.v{x.dtype.count}.{ctx.mem_types[x.dtype.scalar()]} {{{', '.join(ctx.r[x])}}}, [{ctx.r[loc]}+0];" \
if x.dtype.count > 1 else f"ld.{mem_type(buf)}.{ctx.mem_types[x.dtype]} {ctx.r[x]}, [{ctx.r[loc]}+0];"),
lambda ctx, x, loc, buf: f"ld.{mem_type(buf)}.v{x.numel()}.{ctx.mem_types[x.dtype.scalar()]} {{{', '.join(ctx.r[x])}}}, [{ctx.r[loc]}+0];" \
if x.numel() > 1 else f"ld.{mem_type(buf)}.{ctx.mem_types[x.dtype]} {ctx.r[x]}, [{ctx.r[loc]}+0];"),
# simple
(UPat(Ops.DEFINE_REG, src=()), lambda ctx: []),
(UPat(Ops.RANGE, name="r"), lambda ctx, r: [
@ -220,14 +220,14 @@ class PTXRenderer(Renderer):
elif u.op is Ops.DEFINE_VAR: bufs.append((u.expr, u.dtype))
elif u.op is Ops.LOAD:
assert u.src[0].dtype == dtypes.int64, "load isn't int64"
r[u] = [ssa('val', dtype=self.types[u.dtype.scalar()]) for _ in range(u.dtype.count)] if u.dtype.count > 1 else ssa('val', u)
r[u] = [ssa('val', dtype=self.types[u.dtype.scalar()]) for _ in range(u.numel())] if u.numel() > 1 else ssa('val', u)
elif u.op is Ops.PARAM: bufs.append((f"data{u.arg}", u.dtype))
elif u.op is Ops.WMMA:
# registers for packing/unpacking input and acc
self.wmma_r = [[ssa("wmma_in", dtype="b32") for _ in range(0, len(r[u.src[0]]), 4 // u.src[0].dtype.scalar().itemsize)],
[ssa("wmma_in", dtype="b32") for _ in range(0, len(r[u.src[1]]), 4 // u.src[0].dtype.scalar().itemsize)],
[ssa("wmma_acc", dtype="b32") for _ in range(0, len(r[u.src[2]]), 4 // u.dtype.scalar().itemsize)]]
r[u] = [ssa("wmma", dtype=self.types[u.dtype.scalar()]) for _ in range(u.dtype.count)]
r[u] = [ssa("wmma", dtype=self.types[u.dtype.scalar()]) for _ in range(u.numel())]
prefix, dtype = {Ops.CAST: ("cast", None), Ops.BITCAST: ("cast", None), Ops.END: ("pred", "pred"), Ops.RANGE: ("ridx", None),
Ops.DEFINE_VAR: ("dat", None), Ops.CONST: ("const", None), Ops.DEFINE_LOCAL: ("local", self.types[dtypes.ulong]),
Ops.PARAM: ("dat", self.types[dtypes.ulong]), **{op: ("alu", None) for op in GroupOp.ALU}}.get(u.op, (None, None))

View file

@ -546,7 +546,7 @@ def split_store(x:UOp) -> UOp|None:
# if we have any open ranges here, we don't split
if x.ranges: return None
# raw STORE (not from bufferize_to_store) should be processed through its END wrapper, not independently
if x.op is Ops.STORE and x.src[0]._shape is not None: return None
if x.op is Ops.STORE and x.src[0]._shape is not None and x.src[0].shape != (): return None
# local kernel rewrite
lctx = LocalAddBufferContext()

View file

@ -209,10 +209,11 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
@recursive_property
def _shape(self) -> tuple[sint, ...]|None:
if self.dtype.count > 1: return (self.dtype.count,)
match self.op:
# late ops don't have shape
case Ops.UNIQUE | Ops.LUNIQUE | Ops.DEVICE | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \
Ops.STACK | Ops.GEP | Ops.UNROLL | Ops.CONTRACT | Ops.SINK | Ops.END | \
Ops.STACK | Ops.UNROLL | Ops.CONTRACT | Ops.SINK | Ops.END | \
Ops.LINEAR | Ops.PROGRAM | Ops.SOURCE | Ops.BINARY | Ops.INS | Ops.TUPLE | Ops.CALL | Ops.FUNCTION:
return None
@ -233,12 +234,12 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
return None
case Ops.INDEX:
# non pointer index doesn't have a shape
if not isinstance(self.dtype, PtrDType): return None
# fully indexed doesn't have a shape. TODO: remove this
if self.src[0]._shape is None or len(self.src[1:]) == len(self.src[0].shape): return None
# pointer index
return self.src[0].shape[len(self.src[1:]):]
if self.src[0]._shape is None: return None
idx_srcs = self.src[1:-1] if len(self.src) > 1 and self.src[-1].dtype is dtypes.bool else self.src[1:]
return self.src[0].shape[len(idx_srcs):]
case Ops.GEP:
return None if self.src[0]._shape is None else self.src[0].shape[:-1] + ((len(self.arg),) if len(self.arg) > 1 else ())
# some ops init the shape
case Ops.CONST | Ops.DEFINE_VAR | Ops.BIND | Ops.RANGE | Ops.SPECIAL: return ()
@ -492,7 +493,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
ret = UOp(Ops.VCONST if isinstance(b, tuple) else Ops.CONST, dtype,
arg=dtype.const(b),
src=(UOp(Ops.DEVICE, arg=device),) if device is not None else ())
return ret.reshape((1,)*len(shape)).expand(shape) if shape is not None else ret
return ret.reshape((1,)*len(shape)).expand(shape) if shape is not None and ret.shape != shape else ret
@staticmethod
def unique_const(fill_value:ConstType, dtype:DTypeLike|None=None, device:str|tuple[str, ...]|None=None, # type: ignore[override]
shape:tuple[sint, ...]|None=None, unique=True):
@ -500,7 +501,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
assert not isinstance(fill_value, (UOp, tuple)), "unique const only works on numbers"
ret = UOp.const(to_dtype(dtype) if dtype is not None else dtypes.from_py(fill_value), fill_value, canonicalize_device(device))
ret = ret.replace(src=(UOp.unique(None if unique is True else unique),) + ret.src)
return ret.reshape((1,)*len(shape)).expand(shape) if shape is not None else ret
return ret.reshape((1,)*len(shape)).expand(shape) if shape is not None and ret.shape != shape else ret
@staticmethod
def range(end:sint, axis_id, axis_type=AxisType.LOOP, *arg, dtype=dtypes.weakint, src=(), **kwargs):
return UOp(Ops.RANGE, dtype=dtype, src=(sint_to_uop(end, dtype),)+src, arg=(axis_id, axis_type)+arg, **kwargs)