Compare commits

...

2 commits

Author SHA1 Message Date
George Hotz
785f09aac4 lil cleanup 2026-06-11 17:44:22 -07:00
George Hotz
32ad5e8b96 port NIR to new_style (fable) 2026-06-11 17:34:01 -07:00

View file

@ -1,5 +1,5 @@
from typing import Callable, cast, Any from typing import Callable, cast, Any
from tinygrad.dtype import AddrSpace, DType, PtrDType, ImageDType, dtypes, truncate from tinygrad.dtype import AddrSpace, DType, ImageDType, dtypes, truncate
from tinygrad.helpers import DEBUG, OSX, unwrap, fromimport, Target from tinygrad.helpers import DEBUG, OSX, unwrap, fromimport, Target
from tinygrad.renderer import Renderer from tinygrad.renderer import Renderer
from tinygrad.renderer.cstyle import CUDARenderer, OpenCLRenderer from tinygrad.renderer.cstyle import CUDARenderer, OpenCLRenderer
@ -11,7 +11,7 @@ import base64, ctypes, ctypes.util, struct, functools, inspect, itertools
def g(s:str): return getattr(mesa, s) def g(s:str): return getattr(mesa, s)
def nsrc(d:mesa.nir_def) -> mesa.nir_src: return mesa.nir_src(ssa=ctypes.pointer(d)) def nsrc(d:mesa.nir_def) -> mesa.nir_src: return mesa.nir_src(ssa=ctypes.pointer(d))
def glsl_type(t:DType): return mesa.glsl_array_type(glsl_type(t.base), t.size, 0).contents if isinstance(t, PtrDType) else { def glsl_type(t:DType): return {
**{getattr(dtypes,k):g(f"glsl_type_builtin_{v}") for k,v in [('double','double'),('float','float'),('float16','float16_t'),('bool','uint8_t')]}, **{getattr(dtypes,k):g(f"glsl_type_builtin_{v}") for k,v in [('double','double'),('float','float'),('float16','float16_t'),('bool','uint8_t')]},
**{d:g(f"glsl_type_builtin_{'u' * (d in dtypes.uints)}int{str(d.bitsize)+'_t' if d.itemsize != 4 else ''}") for d in dtypes.ints}}[t] **{d:g(f"glsl_type_builtin_{'u' * (d in dtypes.uints)}int{str(d.bitsize)+'_t' if d.itemsize != 4 else ''}") for d in dtypes.ints}}[t]
@ -25,7 +25,6 @@ aop = {**{x:u_aop for x in (dtypes.bool,)+dtypes.uints}, **{x:s_aop for x in dty
def c(t:DType, u:bool=True) -> str: return "u" if t in dtypes.uints and u else ("i" if t in dtypes.ints else ("f" if t in dtypes.floats else "b")) def c(t:DType, u:bool=True) -> str: return "u" if t in dtypes.uints and u else ("i" if t in dtypes.ints else ("f" if t in dtypes.floats else "b"))
def ncast(b:mesa.nir_builder, src:mesa.nir_def, it:DType, ot:DType) -> mesa.nir_def: def ncast(b:mesa.nir_builder, src:mesa.nir_def, it:DType, ot:DType) -> mesa.nir_def:
if isinstance(it, PtrDType) and ot == dtypes.long: return src
return nalu(b, f"{c(it)}2{c(it) if it in dtypes.ints and ot in dtypes.ints else c(ot, ot == dtypes.bool)}{ot.bitsize}", src) return nalu(b, f"{c(it)}2{c(it) if it in dtypes.ints and ot in dtypes.ints else c(ot, ot == dtypes.bool)}{ot.bitsize}", src)
def nif(b:mesa.nir_builder, cond:mesa.nir_def, then_fn:Callable, else_fn:Callable): def nif(b:mesa.nir_builder, cond:mesa.nir_def, then_fn:Callable, else_fn:Callable):
@ -85,10 +84,10 @@ def scope(space): return 'global' if space == AddrSpace.GLOBAL else ('shared' if
nstore = nir_instr(has_def=False, df=lambda addr:addr, intrins=lambda space,val: {"WRITE_MASK":(1<<val.num_components)-1, nstore = nir_instr(has_def=False, df=lambda addr:addr, intrins=lambda space,val: {"WRITE_MASK":(1<<val.num_components)-1,
**({"ALIGN_MUL":val.bit_size//8*val.num_components} if space != AddrSpace.REG else {})}, **({"ALIGN_MUL":val.bit_size//8*val.num_components} if space != AddrSpace.REG else {})},
num_components=lambda val:val.num_components, srcs=lambda space, addr, val: [nsrc(val), nsrc(addr)][::1 if space != AddrSpace.REG else -1])( num_components=lambda val:val.num_components, srcs=lambda space, addr, val: [nsrc(val), nsrc(addr)][::1 if space != AddrSpace.REG else -1])(
lambda b, space, addr, val, dtype: mesa.nir_intrinsic_instr_create(b.shader, g(f"nir_intrinsic_store_{scope(space)}"))) lambda b, space, addr, val: mesa.nir_intrinsic_instr_create(b.shader, g(f"nir_intrinsic_store_{scope(space)}")))
nload = nir_instr(nc=lambda u:u.max_numel(), bs=lambda u:u.dtype.bitsize//u.max_numel(), num_components=lambda u:u.max_numel(), nload = nir_instr(nc=lambda u:u.max_numel(), bs=lambda u:u.dtype.bitsize, num_components=lambda u:u.max_numel(),
intrins=lambda space,u:{**({"ACCESS":mesa.ACCESS_CAN_REORDER} if space==AddrSpace.GLOBAL else {}), intrins=lambda space,u:{**({"ACCESS":mesa.ACCESS_CAN_REORDER} if space==AddrSpace.GLOBAL else {}),
**({"ALIGN_MUL":u.dtype.itemsize} if space != AddrSpace.REG else {})}, srcs=lambda addr: [nsrc(addr)])( **({"ALIGN_MUL":u.dtype.itemsize*u.max_numel()} if space != AddrSpace.REG else {})}, srcs=lambda addr: [nsrc(addr)])(
lambda b, space, addr, u: mesa.nir_intrinsic_instr_create(b.shader, g(f"nir_intrinsic_load_{scope(space)}"))) lambda b, space, addr, u: mesa.nir_intrinsic_instr_create(b.shader, g(f"nir_intrinsic_load_{scope(space)}")))
ngid = nir_instr(nc=3, bs=32)(lambda b: mesa.nir_intrinsic_instr_create(b.shader, mesa.nir_intrinsic_load_workgroup_id)) ngid = nir_instr(nc=3, bs=32)(lambda b: mesa.nir_intrinsic_instr_create(b.shader, mesa.nir_intrinsic_load_workgroup_id))
@ -105,17 +104,18 @@ def njump(b:mesa.nir_builder, typ, tgt=None, cond=None, else_tgt=None): return m
def if_phi(b:mesa.nir_builder, cond, then_fn, else_fn): return mesa.nir_if_phi(b, *nif(b, cond, then_fn, else_fn)).contents def if_phi(b:mesa.nir_builder, cond, then_fn, else_fn): return mesa.nir_if_phi(b, *nif(b, cond, then_fn, else_fn)).contents
def nidx(b:mesa.nir_builder, buf, off, dtype, gate=None) -> mesa.nir_def: def nidx(b:mesa.nir_builder, buf, off, space, itemsize, gate=None) -> mesa.nir_def:
@nir_instr(nc=1, bs=32, modes=lambda buf: buf.data.mode, type=lambda buf: mesa.glsl_get_array_element(buf.type)) @nir_instr(nc=1, bs=32, modes=lambda buf: buf.data.mode, type=lambda buf: mesa.glsl_get_array_element(buf.type))
def reg(b, buf): def reg(b, buf):
deref = mesa.nir_deref_instr_create(b.shader, mesa.nir_deref_type_array) deref = mesa.nir_deref_instr_create(b.shader, mesa.nir_deref_type_array)
deref.contents.parent, deref.contents.arr.index = nsrc(deref_var(b, buf)), nsrc(off) deref.contents.parent, deref.contents.arr.index = nsrc(deref_var(b, buf)), nsrc(off)
return deref return deref
f = (functools.partial(reg, b, buf) if dtype.addrspace == AddrSpace.REG else f = (functools.partial(reg, b, buf) if space == AddrSpace.REG else
lambda: nalu(b, "iadd", buf, nalu(b, "imul", off, nimm(b, dtype.itemsize, dtypes.long)))) lambda: nalu(b, "iadd", buf, nalu(b, "imul", off, nimm(b, itemsize, dtypes.long))))
return if_phi(b, gate, f, lambda: buf) if gate is not None else f() return if_phi(b, gate, f, lambda: buf) if gate is not None else f()
class NIRRenderer(Renderer): class NIRRenderer(Renderer):
new_style = True
suffix = "NIR" suffix = "NIR"
nir_options: bytes nir_options: bytes
global_max, local_max, shared_max = CUDARenderer.global_max, CUDARenderer.local_max, CUDARenderer.shared_max global_max, local_max, shared_max = CUDARenderer.global_max, CUDARenderer.local_max, CUDARenderer.shared_max
@ -146,22 +146,23 @@ class NIRRenderer(Renderer):
def_rewrite = PatternMatcher([ def_rewrite = PatternMatcher([
(UPat(Ops.CONST, name="x"), lambda ctx,x: nimm(ctx.b, x.arg, x.dtype)), (UPat(Ops.CONST, name="x"), lambda ctx,x: nimm(ctx.b, x.arg, x.dtype)),
(UPat(Ops.PARAM, name="x"), lambda ctx,x: ctx.param(ctx.b, x, 8)), (UPat(Ops.PARAM, name="x"), lambda ctx,x: ctx.param(ctx.b, x, 8 if x.addrspace is not None else x.dtype.itemsize)),
(UPat(Ops.DEFINE_VAR, name="x"), lambda ctx,x: ctx.param(ctx.b, x, 4)),
(UPat(Ops.SPECIAL, name="x"), lambda ctx,x: nchannel(ctx.b, {'g':ngid, 'l':nlid, 'i': nid}[x.arg[0]](ctx.b), int(x.arg[-1]))), (UPat(Ops.SPECIAL, name="x"), lambda ctx,x: nchannel(ctx.b, {'g':ngid, 'l':nlid, 'i': nid}[x.arg[0]](ctx.b), int(x.arg[-1]))),
(UPat(Ops.STORE, src=(UPat(Ops.INDEX, src=(UPat.var("buf"),UPat.var("off"))).or_casted(), UPat.var("val"))), (UPat(Ops.STORE, src=(UPat((Ops.INDEX, Ops.SHRINK), src=(UPat.var("buf"),UPat.var("off")), allow_any_len=True), UPat.var("val"))),
lambda ctx,buf,off,val: nstore(ctx.b, buf.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype), ctx.r[val], val.dtype)), lambda ctx,buf,off,val: nstore(ctx.b, buf.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.addrspace, buf.dtype.itemsize), ctx.r[val])),
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("off"))).or_casted(), UPat.var("alt"), UPat.var("gate")), name="x"), (UPat(Ops.LOAD, src=(UPat((Ops.INDEX, Ops.SHRINK), src=(UPat.var("buf"), UPat.var("off")), allow_any_len=True), UPat.var("alt"),
UPat.var("gate")), name="x"),
lambda ctx,x,buf,off,alt,gate: if_phi(ctx.b, ctx.r[gate], lambda ctx,x,buf,off,alt,gate: if_phi(ctx.b, ctx.r[gate],
lambda: nload(ctx.b, buf.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype, ctx.r[gate]), x), lambda: ctx.r[alt])), lambda: nload(ctx.b, buf.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.addrspace, buf.dtype.itemsize, ctx.r[gate]), x),
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("off"))).or_casted(),), name="x"), lambda: ctx.r[alt])),
lambda ctx,x,buf,off: nload(ctx.b, buf.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype), x)), (UPat(Ops.LOAD, src=(UPat((Ops.INDEX, Ops.SHRINK), src=(UPat.var("buf"), UPat.var("off")), allow_any_len=True),), name="x"),
lambda ctx,x,buf,off: nload(ctx.b, buf.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.addrspace, buf.dtype.itemsize), x)),
(UPat(Ops.STACK, name="x"), lambda ctx,x: nalu(ctx.b, f"vec{x.max_numel()}", *[ctx.r[src] for src in x.src])), (UPat(Ops.STACK, name="x"), lambda ctx,x: nalu(ctx.b, f"vec{x.max_numel()}", *[ctx.r[src] for src in x.src])),
(UPat(GroupOp.ALU, name="x"), lambda ctx,x: nalu(ctx.b, aop[x.src[0].dtype.scalar()][x.op], *[ctx.r[src] for src in x.src])), (UPat(GroupOp.ALU, name="x"), lambda ctx,x: nalu(ctx.b, aop[x.src[0].dtype][x.op], *[ctx.r[src] for src in x.src])),
(UPat(Ops.CAST, name="x"), lambda ctx,x: ncast(ctx.b, ctx.r[x.src[0]], x.src[0].dtype, x.dtype)), (UPat(Ops.CAST, name="x"), lambda ctx,x: ncast(ctx.b, ctx.r[x.src[0]], x.src[0].dtype, x.dtype)),
(UPat(Ops.BITCAST, src=(UPat.var("a"),), allow_any_len=True), lambda ctx,a: ctx.r[a]), (UPat(Ops.BITCAST, src=(UPat.var("a"),), allow_any_len=True), lambda ctx,a: ctx.r[a]),
(UPat(Ops.GEP, src=(UPat.var("a"),), name="x"), lambda ctx,x,a: nchannel(ctx.b, ctx.r[a], x.arg[0])), (UPat(Ops.BUFFER, name="x"), lambda ctx,x: mesa.nir_local_variable_create(ctx.b.impl,
(UPat(Ops.DEFINE_REG, name="x"), lambda ctx,x:mesa.nir_local_variable_create(ctx.b.impl, glsl_type(x.dtype), f"acc{x.arg}".encode()).contents), mesa.glsl_array_type(glsl_type(x.dtype), x.max_numel(), 0).contents, f"acc{x.arg.slot}".encode()).contents),
(UPat(Ops.BARRIER), lambda ctx: nbarrier(ctx.b)), (UPat(Ops.BARRIER), lambda ctx: nbarrier(ctx.b)),
(UPat(Ops.IF, name="x"), lambda ctx,x: mesa.nir_push_if(ctx.b, ctx.r[x.src[0]])), (UPat(Ops.IF, name="x"), lambda ctx,x: mesa.nir_push_if(ctx.b, ctx.r[x.src[0]])),
(UPat(Ops.ENDIF, name="x"), lambda ctx,x: (lambda _: mesa.nir_def())(mesa.nir_pop_if(ctx.b, ctx.r[x.src[0]]))) (UPat(Ops.ENDIF, name="x"), lambda ctx,x: (lambda _: mesa.nir_def())(mesa.nir_pop_if(ctx.b, ctx.r[x.src[0]])))
@ -190,19 +191,21 @@ class NIRRenderer(Renderer):
self.param_idx, ranges = 0, [] self.param_idx, ranges = 0, []
for u in uops: for u in uops:
if u.op in {Ops.NOOP, Ops.GROUP, Ops.INDEX}: pass if u.op in {Ops.NOOP, Ops.GROUP} or (u.op is Ops.STACK and len(u.src) == 0): pass
elif u.op is Ops.CAST and isinstance(u.dtype, PtrDType): pass elif u.op in {Ops.INDEX, Ops.SHRINK}:
# INDEX on a register value picks the element, memory INDEX is handled in the LOAD/STORE patterns
if u.src[0].op not in {Ops.PARAM, Ops.BUFFER, Ops.AFTER}: self.r[u] = nchannel(self.b, self.r[u.src[0]], u.src[1].arg)
elif u.op is Ops.AFTER: elif u.op is Ops.AFTER:
self.r[u] = self.r[u.src[0]] self.r[u] = self.r[u.src[0]]
elif u.op == Ops.SINK: elif u.op == Ops.SINK:
if u.arg is not None: if u.arg is not None:
self.b.shader.contents.info.name = ctypes.cast(ctypes.create_string_buffer(u.arg.function_name.encode()), POINTER[ctypes.c_char]) self.b.shader.contents.info.name = ctypes.cast(ctypes.create_string_buffer(u.arg.function_name.encode()), POINTER[ctypes.c_char])
elif u.op == Ops.DEFINE_LOCAL: elif u.op == Ops.BUFFER and u.addrspace == AddrSpace.LOCAL:
self.r[u] = nimm(self.b, self.b.shader.contents.info.shared_size, dtypes.long) self.r[u] = nimm(self.b, self.b.shader.contents.info.shared_size, dtypes.long)
self.b.shader.contents.info.shared_size += u.dtype.nbytes() self.b.shader.contents.info.shared_size += u.max_numel()*u.dtype.itemsize
elif u.op == Ops.RANGE: elif u.op == Ops.RANGE:
ranges.append(i:=deref_var(self.b, mesa.nir_local_variable_create(self.b.impl, glsl_type(u.dtype), f"idx{range_str(u)}".encode()).contents)) ranges.append(i:=deref_var(self.b, mesa.nir_local_variable_create(self.b.impl, glsl_type(u.dtype), f"idx{range_str(u)}".encode()).contents))
nstore(self.b, AddrSpace.REG, i, nimm(self.b, 0, u.dtype), u.dtype) nstore(self.b, AddrSpace.REG, i, nimm(self.b, 0, u.dtype))
mesa.nir_push_loop(self.b) mesa.nir_push_loop(self.b)
self.r[u] = nload(self.b, AddrSpace.REG, i, u) self.r[u] = nload(self.b, AddrSpace.REG, i, u)
nif(self.b, nalu(self.b, "ilt", self.r[u], self.r[u.src[0]]), lambda: None, lambda: njump(self.b, mesa.nir_jump_break)) nif(self.b, nalu(self.b, "ilt", self.r[u], self.r[u.src[0]]), lambda: None, lambda: njump(self.b, mesa.nir_jump_break))
@ -211,7 +214,7 @@ class NIRRenderer(Renderer):
next_i = nalu(self.b, "iadd", self.r[r], nimm(self.b, 1, r.dtype)) next_i = nalu(self.b, "iadd", self.r[r], nimm(self.b, 1, r.dtype))
# TODO: this nif should be removable ... but TestMultiTensor.test_double_matmul_shard_W_0 segfaults with it gone # TODO: this nif should be removable ... but TestMultiTensor.test_double_matmul_shard_W_0 segfaults with it gone
nif(self.b, nalu(self.b, "ilt", next_i, self.r[r.src[0]]), lambda: None, lambda: njump(self.b, mesa.nir_jump_break)) nif(self.b, nalu(self.b, "ilt", next_i, self.r[r.src[0]]), lambda: None, lambda: njump(self.b, mesa.nir_jump_break))
nstore(self.b, AddrSpace.REG, ranges.pop(), next_i, r.dtype), nstore(self.b, AddrSpace.REG, ranges.pop(), next_i),
mesa.nir_pop_loop(self.b, None) mesa.nir_pop_loop(self.b, None)
else: else:
if (d:=self.def_rewrite.rewrite(u, ctx=self)) is None: raise RuntimeError(f"failed to render {u.op} srcs {[x.dtype for x in u.src]}") if (d:=self.def_rewrite.rewrite(u, ctx=self)) is None: raise RuntimeError(f"failed to render {u.op} srcs {[x.dtype for x in u.src]}")
@ -254,7 +257,7 @@ class LVPRenderer(NIRRenderer):
def prerender(self, uops:list[UOp]): def prerender(self, uops:list[UOp]):
super().prerender(uops) super().prerender(uops)
self.param_sz = sum([8 if u.op == Ops.PARAM else u.dtype.itemsize for u in uops if u.op in (Ops.PARAM, Ops.DEFINE_VAR)]) self.param_sz = sum([8 if u.addrspace is not None else u.dtype.itemsize for u in uops if u.op is Ops.PARAM])
def tovec(b, idx_y, idx_x): return nalu(b, "vec4", idx_x, idx_y, nundef(b, dtypes.int), nundef(b, dtypes.int)) def tovec(b, idx_y, idx_x): return nalu(b, "vec4", idx_x, idx_y, nundef(b, dtypes.int), nundef(b, dtypes.int))
def nfloat(dtype): return mesa.nir_type_float16 if dtype == dtypes.half else mesa.nir_type_float32 def nfloat(dtype): return mesa.nir_type_float16 if dtype == dtypes.half else mesa.nir_type_float32
@ -269,7 +272,6 @@ _nload_img = nir_instr(intrins=lambda dtype:{'IMAGE_DIM':mesa.GLSL_SAMPLER_DIM_2
lambda b,img,idx_y,idx_x,dtype: mesa.nir_intrinsic_instr_create(b.shader, g("nir_intrinsic_image_load"))) lambda b,img,idx_y,idx_x,dtype: mesa.nir_intrinsic_instr_create(b.shader, g("nir_intrinsic_image_load")))
class IR3Renderer(NIRRenderer, OpenCLRenderer): class IR3Renderer(NIRRenderer, OpenCLRenderer):
new_style = False
has_aux = True has_aux = True
def nload_img(ctx,img,idx_y,idx_x): def nload_img(ctx,img,idx_y,idx_x):
@ -294,11 +296,11 @@ class IR3Renderer(NIRRenderer, OpenCLRenderer):
def prerender(self, uops:list[UOp]): def prerender(self, uops:list[UOp]):
super().prerender(uops) super().prerender(uops)
self.texs:set[UOp] = set() self.texs:set[UOp] = set()
self.uops, self.ibo_idx, self.img_idx = uops, 0, 0 self.img_idx = 0
self.param_sz = sum([8 if u.op == Ops.PARAM else u.dtype.itemsize for u in uops if u.op in (Ops.PARAM, Ops.DEFINE_VAR)]) self.param_sz = sum([8 if u.addrspace is not None else u.dtype.itemsize for u in uops if u.op is Ops.PARAM])
def postrender(self, uops:list[UOp]): def postrender(self, uops:list[UOp]):
bufs, texs, imgs = [u for u in uops if u.op == Ops.PARAM], itertools.count().__next__, itertools.count().__next__ bufs, texs, imgs = [u for u in uops if u.op is Ops.PARAM and u.addrspace is not None], itertools.count().__next__, itertools.count().__next__
for b in filter(lambda b: isinstance(b.dtype, ImageDType), bufs): nimm_set(self.r[b], texs() if b in self.texs else imgs(), dtypes.int) for b in filter(lambda b: isinstance(b.dtype, ImageDType), bufs): nimm_set(self.r[b], texs() if b in self.texs else imgs(), dtypes.int)
self.b.shader.contents.info.num_ubos = len([u for u in bufs if not isinstance(u.dtype, ImageDType)]) self.b.shader.contents.info.num_ubos = len([u for u in bufs if not isinstance(u.dtype, ImageDType)])