Compare commits

...

2 commits

Author SHA1 Message Date
George Hotz
785f09aac4 lil cleanup 2026-06-11 17:44:22 -07:00
George Hotz
32ad5e8b96 port NIR to new_style (fable) 2026-06-11 17:34:01 -07:00

View file

@ -1,5 +1,5 @@
from typing import Callable, cast, Any
from tinygrad.dtype import AddrSpace, DType, PtrDType, ImageDType, dtypes, truncate
from tinygrad.dtype import AddrSpace, DType, ImageDType, dtypes, truncate
from tinygrad.helpers import DEBUG, OSX, unwrap, fromimport, Target
from tinygrad.renderer import Renderer
from tinygrad.renderer.cstyle import CUDARenderer, OpenCLRenderer
@ -11,7 +11,7 @@ import base64, ctypes, ctypes.util, struct, functools, inspect, itertools
def g(s:str): return getattr(mesa, s)
def nsrc(d:mesa.nir_def) -> mesa.nir_src: return mesa.nir_src(ssa=ctypes.pointer(d))
def glsl_type(t:DType): return mesa.glsl_array_type(glsl_type(t.base), t.size, 0).contents if isinstance(t, PtrDType) else {
def glsl_type(t:DType): return {
**{getattr(dtypes,k):g(f"glsl_type_builtin_{v}") for k,v in [('double','double'),('float','float'),('float16','float16_t'),('bool','uint8_t')]},
**{d:g(f"glsl_type_builtin_{'u' * (d in dtypes.uints)}int{str(d.bitsize)+'_t' if d.itemsize != 4 else ''}") for d in dtypes.ints}}[t]
@ -25,7 +25,6 @@ aop = {**{x:u_aop for x in (dtypes.bool,)+dtypes.uints}, **{x:s_aop for x in dty
def c(t:DType, u:bool=True) -> str: return "u" if t in dtypes.uints and u else ("i" if t in dtypes.ints else ("f" if t in dtypes.floats else "b"))
def ncast(b:mesa.nir_builder, src:mesa.nir_def, it:DType, ot:DType) -> mesa.nir_def:
if isinstance(it, PtrDType) and ot == dtypes.long: return src
return nalu(b, f"{c(it)}2{c(it) if it in dtypes.ints and ot in dtypes.ints else c(ot, ot == dtypes.bool)}{ot.bitsize}", src)
def nif(b:mesa.nir_builder, cond:mesa.nir_def, then_fn:Callable, else_fn:Callable):
@ -85,10 +84,10 @@ def scope(space): return 'global' if space == AddrSpace.GLOBAL else ('shared' if
nstore = nir_instr(has_def=False, df=lambda addr:addr, intrins=lambda space,val: {"WRITE_MASK":(1<<val.num_components)-1,
**({"ALIGN_MUL":val.bit_size//8*val.num_components} if space != AddrSpace.REG else {})},
num_components=lambda val:val.num_components, srcs=lambda space, addr, val: [nsrc(val), nsrc(addr)][::1 if space != AddrSpace.REG else -1])(
lambda b, space, addr, val, dtype: mesa.nir_intrinsic_instr_create(b.shader, g(f"nir_intrinsic_store_{scope(space)}")))
nload = nir_instr(nc=lambda u:u.max_numel(), bs=lambda u:u.dtype.bitsize//u.max_numel(), num_components=lambda u:u.max_numel(),
lambda b, space, addr, val: mesa.nir_intrinsic_instr_create(b.shader, g(f"nir_intrinsic_store_{scope(space)}")))
nload = nir_instr(nc=lambda u:u.max_numel(), bs=lambda u:u.dtype.bitsize, num_components=lambda u:u.max_numel(),
intrins=lambda space,u:{**({"ACCESS":mesa.ACCESS_CAN_REORDER} if space==AddrSpace.GLOBAL else {}),
**({"ALIGN_MUL":u.dtype.itemsize} if space != AddrSpace.REG else {})}, srcs=lambda addr: [nsrc(addr)])(
**({"ALIGN_MUL":u.dtype.itemsize*u.max_numel()} if space != AddrSpace.REG else {})}, srcs=lambda addr: [nsrc(addr)])(
lambda b, space, addr, u: mesa.nir_intrinsic_instr_create(b.shader, g(f"nir_intrinsic_load_{scope(space)}")))
ngid = nir_instr(nc=3, bs=32)(lambda b: mesa.nir_intrinsic_instr_create(b.shader, mesa.nir_intrinsic_load_workgroup_id))
@ -105,17 +104,18 @@ def njump(b:mesa.nir_builder, typ, tgt=None, cond=None, else_tgt=None): return m
def if_phi(b:mesa.nir_builder, cond, then_fn, else_fn): return mesa.nir_if_phi(b, *nif(b, cond, then_fn, else_fn)).contents
def nidx(b:mesa.nir_builder, buf, off, dtype, gate=None) -> mesa.nir_def:
def nidx(b:mesa.nir_builder, buf, off, space, itemsize, gate=None) -> mesa.nir_def:
@nir_instr(nc=1, bs=32, modes=lambda buf: buf.data.mode, type=lambda buf: mesa.glsl_get_array_element(buf.type))
def reg(b, buf):
deref = mesa.nir_deref_instr_create(b.shader, mesa.nir_deref_type_array)
deref.contents.parent, deref.contents.arr.index = nsrc(deref_var(b, buf)), nsrc(off)
return deref
f = (functools.partial(reg, b, buf) if dtype.addrspace == AddrSpace.REG else
lambda: nalu(b, "iadd", buf, nalu(b, "imul", off, nimm(b, dtype.itemsize, dtypes.long))))
f = (functools.partial(reg, b, buf) if space == AddrSpace.REG else
lambda: nalu(b, "iadd", buf, nalu(b, "imul", off, nimm(b, itemsize, dtypes.long))))
return if_phi(b, gate, f, lambda: buf) if gate is not None else f()
class NIRRenderer(Renderer):
new_style = True
suffix = "NIR"
nir_options: bytes
global_max, local_max, shared_max = CUDARenderer.global_max, CUDARenderer.local_max, CUDARenderer.shared_max
@ -146,22 +146,23 @@ class NIRRenderer(Renderer):
def_rewrite = PatternMatcher([
(UPat(Ops.CONST, name="x"), lambda ctx,x: nimm(ctx.b, x.arg, x.dtype)),
(UPat(Ops.PARAM, name="x"), lambda ctx,x: ctx.param(ctx.b, x, 8)),
(UPat(Ops.DEFINE_VAR, name="x"), lambda ctx,x: ctx.param(ctx.b, x, 4)),
(UPat(Ops.PARAM, name="x"), lambda ctx,x: ctx.param(ctx.b, x, 8 if x.addrspace is not None else x.dtype.itemsize)),
(UPat(Ops.SPECIAL, name="x"), lambda ctx,x: nchannel(ctx.b, {'g':ngid, 'l':nlid, 'i': nid}[x.arg[0]](ctx.b), int(x.arg[-1]))),
(UPat(Ops.STORE, src=(UPat(Ops.INDEX, src=(UPat.var("buf"),UPat.var("off"))).or_casted(), UPat.var("val"))),
lambda ctx,buf,off,val: nstore(ctx.b, buf.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype), ctx.r[val], val.dtype)),
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("off"))).or_casted(), UPat.var("alt"), UPat.var("gate")), name="x"),
(UPat(Ops.STORE, src=(UPat((Ops.INDEX, Ops.SHRINK), src=(UPat.var("buf"),UPat.var("off")), allow_any_len=True), UPat.var("val"))),
lambda ctx,buf,off,val: nstore(ctx.b, buf.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.addrspace, buf.dtype.itemsize), ctx.r[val])),
(UPat(Ops.LOAD, src=(UPat((Ops.INDEX, Ops.SHRINK), src=(UPat.var("buf"), UPat.var("off")), allow_any_len=True), UPat.var("alt"),
UPat.var("gate")), name="x"),
lambda ctx,x,buf,off,alt,gate: if_phi(ctx.b, ctx.r[gate],
lambda: nload(ctx.b, buf.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype, ctx.r[gate]), x), lambda: ctx.r[alt])),
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("off"))).or_casted(),), name="x"),
lambda ctx,x,buf,off: nload(ctx.b, buf.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype), x)),
lambda: nload(ctx.b, buf.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.addrspace, buf.dtype.itemsize, ctx.r[gate]), x),
lambda: ctx.r[alt])),
(UPat(Ops.LOAD, src=(UPat((Ops.INDEX, Ops.SHRINK), src=(UPat.var("buf"), UPat.var("off")), allow_any_len=True),), name="x"),
lambda ctx,x,buf,off: nload(ctx.b, buf.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.addrspace, buf.dtype.itemsize), x)),
(UPat(Ops.STACK, name="x"), lambda ctx,x: nalu(ctx.b, f"vec{x.max_numel()}", *[ctx.r[src] for src in x.src])),
(UPat(GroupOp.ALU, name="x"), lambda ctx,x: nalu(ctx.b, aop[x.src[0].dtype.scalar()][x.op], *[ctx.r[src] for src in x.src])),
(UPat(GroupOp.ALU, name="x"), lambda ctx,x: nalu(ctx.b, aop[x.src[0].dtype][x.op], *[ctx.r[src] for src in x.src])),
(UPat(Ops.CAST, name="x"), lambda ctx,x: ncast(ctx.b, ctx.r[x.src[0]], x.src[0].dtype, x.dtype)),
(UPat(Ops.BITCAST, src=(UPat.var("a"),), allow_any_len=True), lambda ctx,a: ctx.r[a]),
(UPat(Ops.GEP, src=(UPat.var("a"),), name="x"), lambda ctx,x,a: nchannel(ctx.b, ctx.r[a], x.arg[0])),
(UPat(Ops.DEFINE_REG, name="x"), lambda ctx,x:mesa.nir_local_variable_create(ctx.b.impl, glsl_type(x.dtype), f"acc{x.arg}".encode()).contents),
(UPat(Ops.BUFFER, name="x"), lambda ctx,x: mesa.nir_local_variable_create(ctx.b.impl,
mesa.glsl_array_type(glsl_type(x.dtype), x.max_numel(), 0).contents, f"acc{x.arg.slot}".encode()).contents),
(UPat(Ops.BARRIER), lambda ctx: nbarrier(ctx.b)),
(UPat(Ops.IF, name="x"), lambda ctx,x: mesa.nir_push_if(ctx.b, ctx.r[x.src[0]])),
(UPat(Ops.ENDIF, name="x"), lambda ctx,x: (lambda _: mesa.nir_def())(mesa.nir_pop_if(ctx.b, ctx.r[x.src[0]])))
@ -190,19 +191,21 @@ class NIRRenderer(Renderer):
self.param_idx, ranges = 0, []
for u in uops:
if u.op in {Ops.NOOP, Ops.GROUP, Ops.INDEX}: pass
elif u.op is Ops.CAST and isinstance(u.dtype, PtrDType): pass
if u.op in {Ops.NOOP, Ops.GROUP} or (u.op is Ops.STACK and len(u.src) == 0): pass
elif u.op in {Ops.INDEX, Ops.SHRINK}:
# INDEX on a register value picks the element, memory INDEX is handled in the LOAD/STORE patterns
if u.src[0].op not in {Ops.PARAM, Ops.BUFFER, Ops.AFTER}: self.r[u] = nchannel(self.b, self.r[u.src[0]], u.src[1].arg)
elif u.op is Ops.AFTER:
self.r[u] = self.r[u.src[0]]
elif u.op == Ops.SINK:
if u.arg is not None:
self.b.shader.contents.info.name = ctypes.cast(ctypes.create_string_buffer(u.arg.function_name.encode()), POINTER[ctypes.c_char])
elif u.op == Ops.DEFINE_LOCAL:
elif u.op == Ops.BUFFER and u.addrspace == AddrSpace.LOCAL:
self.r[u] = nimm(self.b, self.b.shader.contents.info.shared_size, dtypes.long)
self.b.shader.contents.info.shared_size += u.dtype.nbytes()
self.b.shader.contents.info.shared_size += u.max_numel()*u.dtype.itemsize
elif u.op == Ops.RANGE:
ranges.append(i:=deref_var(self.b, mesa.nir_local_variable_create(self.b.impl, glsl_type(u.dtype), f"idx{range_str(u)}".encode()).contents))
nstore(self.b, AddrSpace.REG, i, nimm(self.b, 0, u.dtype), u.dtype)
nstore(self.b, AddrSpace.REG, i, nimm(self.b, 0, u.dtype))
mesa.nir_push_loop(self.b)
self.r[u] = nload(self.b, AddrSpace.REG, i, u)
nif(self.b, nalu(self.b, "ilt", self.r[u], self.r[u.src[0]]), lambda: None, lambda: njump(self.b, mesa.nir_jump_break))
@ -211,7 +214,7 @@ class NIRRenderer(Renderer):
next_i = nalu(self.b, "iadd", self.r[r], nimm(self.b, 1, r.dtype))
# TODO: this nif should be removable ... but TestMultiTensor.test_double_matmul_shard_W_0 segfaults with it gone
nif(self.b, nalu(self.b, "ilt", next_i, self.r[r.src[0]]), lambda: None, lambda: njump(self.b, mesa.nir_jump_break))
nstore(self.b, AddrSpace.REG, ranges.pop(), next_i, r.dtype),
nstore(self.b, AddrSpace.REG, ranges.pop(), next_i),
mesa.nir_pop_loop(self.b, None)
else:
if (d:=self.def_rewrite.rewrite(u, ctx=self)) is None: raise RuntimeError(f"failed to render {u.op} srcs {[x.dtype for x in u.src]}")
@ -254,7 +257,7 @@ class LVPRenderer(NIRRenderer):
def prerender(self, uops:list[UOp]):
super().prerender(uops)
self.param_sz = sum([8 if u.op == Ops.PARAM else u.dtype.itemsize for u in uops if u.op in (Ops.PARAM, Ops.DEFINE_VAR)])
self.param_sz = sum([8 if u.addrspace is not None else u.dtype.itemsize for u in uops if u.op is Ops.PARAM])
def tovec(b, idx_y, idx_x): return nalu(b, "vec4", idx_x, idx_y, nundef(b, dtypes.int), nundef(b, dtypes.int))
def nfloat(dtype): return mesa.nir_type_float16 if dtype == dtypes.half else mesa.nir_type_float32
@ -269,7 +272,6 @@ _nload_img = nir_instr(intrins=lambda dtype:{'IMAGE_DIM':mesa.GLSL_SAMPLER_DIM_2
lambda b,img,idx_y,idx_x,dtype: mesa.nir_intrinsic_instr_create(b.shader, g("nir_intrinsic_image_load")))
class IR3Renderer(NIRRenderer, OpenCLRenderer):
new_style = False
has_aux = True
def nload_img(ctx,img,idx_y,idx_x):
@ -294,11 +296,11 @@ class IR3Renderer(NIRRenderer, OpenCLRenderer):
def prerender(self, uops:list[UOp]):
super().prerender(uops)
self.texs:set[UOp] = set()
self.uops, self.ibo_idx, self.img_idx = uops, 0, 0
self.param_sz = sum([8 if u.op == Ops.PARAM else u.dtype.itemsize for u in uops if u.op in (Ops.PARAM, Ops.DEFINE_VAR)])
self.img_idx = 0
self.param_sz = sum([8 if u.addrspace is not None else u.dtype.itemsize for u in uops if u.op is Ops.PARAM])
def postrender(self, uops:list[UOp]):
bufs, texs, imgs = [u for u in uops if u.op == Ops.PARAM], itertools.count().__next__, itertools.count().__next__
bufs, texs, imgs = [u for u in uops if u.op is Ops.PARAM and u.addrspace is not None], itertools.count().__next__, itertools.count().__next__
for b in filter(lambda b: isinstance(b.dtype, ImageDType), bufs): nimm_set(self.r[b], texs() if b in self.texs else imgs(), dtypes.int)
self.b.shader.contents.info.num_ubos = len([u for u in bufs if not isinstance(u.dtype, ImageDType)])