fix for nir

This commit is contained in:
George Hotz 2026-06-01 21:46:45 -07:00
commit 394afe40c5
2 changed files with 63 additions and 24 deletions

View file

@ -1,17 +1,17 @@
from typing import Callable, cast, Any
from tinygrad.dtype import AddrSpace, DType, PtrDType, ImageDType, dtypes, truncate
from tinygrad.dtype import AddrSpace, DType, ImageDType, dtypes, truncate
from tinygrad.helpers import DEBUG, OSX, unwrap, fromimport, Target
from tinygrad.renderer import Renderer
from tinygrad.renderer.cstyle import CUDARenderer, OpenCLRenderer
from tinygrad.uop.ops import GroupOp, Ops, UOp, PatternMatcher, UPat, range_str
from tinygrad.runtime.autogen import mesa
from tinygrad.runtime.support.c import POINTER
import base64, ctypes, ctypes.util, struct, functools, inspect, itertools
import base64, ctypes, ctypes.util, struct, functools, inspect, itertools, os, warnings
def g(s:str): return getattr(mesa, s)
def nsrc(d:mesa.nir_def) -> mesa.nir_src: return mesa.nir_src(ssa=ctypes.pointer(d))
def glsl_type(t:DType): return mesa.glsl_array_type(glsl_type(t.base), t.size, 0).contents if isinstance(t, PtrDType) else {
def glsl_type(t:DType): return {
**{getattr(dtypes,k):g(f"glsl_type_builtin_{v}") for k,v in [('double','double'),('float','float'),('float16','float16_t'),('bool','uint8_t')]},
**{d:g(f"glsl_type_builtin_{'u' * (d in dtypes.uints)}int{str(d.bitsize)+'_t' if d.itemsize != 4 else ''}") for d in dtypes.ints}}[t]
@ -25,7 +25,6 @@ aop = {**{x:u_aop for x in (dtypes.bool,)+dtypes.uints}, **{x:s_aop for x in dty
def c(t:DType, u:bool=True) -> str: return "u" if t in dtypes.uints and u else ("i" if t in dtypes.ints else ("f" if t in dtypes.floats else "b"))
def ncast(b:mesa.nir_builder, src:mesa.nir_def, it:DType, ot:DType) -> mesa.nir_def:
if isinstance(it, PtrDType) and ot == dtypes.long: return src
return nalu(b, f"{c(it)}2{c(it) if it in dtypes.ints and ot in dtypes.ints else c(ot, ot == dtypes.bool)}{ot.bitsize}", src)
def nif(b:mesa.nir_builder, cond:mesa.nir_def, then_fn:Callable, else_fn:Callable):
@ -86,9 +85,9 @@ def scope(space): return 'global' if space == AddrSpace.GLOBAL else ('shared' if
nstore = nir_instr(has_def=False, df=lambda addr:addr, intrins=lambda space,val: {"WRITE_MASK":(1<<val.num_components)-1, **iointr(space)},
num_components=lambda val:val.num_components, srcs=lambda space, addr, val: [nsrc(val), nsrc(addr)][::1 if space != AddrSpace.REG else -1])(
lambda b, space, addr, val, dtype: mesa.nir_intrinsic_instr_create(b.shader, g(f"nir_intrinsic_store_{scope(space)}")))
nload = nir_instr(nc=lambda dtype:dtype.count, bs=lambda dtype:dtype.bitsize//dtype.count, num_components=lambda dtype:dtype.count,
nload = nir_instr(nc=lambda count:count, bs=lambda dtype:dtype.bitsize, num_components=lambda count:count,
intrins=lambda space:{**({"ACCESS":mesa.ACCESS_CAN_REORDER} if space==AddrSpace.GLOBAL else {}), **iointr(space)}, srcs=lambda addr: [nsrc(addr)])(
lambda b, space, addr, dtype: mesa.nir_intrinsic_instr_create(b.shader, g(f"nir_intrinsic_load_{scope(space)}")))
lambda b, space, addr, dtype, count=1: mesa.nir_intrinsic_instr_create(b.shader, g(f"nir_intrinsic_load_{scope(space)}")))
ngid = nir_instr(nc=3, bs=32)(lambda b: mesa.nir_intrinsic_instr_create(b.shader, mesa.nir_intrinsic_load_workgroup_id))
nlid = nir_instr(nc=3, bs=32)(lambda b: mesa.nir_intrinsic_instr_create(b.shader, mesa.nir_intrinsic_load_local_invocation_id))
@ -104,16 +103,31 @@ def njump(b:mesa.nir_builder, typ, tgt=None, cond=None, else_tgt=None): return m
def if_phi(b:mesa.nir_builder, cond, then_fn, else_fn): return mesa.nir_if_phi(b, *nif(b, cond, then_fn, else_fn)).contents
def nidx(b:mesa.nir_builder, buf, off, dtype, gate=None) -> mesa.nir_def:
def _load_count(x:UOp) -> int: return x.max_numel() if 1 < x.max_numel() <= 4 else 1
def _pad_count(b:mesa.nir_builder, dtype:DType, count:int, val):
return val if val.num_components == count else nalu(b, f"vec{count}", val, *[nundef(b, dtype) for _ in range(count-1)])
def nidx(b:mesa.nir_builder, buf, off, dtype, addrspace, gate=None) -> mesa.nir_def:
@nir_instr(nc=1, bs=32, modes=lambda buf: buf.data.mode, type=lambda buf: mesa.glsl_get_array_element(buf.type))
def reg(b, buf):
deref = mesa.nir_deref_instr_create(b.shader, mesa.nir_deref_type_array)
deref.contents.parent, deref.contents.arr.index = nsrc(deref_var(b, buf)), nsrc(off)
return deref
f = (functools.partial(reg, b, buf) if dtype.addrspace == AddrSpace.REG else
f = (functools.partial(reg, b, buf) if addrspace == AddrSpace.REG else
lambda: nalu(b, "iadd", buf, nalu(b, "imul", off, nimm(b, dtype.itemsize, dtypes.long))))
return if_phi(b, gate, f, lambda: buf) if gate is not None else f()
def ngated_load_index(ctx, x, buf, off, alt, gate):
cnt = _load_count(x)
return if_phi(ctx.b, ctx.r[gate],
lambda: nload(ctx.b, buf.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype, buf.addrspace, ctx.r[gate]), x.dtype, cnt),
lambda: _pad_count(ctx.b, x.dtype, cnt, ctx.r[alt]))
def ngated_load_shrink(ctx, x, idx, alt, gate):
cnt = _load_count(idx)
return if_phi(ctx.b, ctx.r[gate], lambda: nload(ctx.b, idx.addrspace, ctx.r[idx], x.dtype, cnt),
lambda: _pad_count(ctx.b, x.dtype, cnt, ctx.r[alt]))
class NIRRenderer(Renderer):
suffix = "NIR"
nir_options: bytes
@ -137,7 +151,7 @@ class NIRRenderer(Renderer):
(UPat(Ops.CAST, (dtypes.uchar, dtypes.ushort), src=(UPat.var("x", dtypes.floats),), name="c"), lambda x,c: x.cast(dtypes.int32).cast(c.dtype)),
# load/store use pointer arithmetic, and the cast does nothing. NOTE: this doesn't apply to image indexing cause it's 1-D
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("off")), name="x"), lambda x,buf,off: x.replace(
src=(buf,off.cast(dtypes.long))) if buf.dtype.addrspace != AddrSpace.REG and off.op not in (Ops.CAST, Ops.STACK) else None),
src=(buf,off.cast(dtypes.long))) if buf.addrspace != AddrSpace.REG and off.op not in (Ops.CAST, Ops.STACK) else None),
# images need index to be int for nir
(UPat.var("buf").index(UPat.var("idx_y"), UPat.var("idx_x")),
lambda buf,idx_y,idx_x: buf.index(idx_y.cast(dtypes.int), idx_x.cast(dtypes.int))),
@ -149,18 +163,28 @@ class NIRRenderer(Renderer):
(UPat(Ops.DEFINE_VAR, name="x"), lambda ctx,x: ctx.param(ctx.b, x, 4)),
(UPat(Ops.SPECIAL, name="x"), lambda ctx,x: nchannel(ctx.b, {'g':ngid, 'l':nlid, 'i': nid}[x.arg[0]](ctx.b), int(x.arg[-1]))),
(UPat(Ops.STORE, src=(UPat(Ops.INDEX, src=(UPat.var("buf"),UPat.var("off"))).or_casted(), UPat.var("val"))),
lambda ctx,buf,off,val: nstore(ctx.b, buf.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype), ctx.r[val], val.dtype)),
lambda ctx,buf,off,val: nstore(ctx.b, buf.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype, buf.addrspace), ctx.r[val], val.dtype)),
(UPat(Ops.STORE, src=(UPat(Ops.SHRINK, name="idx"), UPat.var("val"))),
lambda ctx,idx,val: nstore(ctx.b, idx.addrspace, ctx.r[idx], ctx.r[val], val.dtype)),
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("off"))).or_casted(), UPat.var("alt"), UPat.var("gate")), name="x"),
lambda ctx,x,buf,off,alt,gate: if_phi(ctx.b, ctx.r[gate],
lambda: nload(ctx.b, buf.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype, ctx.r[gate]), x.dtype), lambda: ctx.r[alt])),
ngated_load_index),
(UPat(Ops.LOAD, src=(UPat(Ops.SHRINK, name="idx"), UPat.var("alt"), UPat.var("gate")), name="x"),
ngated_load_shrink),
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("off"))).or_casted(),), name="x"),
lambda ctx,x,buf,off: nload(ctx.b, buf.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype), x.dtype)),
(UPat(Ops.STACK, name="x"), lambda ctx,x: nalu(ctx.b, f"vec{x.dtype.count}", *[ctx.r[src] for src in x.src])),
lambda ctx,x,buf,off: nload(ctx.b, buf.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype, buf.addrspace), x.dtype, _load_count(x))),
(UPat(Ops.LOAD, src=(UPat(Ops.SHRINK, name="idx"),), name="x"),
lambda ctx,x,idx: nload(ctx.b, idx.addrspace, ctx.r[idx], x.dtype, _load_count(idx))),
(UPat(Ops.SHRINK, src=(UPat.var("buf"), UPat.var("off"), UPat.cvar()), name="x"),
lambda ctx,x,buf,off: nidx(ctx.b, ctx.r[buf], ctx.r[off], x.dtype, x.addrspace)),
(UPat(Ops.STACK, name="x"), lambda ctx,x: ctx.r[x.src[0]] if len(x.src) == 1 else
nalu(ctx.b, f"vec{len(x.src)}", *[ctx.r[src] for src in x.src])),
(UPat(GroupOp.ALU, name="x"), lambda ctx,x: nalu(ctx.b, aop[x.src[0].dtype.scalar()][x.op], *[ctx.r[src] for src in x.src])),
(UPat(Ops.CAST, name="x"), lambda ctx,x: ncast(ctx.b, ctx.r[x.src[0]], x.src[0].dtype, x.dtype)),
(UPat(Ops.BITCAST, src=(UPat.var("a"),), allow_any_len=True), lambda ctx,a: ctx.r[a]),
(UPat(Ops.GEP, src=(UPat.var("a"),), name="x"), lambda ctx,x,a: nchannel(ctx.b, ctx.r[a], x.arg[0])),
(UPat(Ops.DEFINE_REG, name="x"), lambda ctx,x:mesa.nir_local_variable_create(ctx.b.impl, glsl_type(x.dtype), f"acc{x.arg}".encode()).contents),
(UPat(Ops.INDEX, src=(UPat.var("a"), UPat.cvar("idx"))),
lambda ctx,a,idx: nchannel(ctx.b, ctx.r[a], idx.arg) if a.addrspace == AddrSpace.ANON else None),
(UPat(Ops.DEFINE_REG, name="x"), lambda ctx,x: mesa.nir_local_variable_create(ctx.b.impl,
mesa.glsl_array_type(glsl_type(x.dtype), x.src[0].arg, 0), f"acc{x.arg}".encode()).contents),
(UPat(Ops.BARRIER), lambda ctx: nbarrier(ctx.b)),
(UPat(Ops.IF, name="x"), lambda ctx,x: mesa.nir_push_if(ctx.b, ctx.r[x.src[0]])),
(UPat(Ops.ENDIF, name="x"), lambda ctx,x: (lambda _: mesa.nir_def())(mesa.nir_pop_if(ctx.b, ctx.r[x.src[0]])))
@ -189,8 +213,7 @@ class NIRRenderer(Renderer):
self.param_idx, ranges = 0, []
for u in uops:
if u.op in {Ops.NOOP, Ops.GROUP, Ops.INDEX}: pass
elif u.op is Ops.CAST and isinstance(u.dtype, PtrDType): pass
if u.op in {Ops.NOOP, Ops.GROUP} or (u.op is Ops.INDEX and u.src[0].addrspace != AddrSpace.ANON): pass
elif u.op is Ops.AFTER:
self.r[u] = self.r[u.src[0]]
elif u.op == Ops.SINK:
@ -198,7 +221,7 @@ class NIRRenderer(Renderer):
self.b.shader.contents.info.name = ctypes.cast(ctypes.create_string_buffer(u.arg.function_name.encode()), POINTER[ctypes.c_char])
elif u.op == Ops.DEFINE_LOCAL:
self.r[u] = nimm(self.b, self.b.shader.contents.info.shared_size, dtypes.long)
self.b.shader.contents.info.shared_size += u.dtype.nbytes()
self.b.shader.contents.info.shared_size += u.src[0].arg * u.dtype.itemsize
elif u.op == Ops.RANGE:
ranges.append(i:=deref_var(self.b, mesa.nir_local_variable_create(self.b.impl, glsl_type(u.dtype), f"idx{range_str(u)}".encode()).contents))
nstore(self.b, AddrSpace.REG, i, nimm(self.b, 0, u.dtype), u.dtype)
@ -217,7 +240,17 @@ class NIRRenderer(Renderer):
self.r[u] = cast(mesa.nir_def, d)
self.postrender(uops)
mesa.nir_validate_shader(self.b.shader, b"after render")
if DEBUG >= 2 and hasattr(os, "fork"):
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
pid = os.fork()
if pid == 0:
mesa.nir_validate_shader(self.b.shader, b"after render")
os._exit(0)
_, status = os.waitpid(pid, 0)
if os.WIFSIGNALED(status): raise RuntimeError(f"NIR validation failed after render with signal {os.WTERMSIG(status)}")
if os.WEXITSTATUS(status) != 0: raise RuntimeError(f"NIR validation failed after render with exit code {os.WEXITSTATUS(status)}")
else: mesa.nir_validate_shader(self.b.shader, b"after render")
if DEBUG >= 4: mesa.nir_print_shader(self.b.shader, ctypes.POINTER(mesa.struct__IO_FILE).in_dll(ctypes.CDLL(ctypes.util.find_library('c')),
"__stdoutp" if OSX else "stdout"))
mesa.nir_serialize(blob:=mesa.struct_blob(), self.b.shader, False)
@ -257,9 +290,10 @@ class LVPRenderer(NIRRenderer):
def tovec(b, idx_y, idx_x): return nalu(b, "vec4", idx_x, idx_y, nundef(b, dtypes.int), nundef(b, dtypes.int))
def nfloat(dtype): return mesa.nir_type_float16 if dtype == dtypes.half else mesa.nir_type_float32
nstore_img = nir_instr(has_def=False, df=lambda img:img, num_components=lambda val:val.num_components,
nstore_img = nir_instr(has_def=False, df=lambda img:img, num_components=4,
intrins=lambda dtype:{'IMAGE_DIM':mesa.GLSL_SAMPLER_DIM_2D, 'ACCESS':mesa.ACCESS_CAN_REORDER, 'SRC_TYPE':nfloat(dtype)},
srcs=lambda b,img,idx_y,idx_x,val:[nsrc(x) for x in [img, tovec(b, idx_y, idx_x), nundef(b, dtypes.int), val, nimm(b, 0, dtypes.int)]])(
srcs=lambda b,img,idx_y,idx_x,val,dtype:[nsrc(x) for x in [img, tovec(b, idx_y, idx_x), nundef(b, dtypes.int),
val if val.num_components == 4 else nalu(b, "vec4", val, nundef(b, dtype), nundef(b, dtype), nundef(b, dtype)), nimm(b, 0, dtypes.int)]])(
lambda b,img,idx_y,idx_x,val,dtype:mesa.nir_intrinsic_instr_create(b.shader,g("nir_intrinsic_image_store")))
_nload_img = nir_instr(intrins=lambda dtype:{'IMAGE_DIM':mesa.GLSL_SAMPLER_DIM_2D, 'ACCESS':mesa.ACCESS_CAN_REORDER, 'DEST_TYPE':nfloat(dtype)},
@ -277,8 +311,11 @@ class IR3Renderer(NIRRenderer, OpenCLRenderer):
def_rewrite = PatternMatcher([
(UPat(Ops.STORE, src=(UPat.var('img').index(UPat.var('idx_y'), UPat.var('idx_x')), UPat.var("val")), allow_any_len=True),
lambda ctx,img,idx_y,idx_x,val: nstore_img(ctx.b, ctx.r[img], ctx.r[idx_y], ctx.r[idx_x], ctx.r[val], val.dtype)),
(UPat(Ops.LOAD, src=(UPat.var('img').index(UPat.var('idx_y'), UPat.var('idx_x')), UPat.var("alt"), UPat.var("gate"))),
lambda ctx,img,idx_y,idx_x,alt,gate: if_phi(ctx.b, ctx.r[gate], lambda: ctx.nload_img(img, idx_y, idx_x), lambda: ctx.r[alt])),
(UPat(Ops.LOAD, src=(UPat.var('img').index(UPat.var('idx_y'), UPat.var('idx_x')), UPat.var("alt"), UPat.var("gate")), name="x"),
lambda ctx,x,img,idx_y,idx_x,alt,gate: if_phi(ctx.b, ctx.r[gate],
lambda: ctx.nload_img(img, idx_y, idx_x) if len(x.shape) > 0 and x.shape[-1] == 4 else nchannel(ctx.b, ctx.nload_img(img, idx_y, idx_x), 0),
lambda: ctx.r[alt] if len(x.shape) == 0 or x.shape[-1] != 4 or ctx.r[alt].num_components == 4 else
nalu(ctx.b, "vec4", ctx.r[alt], nundef(ctx.b, x.dtype), nundef(ctx.b, x.dtype), nundef(ctx.b, x.dtype)))),
(UPat(Ops.LOAD, src=(UPat.var('img').index(UPat.var('idx_y'), UPat.var('idx_x')),)), nload_img),
]) + NIRRenderer.def_rewrite

View file

@ -114,6 +114,8 @@ class DLL(ctypes.CDLL):
def __init__(self, nm:str, paths:str|list[str], extra_paths=[], emsg="", **kwargs):
self.nm, self.emsg = nm, emsg or f"try setting {nm.upper()+'_PATH'}?"
if nm == 'llvm' and (ver:=getenv("LLVM_VERSION", "")):
paths = ([f"/opt/homebrew/opt/llvm@{ver}/lib/libLLVM.dylib"] if OSX else [f"LLVM-{ver}"]) + (paths if isinstance(paths, list) else [paths])
if (path:= DLL.findlib(nm, paths if isinstance(paths, list) else [paths], extra_paths if isinstance(extra_paths, list) else [extra_paths])):
if DEBUG >= 3: print(f"loading {nm} from {path}")
try: