bitsize in nir

This commit is contained in:
George Hotz 2026-06-02 13:11:31 -07:00
commit 431accc9b7
2 changed files with 19 additions and 23 deletions

View file

@ -71,12 +71,12 @@ def nimm_set(imm:mesa.nir_def, x, dtype:DType):
instr = ctypes.cast(imm.parent_instr, ctypes.POINTER(mesa.nir_load_const_instr))
struct.pack_into(unwrap(dtype.fmt), (ctypes.c_ubyte * dtype.itemsize).from_address(ctypes.addressof(instr.contents.value)), 0, truncate[dtype](x))
@nir_instr(nc=1, bs=lambda dtype: dtype.bitsize)
@nir_instr(nc=1, bs=lambda dtype: dtype.bitsize*dtype.count)
def nimm(b:mesa.nir_builder, x, dtype:DType) -> mesa.nir_def:
nimm_set((instr:=mesa.nir_load_const_instr_create(b.shader, 1, dtype.bitsize)).contents._def, x, dtype)
nimm_set((instr:=mesa.nir_load_const_instr_create(b.shader, 1, dtype.bitsize*dtype.count)).contents._def, x, dtype)
return instr
@nir_instr(nc=1, bs=lambda dtype: dtype.bitsize)
def nundef(b, dtype): return mesa.nir_undef_instr_create(b.shader, 1, dtype.bitsize)
@nir_instr(nc=1, bs=lambda dtype: dtype.bitsize*dtype.count)
def nundef(b, dtype): return mesa.nir_undef_instr_create(b.shader, 1, dtype.bitsize*dtype.count)
deref_var = nir_instr(nc=1, bs=32, modes=lambda var:var.data.mode, type=lambda var:var.type, var=lambda var:ctypes.pointer(var))( # pylint: disable=W0108
lambda b, var: mesa.nir_deref_instr_create(b.shader, mesa.nir_deref_type_var))
@ -86,7 +86,7 @@ def scope(space): return 'global' if space == AddrSpace.GLOBAL else ('shared' if
nstore = nir_instr(has_def=False, df=lambda addr:addr, intrins=lambda space,val: {"WRITE_MASK":(1<<val.num_components)-1, **iointr(space)},
num_components=lambda val:val.num_components, srcs=lambda space, addr, val: [nsrc(val), nsrc(addr)][::1 if space != AddrSpace.REG else -1])(
lambda b, space, addr, val, dtype: mesa.nir_intrinsic_instr_create(b.shader, g(f"nir_intrinsic_store_{scope(space)}")))
nload = nir_instr(nc=lambda dtype:dtype.count, bs=lambda dtype:dtype.bitsize//dtype.count, num_components=lambda dtype:dtype.count,
nload = nir_instr(nc=lambda dtype:dtype.count, bs=lambda dtype:dtype.bitsize, num_components=lambda dtype:dtype.count,
intrins=lambda space:{**({"ACCESS":mesa.ACCESS_CAN_REORDER} if space==AddrSpace.GLOBAL else {}), **iointr(space)}, srcs=lambda addr: [nsrc(addr)])(
lambda b, space, addr, dtype: mesa.nir_intrinsic_instr_create(b.shader, g(f"nir_intrinsic_load_{scope(space)}")))

View file

@ -27,12 +27,8 @@ def packed_load(root:UOp, bidx:UOp, dtype:DType, var:UOp|None=None, gate:UOp|Non
val = (load.cast(dtypes.uint32) >> shift_am) & mask
return sign_extend(val, 8*dtype.itemsize).cast(dtype) if dtype in [dtypes.char, dtypes.short] else val.cast(dtype)
def is_packed(dt:DType, odt:DType|None = None) -> bool:
if odt is None: odt = dt
# registers aren't packed
if isinstance(odt, PtrDType) and odt.addrspace == AddrSpace.REG: return False
return dt.itemsize < 4 and dt.base != dtypes.half
def _packed_size(dt:PtrDType): return dt.size // (4//dt.itemsize) if is_packed(dt) else dt.size
def is_packed(u:UOp) -> bool: return u.dtype.itemsize < 4 and u.dtype.base != dtypes.half and u.addrspace != AddrSpace.REG
def _packed_size(u:UOp): return u.max_numel() // (4//u.dtype.itemsize) if is_packed(u) else u.max_numel()
def is_nan(a):
bs, (exp, mant) = a.dtype.bitsize, dtypes.finfo(a.dtype)
@ -43,12 +39,12 @@ wgsl_matcher = PatternMatcher([
lambda a,b,c: a.cast(dtypes.int).alu(c.op, b.cast(dtypes.int)).cast(dtypes.bool)),
# TODO: load alt value doesnt have to be a const
(UPat.load(UPat.var("b"), UPat.cvar("c"), UPat.var("gate"), name="l"),
lambda l,b,c,gate: packed_load(l,b,l.dtype,c.cast(dtypes.uint32),gate) if is_packed(l.dtype, b.dtype) else None),
(UPat.load(UPat.var("b"), name='l'), lambda l,b: packed_load(l, b, l.dtype) if is_packed(l.dtype, b.dtype) else None),
lambda l,b,c,gate: packed_load(l,b,l.dtype,c.cast(dtypes.uint32),gate) if is_packed(b) else None),
(UPat.load(UPat.var("b"), name='l'), lambda l,b: packed_load(l, b, l.dtype) if is_packed(b) else None),
(UPat.store(UPat.var("bidx"), UPat.var("var"), UPat.var("gate")),
lambda bidx,var,gate: packed_store(bidx,var,gate) if is_packed(var.dtype, bidx.dtype) else None),
lambda bidx,var,gate: packed_store(bidx,var,gate) if is_packed(bidx) else None),
(UPat.store(UPat.var("bidx"), UPat.var("var")),
lambda bidx,var: packed_store(bidx,var) if is_packed(var.dtype, bidx.dtype) else None),
lambda bidx,var: packed_store(bidx,var) if is_packed(bidx) else None),
(UPat.var("a") << UPat.var("b"),lambda a,b:(a.bitcast(dtypes.uint32)<<b.cast(dtypes.uint32)).bitcast(a.dtype) if b.dtype!=dtypes.uint32 else None),
(UPat.var("x") >> UPat.var("y"), lambda x,y: UOp(Ops.SHR, x.dtype, (x,y.cast(dtypes.uint))) if y.dtype != dtypes.uint else None),
# fix nan check: 'a != a -> is_nan()'
@ -73,8 +69,8 @@ class WGSLRenderer(CStyleLanguage):
(UPat(Ops.CONST, dtype=(dtypes.uchar, dtypes.ushort, dtypes.uint32), name="x"),
lambda x: f"bitcast<u32>({x.arg})" if x.arg < 0 else f"{x.arg&0xFFFFFFFF}u"),
(UPat(Ops.CONST, dtype=dtypes.int32, name="x"), lambda ctx,x: f"{truncate[x.dtype](x.arg)}"),
(UPat(Ops.DEFINE_LOCAL, name="x"), lambda ctx,x: f"var<workgroup> {ctx[x]}: array<{ctx.buf_map(x.dtype.base)},{_packed_size(x.dtype)}>;"),
(UPat(Ops.DEFINE_REG, name="x"), lambda ctx,x: f"var {ctx[x]}: array<{ctx.buf_map(x.dtype)},{_packed_size(x.dtype)}>;"),
(UPat(Ops.DEFINE_LOCAL, name="x"), lambda ctx,x: f"var<workgroup> {ctx[x]}: array<{ctx.buf_map(x)},{_packed_size(x)}>;"),
(UPat(Ops.DEFINE_REG, name="x"), lambda ctx,x: f"var {ctx[x]}: array<{ctx.buf_map(x)},{_packed_size(x)}>;"),
(UPat(Ops.BITCAST, dtype=dtypes.half, name="x", src=(UPat(dtype=(dtypes.short, dtypes.ushort, dtypes.uint32),),)),
lambda ctx,x: f"bitcast<vec2<f16>>({ctx[x.src[0]]})[0]"),
(UPat(Ops.BITCAST, dtype=dtypes.uchar, name="x"), lambda ctx,x: f"bitcast<u32>({ctx[x.src[0]]}&0xFF)"),
@ -86,11 +82,11 @@ class WGSLRenderer(CStyleLanguage):
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"bitcast<{ctx.type_map[x.dtype]}>({ctx[x.src[0]]})"),
# TODO: load alt value doesnt have to be a const
(UPat.load(UPat.var("b"), UPat.cvar("v"), UPat.var("gate")),
lambda ctx,b,v,gate: f"select({ctx[v]}, {ctx.render_load(ctx[b],b.src[0].dtype)}, {ctx[gate]})"),
(UPat.load(UPat.var("b")), lambda ctx, b: ctx.render_load(ctx[b], b.dtype)),
lambda ctx,b,v,gate: f"select({ctx[v]}, {ctx.render_load(ctx[b], b.src[0])}, {ctx[gate]})"),
(UPat.load(UPat.var("b")), lambda ctx, b: ctx.render_load(ctx[b], b)),
(UPat.store(UPat.var("b"), UPat.var("v")), lambda ctx,b,v:\
# (load & mask) | var -> mask = v.src[0].src[1], var = v.src[1]
f"atomicAnd(&{ctx[b]},{ctx[v.src[0].src[1]]});\n atomicAdd(&{ctx[b]},{ctx[v.src[1]]});" if is_packed(b.src[0].dtype) \
f"atomicAnd(&{ctx[b]},{ctx[v.src[0].src[1]]});\n atomicAdd(&{ctx[b]},{ctx[v.src[1]]});" if is_packed(b.src[0]) \
else f"{ctx[b]} = {ctx[v]};"),
(UPat(Ops.INDEX, src=(UPat.var("b"), UPat.var("idx"))),
lambda ctx,b,idx: f"{ctx[b]}[{strip_parens(ctx[idx]) if idx.arg is Ops.ADD else ctx[idx]}]"),
@ -98,8 +94,8 @@ class WGSLRenderer(CStyleLanguage):
def render_cast(self, dt:DType, val: str) -> str: return f"{self.type_map[dt]}({val})"
def render_dtype(self, dt:DType, mutable=True) -> str: return "var"
def render_load(self, x:str, dt:DType) -> str: return f"atomicLoad(&{x})" if is_packed(dt) else x
def buf_map(self, dt:DType) -> str: return "atomic<u32>" if is_packed(dt) else self.type_map[dt.base]
def render_load(self, x:str, u:UOp) -> str: return f"atomicLoad(&{x})" if is_packed(u) else x
def buf_map(self, u:UOp) -> str: return "atomic<u32>" if is_packed(u) else self.type_map[u.dtype.base]
def render_kernel(self, function_name:str, kernel:list[str], bufs:list[tuple[str,tuple[UOp,bool]]], uops:list[UOp], prefix=None) -> str:
local_size = [u.src[0].ssimplify() for u in sorted([u for u in uops if u.op is Ops.SPECIAL and u.arg[0] == 'l'], key=lambda u: u.arg)]
if not local_size: local_size = [1]
@ -111,7 +107,7 @@ class WGSLRenderer(CStyleLanguage):
prg += "@group(0) @binding(0)\nvar<uniform> INFINITY : f32;\n"
prg += "\n".join((external_local_bufs or [])+[f"@group(0) @binding({next(bind_it)+1})" +
f"{'var<storage,read_write>' if isinstance(u.dtype, PtrDType) else 'var<uniform>'}" +
f"{name}:{f'array<{self.buf_map(u.dtype.base)}>' if isinstance(u.dtype,PtrDType) else self.buf_map(u.dtype)};" for name,(u,_) in bufs])
f"{name}:{f'array<{self.buf_map(u)}>' if isinstance(u.dtype,PtrDType) else self.buf_map(u)};" for name,(u,_) in bufs])
prg += f"\n@compute @workgroup_size({','.join([str(x) for x in local_size])}) fn {function_name}(@builtin(workgroup_id) gindex: vec3<u32>,"
return prg + "@builtin(local_invocation_id) lindex: vec3<u32>) {\n" + "\n".join(kernel) + "\n}"