mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
fix for nir
This commit is contained in:
parent
774847a54d
commit
394afe40c5
2 changed files with 63 additions and 24 deletions
|
|
@ -1,17 +1,17 @@
|
|||
from typing import Callable, cast, Any
|
||||
from tinygrad.dtype import AddrSpace, DType, PtrDType, ImageDType, dtypes, truncate
|
||||
from tinygrad.dtype import AddrSpace, DType, ImageDType, dtypes, truncate
|
||||
from tinygrad.helpers import DEBUG, OSX, unwrap, fromimport, Target
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.renderer.cstyle import CUDARenderer, OpenCLRenderer
|
||||
from tinygrad.uop.ops import GroupOp, Ops, UOp, PatternMatcher, UPat, range_str
|
||||
from tinygrad.runtime.autogen import mesa
|
||||
from tinygrad.runtime.support.c import POINTER
|
||||
import base64, ctypes, ctypes.util, struct, functools, inspect, itertools
|
||||
import base64, ctypes, ctypes.util, struct, functools, inspect, itertools, os, warnings
|
||||
|
||||
def g(s:str): return getattr(mesa, s)
|
||||
def nsrc(d:mesa.nir_def) -> mesa.nir_src: return mesa.nir_src(ssa=ctypes.pointer(d))
|
||||
|
||||
def glsl_type(t:DType): return mesa.glsl_array_type(glsl_type(t.base), t.size, 0).contents if isinstance(t, PtrDType) else {
|
||||
def glsl_type(t:DType): return {
|
||||
**{getattr(dtypes,k):g(f"glsl_type_builtin_{v}") for k,v in [('double','double'),('float','float'),('float16','float16_t'),('bool','uint8_t')]},
|
||||
**{d:g(f"glsl_type_builtin_{'u' * (d in dtypes.uints)}int{str(d.bitsize)+'_t' if d.itemsize != 4 else ''}") for d in dtypes.ints}}[t]
|
||||
|
||||
|
|
@ -25,7 +25,6 @@ aop = {**{x:u_aop for x in (dtypes.bool,)+dtypes.uints}, **{x:s_aop for x in dty
|
|||
|
||||
def c(t:DType, u:bool=True) -> str: return "u" if t in dtypes.uints and u else ("i" if t in dtypes.ints else ("f" if t in dtypes.floats else "b"))
|
||||
def ncast(b:mesa.nir_builder, src:mesa.nir_def, it:DType, ot:DType) -> mesa.nir_def:
|
||||
if isinstance(it, PtrDType) and ot == dtypes.long: return src
|
||||
return nalu(b, f"{c(it)}2{c(it) if it in dtypes.ints and ot in dtypes.ints else c(ot, ot == dtypes.bool)}{ot.bitsize}", src)
|
||||
|
||||
def nif(b:mesa.nir_builder, cond:mesa.nir_def, then_fn:Callable, else_fn:Callable):
|
||||
|
|
@ -86,9 +85,9 @@ def scope(space): return 'global' if space == AddrSpace.GLOBAL else ('shared' if
|
|||
nstore = nir_instr(has_def=False, df=lambda addr:addr, intrins=lambda space,val: {"WRITE_MASK":(1<<val.num_components)-1, **iointr(space)},
|
||||
num_components=lambda val:val.num_components, srcs=lambda space, addr, val: [nsrc(val), nsrc(addr)][::1 if space != AddrSpace.REG else -1])(
|
||||
lambda b, space, addr, val, dtype: mesa.nir_intrinsic_instr_create(b.shader, g(f"nir_intrinsic_store_{scope(space)}")))
|
||||
nload = nir_instr(nc=lambda dtype:dtype.count, bs=lambda dtype:dtype.bitsize//dtype.count, num_components=lambda dtype:dtype.count,
|
||||
nload = nir_instr(nc=lambda count:count, bs=lambda dtype:dtype.bitsize, num_components=lambda count:count,
|
||||
intrins=lambda space:{**({"ACCESS":mesa.ACCESS_CAN_REORDER} if space==AddrSpace.GLOBAL else {}), **iointr(space)}, srcs=lambda addr: [nsrc(addr)])(
|
||||
lambda b, space, addr, dtype: mesa.nir_intrinsic_instr_create(b.shader, g(f"nir_intrinsic_load_{scope(space)}")))
|
||||
lambda b, space, addr, dtype, count=1: mesa.nir_intrinsic_instr_create(b.shader, g(f"nir_intrinsic_load_{scope(space)}")))
|
||||
|
||||
ngid = nir_instr(nc=3, bs=32)(lambda b: mesa.nir_intrinsic_instr_create(b.shader, mesa.nir_intrinsic_load_workgroup_id))
|
||||
nlid = nir_instr(nc=3, bs=32)(lambda b: mesa.nir_intrinsic_instr_create(b.shader, mesa.nir_intrinsic_load_local_invocation_id))
|
||||
|
|
@ -104,16 +103,31 @@ def njump(b:mesa.nir_builder, typ, tgt=None, cond=None, else_tgt=None): return m
|
|||
|
||||
def if_phi(b:mesa.nir_builder, cond, then_fn, else_fn): return mesa.nir_if_phi(b, *nif(b, cond, then_fn, else_fn)).contents
|
||||
|
||||
def nidx(b:mesa.nir_builder, buf, off, dtype, gate=None) -> mesa.nir_def:
|
||||
def _load_count(x:UOp) -> int: return x.max_numel() if 1 < x.max_numel() <= 4 else 1
|
||||
def _pad_count(b:mesa.nir_builder, dtype:DType, count:int, val):
|
||||
return val if val.num_components == count else nalu(b, f"vec{count}", val, *[nundef(b, dtype) for _ in range(count-1)])
|
||||
|
||||
def nidx(b:mesa.nir_builder, buf, off, dtype, addrspace, gate=None) -> mesa.nir_def:
|
||||
@nir_instr(nc=1, bs=32, modes=lambda buf: buf.data.mode, type=lambda buf: mesa.glsl_get_array_element(buf.type))
|
||||
def reg(b, buf):
|
||||
deref = mesa.nir_deref_instr_create(b.shader, mesa.nir_deref_type_array)
|
||||
deref.contents.parent, deref.contents.arr.index = nsrc(deref_var(b, buf)), nsrc(off)
|
||||
return deref
|
||||
f = (functools.partial(reg, b, buf) if dtype.addrspace == AddrSpace.REG else
|
||||
f = (functools.partial(reg, b, buf) if addrspace == AddrSpace.REG else
|
||||
lambda: nalu(b, "iadd", buf, nalu(b, "imul", off, nimm(b, dtype.itemsize, dtypes.long))))
|
||||
return if_phi(b, gate, f, lambda: buf) if gate is not None else f()
|
||||
|
||||
def ngated_load_index(ctx, x, buf, off, alt, gate):
|
||||
cnt = _load_count(x)
|
||||
return if_phi(ctx.b, ctx.r[gate],
|
||||
lambda: nload(ctx.b, buf.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype, buf.addrspace, ctx.r[gate]), x.dtype, cnt),
|
||||
lambda: _pad_count(ctx.b, x.dtype, cnt, ctx.r[alt]))
|
||||
|
||||
def ngated_load_shrink(ctx, x, idx, alt, gate):
|
||||
cnt = _load_count(idx)
|
||||
return if_phi(ctx.b, ctx.r[gate], lambda: nload(ctx.b, idx.addrspace, ctx.r[idx], x.dtype, cnt),
|
||||
lambda: _pad_count(ctx.b, x.dtype, cnt, ctx.r[alt]))
|
||||
|
||||
class NIRRenderer(Renderer):
|
||||
suffix = "NIR"
|
||||
nir_options: bytes
|
||||
|
|
@ -137,7 +151,7 @@ class NIRRenderer(Renderer):
|
|||
(UPat(Ops.CAST, (dtypes.uchar, dtypes.ushort), src=(UPat.var("x", dtypes.floats),), name="c"), lambda x,c: x.cast(dtypes.int32).cast(c.dtype)),
|
||||
# load/store use pointer arithmetic, and the cast does nothing. NOTE: this doesn't apply to image indexing cause it's 1-D
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("off")), name="x"), lambda x,buf,off: x.replace(
|
||||
src=(buf,off.cast(dtypes.long))) if buf.dtype.addrspace != AddrSpace.REG and off.op not in (Ops.CAST, Ops.STACK) else None),
|
||||
src=(buf,off.cast(dtypes.long))) if buf.addrspace != AddrSpace.REG and off.op not in (Ops.CAST, Ops.STACK) else None),
|
||||
# images need index to be int for nir
|
||||
(UPat.var("buf").index(UPat.var("idx_y"), UPat.var("idx_x")),
|
||||
lambda buf,idx_y,idx_x: buf.index(idx_y.cast(dtypes.int), idx_x.cast(dtypes.int))),
|
||||
|
|
@ -149,18 +163,28 @@ class NIRRenderer(Renderer):
|
|||
(UPat(Ops.DEFINE_VAR, name="x"), lambda ctx,x: ctx.param(ctx.b, x, 4)),
|
||||
(UPat(Ops.SPECIAL, name="x"), lambda ctx,x: nchannel(ctx.b, {'g':ngid, 'l':nlid, 'i': nid}[x.arg[0]](ctx.b), int(x.arg[-1]))),
|
||||
(UPat(Ops.STORE, src=(UPat(Ops.INDEX, src=(UPat.var("buf"),UPat.var("off"))).or_casted(), UPat.var("val"))),
|
||||
lambda ctx,buf,off,val: nstore(ctx.b, buf.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype), ctx.r[val], val.dtype)),
|
||||
lambda ctx,buf,off,val: nstore(ctx.b, buf.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype, buf.addrspace), ctx.r[val], val.dtype)),
|
||||
(UPat(Ops.STORE, src=(UPat(Ops.SHRINK, name="idx"), UPat.var("val"))),
|
||||
lambda ctx,idx,val: nstore(ctx.b, idx.addrspace, ctx.r[idx], ctx.r[val], val.dtype)),
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("off"))).or_casted(), UPat.var("alt"), UPat.var("gate")), name="x"),
|
||||
lambda ctx,x,buf,off,alt,gate: if_phi(ctx.b, ctx.r[gate],
|
||||
lambda: nload(ctx.b, buf.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype, ctx.r[gate]), x.dtype), lambda: ctx.r[alt])),
|
||||
ngated_load_index),
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.SHRINK, name="idx"), UPat.var("alt"), UPat.var("gate")), name="x"),
|
||||
ngated_load_shrink),
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("off"))).or_casted(),), name="x"),
|
||||
lambda ctx,x,buf,off: nload(ctx.b, buf.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype), x.dtype)),
|
||||
(UPat(Ops.STACK, name="x"), lambda ctx,x: nalu(ctx.b, f"vec{x.dtype.count}", *[ctx.r[src] for src in x.src])),
|
||||
lambda ctx,x,buf,off: nload(ctx.b, buf.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype, buf.addrspace), x.dtype, _load_count(x))),
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.SHRINK, name="idx"),), name="x"),
|
||||
lambda ctx,x,idx: nload(ctx.b, idx.addrspace, ctx.r[idx], x.dtype, _load_count(idx))),
|
||||
(UPat(Ops.SHRINK, src=(UPat.var("buf"), UPat.var("off"), UPat.cvar()), name="x"),
|
||||
lambda ctx,x,buf,off: nidx(ctx.b, ctx.r[buf], ctx.r[off], x.dtype, x.addrspace)),
|
||||
(UPat(Ops.STACK, name="x"), lambda ctx,x: ctx.r[x.src[0]] if len(x.src) == 1 else
|
||||
nalu(ctx.b, f"vec{len(x.src)}", *[ctx.r[src] for src in x.src])),
|
||||
(UPat(GroupOp.ALU, name="x"), lambda ctx,x: nalu(ctx.b, aop[x.src[0].dtype.scalar()][x.op], *[ctx.r[src] for src in x.src])),
|
||||
(UPat(Ops.CAST, name="x"), lambda ctx,x: ncast(ctx.b, ctx.r[x.src[0]], x.src[0].dtype, x.dtype)),
|
||||
(UPat(Ops.BITCAST, src=(UPat.var("a"),), allow_any_len=True), lambda ctx,a: ctx.r[a]),
|
||||
(UPat(Ops.GEP, src=(UPat.var("a"),), name="x"), lambda ctx,x,a: nchannel(ctx.b, ctx.r[a], x.arg[0])),
|
||||
(UPat(Ops.DEFINE_REG, name="x"), lambda ctx,x:mesa.nir_local_variable_create(ctx.b.impl, glsl_type(x.dtype), f"acc{x.arg}".encode()).contents),
|
||||
(UPat(Ops.INDEX, src=(UPat.var("a"), UPat.cvar("idx"))),
|
||||
lambda ctx,a,idx: nchannel(ctx.b, ctx.r[a], idx.arg) if a.addrspace == AddrSpace.ANON else None),
|
||||
(UPat(Ops.DEFINE_REG, name="x"), lambda ctx,x: mesa.nir_local_variable_create(ctx.b.impl,
|
||||
mesa.glsl_array_type(glsl_type(x.dtype), x.src[0].arg, 0), f"acc{x.arg}".encode()).contents),
|
||||
(UPat(Ops.BARRIER), lambda ctx: nbarrier(ctx.b)),
|
||||
(UPat(Ops.IF, name="x"), lambda ctx,x: mesa.nir_push_if(ctx.b, ctx.r[x.src[0]])),
|
||||
(UPat(Ops.ENDIF, name="x"), lambda ctx,x: (lambda _: mesa.nir_def())(mesa.nir_pop_if(ctx.b, ctx.r[x.src[0]])))
|
||||
|
|
@ -189,8 +213,7 @@ class NIRRenderer(Renderer):
|
|||
self.param_idx, ranges = 0, []
|
||||
|
||||
for u in uops:
|
||||
if u.op in {Ops.NOOP, Ops.GROUP, Ops.INDEX}: pass
|
||||
elif u.op is Ops.CAST and isinstance(u.dtype, PtrDType): pass
|
||||
if u.op in {Ops.NOOP, Ops.GROUP} or (u.op is Ops.INDEX and u.src[0].addrspace != AddrSpace.ANON): pass
|
||||
elif u.op is Ops.AFTER:
|
||||
self.r[u] = self.r[u.src[0]]
|
||||
elif u.op == Ops.SINK:
|
||||
|
|
@ -198,7 +221,7 @@ class NIRRenderer(Renderer):
|
|||
self.b.shader.contents.info.name = ctypes.cast(ctypes.create_string_buffer(u.arg.function_name.encode()), POINTER[ctypes.c_char])
|
||||
elif u.op == Ops.DEFINE_LOCAL:
|
||||
self.r[u] = nimm(self.b, self.b.shader.contents.info.shared_size, dtypes.long)
|
||||
self.b.shader.contents.info.shared_size += u.dtype.nbytes()
|
||||
self.b.shader.contents.info.shared_size += u.src[0].arg * u.dtype.itemsize
|
||||
elif u.op == Ops.RANGE:
|
||||
ranges.append(i:=deref_var(self.b, mesa.nir_local_variable_create(self.b.impl, glsl_type(u.dtype), f"idx{range_str(u)}".encode()).contents))
|
||||
nstore(self.b, AddrSpace.REG, i, nimm(self.b, 0, u.dtype), u.dtype)
|
||||
|
|
@ -217,7 +240,17 @@ class NIRRenderer(Renderer):
|
|||
self.r[u] = cast(mesa.nir_def, d)
|
||||
self.postrender(uops)
|
||||
|
||||
mesa.nir_validate_shader(self.b.shader, b"after render")
|
||||
if DEBUG >= 2 and hasattr(os, "fork"):
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", DeprecationWarning)
|
||||
pid = os.fork()
|
||||
if pid == 0:
|
||||
mesa.nir_validate_shader(self.b.shader, b"after render")
|
||||
os._exit(0)
|
||||
_, status = os.waitpid(pid, 0)
|
||||
if os.WIFSIGNALED(status): raise RuntimeError(f"NIR validation failed after render with signal {os.WTERMSIG(status)}")
|
||||
if os.WEXITSTATUS(status) != 0: raise RuntimeError(f"NIR validation failed after render with exit code {os.WEXITSTATUS(status)}")
|
||||
else: mesa.nir_validate_shader(self.b.shader, b"after render")
|
||||
if DEBUG >= 4: mesa.nir_print_shader(self.b.shader, ctypes.POINTER(mesa.struct__IO_FILE).in_dll(ctypes.CDLL(ctypes.util.find_library('c')),
|
||||
"__stdoutp" if OSX else "stdout"))
|
||||
mesa.nir_serialize(blob:=mesa.struct_blob(), self.b.shader, False)
|
||||
|
|
@ -257,9 +290,10 @@ class LVPRenderer(NIRRenderer):
|
|||
|
||||
def tovec(b, idx_y, idx_x): return nalu(b, "vec4", idx_x, idx_y, nundef(b, dtypes.int), nundef(b, dtypes.int))
|
||||
def nfloat(dtype): return mesa.nir_type_float16 if dtype == dtypes.half else mesa.nir_type_float32
|
||||
nstore_img = nir_instr(has_def=False, df=lambda img:img, num_components=lambda val:val.num_components,
|
||||
nstore_img = nir_instr(has_def=False, df=lambda img:img, num_components=4,
|
||||
intrins=lambda dtype:{'IMAGE_DIM':mesa.GLSL_SAMPLER_DIM_2D, 'ACCESS':mesa.ACCESS_CAN_REORDER, 'SRC_TYPE':nfloat(dtype)},
|
||||
srcs=lambda b,img,idx_y,idx_x,val:[nsrc(x) for x in [img, tovec(b, idx_y, idx_x), nundef(b, dtypes.int), val, nimm(b, 0, dtypes.int)]])(
|
||||
srcs=lambda b,img,idx_y,idx_x,val,dtype:[nsrc(x) for x in [img, tovec(b, idx_y, idx_x), nundef(b, dtypes.int),
|
||||
val if val.num_components == 4 else nalu(b, "vec4", val, nundef(b, dtype), nundef(b, dtype), nundef(b, dtype)), nimm(b, 0, dtypes.int)]])(
|
||||
lambda b,img,idx_y,idx_x,val,dtype:mesa.nir_intrinsic_instr_create(b.shader,g("nir_intrinsic_image_store")))
|
||||
|
||||
_nload_img = nir_instr(intrins=lambda dtype:{'IMAGE_DIM':mesa.GLSL_SAMPLER_DIM_2D, 'ACCESS':mesa.ACCESS_CAN_REORDER, 'DEST_TYPE':nfloat(dtype)},
|
||||
|
|
@ -277,8 +311,11 @@ class IR3Renderer(NIRRenderer, OpenCLRenderer):
|
|||
def_rewrite = PatternMatcher([
|
||||
(UPat(Ops.STORE, src=(UPat.var('img').index(UPat.var('idx_y'), UPat.var('idx_x')), UPat.var("val")), allow_any_len=True),
|
||||
lambda ctx,img,idx_y,idx_x,val: nstore_img(ctx.b, ctx.r[img], ctx.r[idx_y], ctx.r[idx_x], ctx.r[val], val.dtype)),
|
||||
(UPat(Ops.LOAD, src=(UPat.var('img').index(UPat.var('idx_y'), UPat.var('idx_x')), UPat.var("alt"), UPat.var("gate"))),
|
||||
lambda ctx,img,idx_y,idx_x,alt,gate: if_phi(ctx.b, ctx.r[gate], lambda: ctx.nload_img(img, idx_y, idx_x), lambda: ctx.r[alt])),
|
||||
(UPat(Ops.LOAD, src=(UPat.var('img').index(UPat.var('idx_y'), UPat.var('idx_x')), UPat.var("alt"), UPat.var("gate")), name="x"),
|
||||
lambda ctx,x,img,idx_y,idx_x,alt,gate: if_phi(ctx.b, ctx.r[gate],
|
||||
lambda: ctx.nload_img(img, idx_y, idx_x) if len(x.shape) > 0 and x.shape[-1] == 4 else nchannel(ctx.b, ctx.nload_img(img, idx_y, idx_x), 0),
|
||||
lambda: ctx.r[alt] if len(x.shape) == 0 or x.shape[-1] != 4 or ctx.r[alt].num_components == 4 else
|
||||
nalu(ctx.b, "vec4", ctx.r[alt], nundef(ctx.b, x.dtype), nundef(ctx.b, x.dtype), nundef(ctx.b, x.dtype)))),
|
||||
(UPat(Ops.LOAD, src=(UPat.var('img').index(UPat.var('idx_y'), UPat.var('idx_x')),)), nload_img),
|
||||
]) + NIRRenderer.def_rewrite
|
||||
|
||||
|
|
|
|||
|
|
@ -114,6 +114,8 @@ class DLL(ctypes.CDLL):
|
|||
|
||||
def __init__(self, nm:str, paths:str|list[str], extra_paths=[], emsg="", **kwargs):
|
||||
self.nm, self.emsg = nm, emsg or f"try setting {nm.upper()+'_PATH'}?"
|
||||
if nm == 'llvm' and (ver:=getenv("LLVM_VERSION", "")):
|
||||
paths = ([f"/opt/homebrew/opt/llvm@{ver}/lib/libLLVM.dylib"] if OSX else [f"LLVM-{ver}"]) + (paths if isinstance(paths, list) else [paths])
|
||||
if (path:= DLL.findlib(nm, paths if isinstance(paths, list) else [paths], extra_paths if isinstance(extra_paths, list) else [extra_paths])):
|
||||
if DEBUG >= 3: print(f"loading {nm} from {path}")
|
||||
try:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue