mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
2 commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
785f09aac4 | ||
|
|
32ad5e8b96 |
1 changed files with 33 additions and 31 deletions
|
|
@ -1,5 +1,5 @@
|
|||
from typing import Callable, cast, Any
|
||||
from tinygrad.dtype import AddrSpace, DType, PtrDType, ImageDType, dtypes, truncate
|
||||
from tinygrad.dtype import AddrSpace, DType, ImageDType, dtypes, truncate
|
||||
from tinygrad.helpers import DEBUG, OSX, unwrap, fromimport, Target
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.renderer.cstyle import CUDARenderer, OpenCLRenderer
|
||||
|
|
@ -11,7 +11,7 @@ import base64, ctypes, ctypes.util, struct, functools, inspect, itertools
|
|||
def g(s:str): return getattr(mesa, s)
|
||||
def nsrc(d:mesa.nir_def) -> mesa.nir_src: return mesa.nir_src(ssa=ctypes.pointer(d))
|
||||
|
||||
def glsl_type(t:DType): return mesa.glsl_array_type(glsl_type(t.base), t.size, 0).contents if isinstance(t, PtrDType) else {
|
||||
def glsl_type(t:DType): return {
|
||||
**{getattr(dtypes,k):g(f"glsl_type_builtin_{v}") for k,v in [('double','double'),('float','float'),('float16','float16_t'),('bool','uint8_t')]},
|
||||
**{d:g(f"glsl_type_builtin_{'u' * (d in dtypes.uints)}int{str(d.bitsize)+'_t' if d.itemsize != 4 else ''}") for d in dtypes.ints}}[t]
|
||||
|
||||
|
|
@ -25,7 +25,6 @@ aop = {**{x:u_aop for x in (dtypes.bool,)+dtypes.uints}, **{x:s_aop for x in dty
|
|||
|
||||
def c(t:DType, u:bool=True) -> str: return "u" if t in dtypes.uints and u else ("i" if t in dtypes.ints else ("f" if t in dtypes.floats else "b"))
|
||||
def ncast(b:mesa.nir_builder, src:mesa.nir_def, it:DType, ot:DType) -> mesa.nir_def:
|
||||
if isinstance(it, PtrDType) and ot == dtypes.long: return src
|
||||
return nalu(b, f"{c(it)}2{c(it) if it in dtypes.ints and ot in dtypes.ints else c(ot, ot == dtypes.bool)}{ot.bitsize}", src)
|
||||
|
||||
def nif(b:mesa.nir_builder, cond:mesa.nir_def, then_fn:Callable, else_fn:Callable):
|
||||
|
|
@ -85,10 +84,10 @@ def scope(space): return 'global' if space == AddrSpace.GLOBAL else ('shared' if
|
|||
nstore = nir_instr(has_def=False, df=lambda addr:addr, intrins=lambda space,val: {"WRITE_MASK":(1<<val.num_components)-1,
|
||||
**({"ALIGN_MUL":val.bit_size//8*val.num_components} if space != AddrSpace.REG else {})},
|
||||
num_components=lambda val:val.num_components, srcs=lambda space, addr, val: [nsrc(val), nsrc(addr)][::1 if space != AddrSpace.REG else -1])(
|
||||
lambda b, space, addr, val, dtype: mesa.nir_intrinsic_instr_create(b.shader, g(f"nir_intrinsic_store_{scope(space)}")))
|
||||
nload = nir_instr(nc=lambda u:u.max_numel(), bs=lambda u:u.dtype.bitsize//u.max_numel(), num_components=lambda u:u.max_numel(),
|
||||
lambda b, space, addr, val: mesa.nir_intrinsic_instr_create(b.shader, g(f"nir_intrinsic_store_{scope(space)}")))
|
||||
nload = nir_instr(nc=lambda u:u.max_numel(), bs=lambda u:u.dtype.bitsize, num_components=lambda u:u.max_numel(),
|
||||
intrins=lambda space,u:{**({"ACCESS":mesa.ACCESS_CAN_REORDER} if space==AddrSpace.GLOBAL else {}),
|
||||
**({"ALIGN_MUL":u.dtype.itemsize} if space != AddrSpace.REG else {})}, srcs=lambda addr: [nsrc(addr)])(
|
||||
**({"ALIGN_MUL":u.dtype.itemsize*u.max_numel()} if space != AddrSpace.REG else {})}, srcs=lambda addr: [nsrc(addr)])(
|
||||
lambda b, space, addr, u: mesa.nir_intrinsic_instr_create(b.shader, g(f"nir_intrinsic_load_{scope(space)}")))
|
||||
|
||||
ngid = nir_instr(nc=3, bs=32)(lambda b: mesa.nir_intrinsic_instr_create(b.shader, mesa.nir_intrinsic_load_workgroup_id))
|
||||
|
|
@ -105,17 +104,18 @@ def njump(b:mesa.nir_builder, typ, tgt=None, cond=None, else_tgt=None): return m
|
|||
|
||||
def if_phi(b:mesa.nir_builder, cond, then_fn, else_fn): return mesa.nir_if_phi(b, *nif(b, cond, then_fn, else_fn)).contents
|
||||
|
||||
def nidx(b:mesa.nir_builder, buf, off, dtype, gate=None) -> mesa.nir_def:
|
||||
def nidx(b:mesa.nir_builder, buf, off, space, itemsize, gate=None) -> mesa.nir_def:
|
||||
@nir_instr(nc=1, bs=32, modes=lambda buf: buf.data.mode, type=lambda buf: mesa.glsl_get_array_element(buf.type))
|
||||
def reg(b, buf):
|
||||
deref = mesa.nir_deref_instr_create(b.shader, mesa.nir_deref_type_array)
|
||||
deref.contents.parent, deref.contents.arr.index = nsrc(deref_var(b, buf)), nsrc(off)
|
||||
return deref
|
||||
f = (functools.partial(reg, b, buf) if dtype.addrspace == AddrSpace.REG else
|
||||
lambda: nalu(b, "iadd", buf, nalu(b, "imul", off, nimm(b, dtype.itemsize, dtypes.long))))
|
||||
f = (functools.partial(reg, b, buf) if space == AddrSpace.REG else
|
||||
lambda: nalu(b, "iadd", buf, nalu(b, "imul", off, nimm(b, itemsize, dtypes.long))))
|
||||
return if_phi(b, gate, f, lambda: buf) if gate is not None else f()
|
||||
|
||||
class NIRRenderer(Renderer):
|
||||
new_style = True
|
||||
suffix = "NIR"
|
||||
nir_options: bytes
|
||||
global_max, local_max, shared_max = CUDARenderer.global_max, CUDARenderer.local_max, CUDARenderer.shared_max
|
||||
|
|
@ -146,22 +146,23 @@ class NIRRenderer(Renderer):
|
|||
|
||||
def_rewrite = PatternMatcher([
|
||||
(UPat(Ops.CONST, name="x"), lambda ctx,x: nimm(ctx.b, x.arg, x.dtype)),
|
||||
(UPat(Ops.PARAM, name="x"), lambda ctx,x: ctx.param(ctx.b, x, 8)),
|
||||
(UPat(Ops.DEFINE_VAR, name="x"), lambda ctx,x: ctx.param(ctx.b, x, 4)),
|
||||
(UPat(Ops.PARAM, name="x"), lambda ctx,x: ctx.param(ctx.b, x, 8 if x.addrspace is not None else x.dtype.itemsize)),
|
||||
(UPat(Ops.SPECIAL, name="x"), lambda ctx,x: nchannel(ctx.b, {'g':ngid, 'l':nlid, 'i': nid}[x.arg[0]](ctx.b), int(x.arg[-1]))),
|
||||
(UPat(Ops.STORE, src=(UPat(Ops.INDEX, src=(UPat.var("buf"),UPat.var("off"))).or_casted(), UPat.var("val"))),
|
||||
lambda ctx,buf,off,val: nstore(ctx.b, buf.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype), ctx.r[val], val.dtype)),
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("off"))).or_casted(), UPat.var("alt"), UPat.var("gate")), name="x"),
|
||||
(UPat(Ops.STORE, src=(UPat((Ops.INDEX, Ops.SHRINK), src=(UPat.var("buf"),UPat.var("off")), allow_any_len=True), UPat.var("val"))),
|
||||
lambda ctx,buf,off,val: nstore(ctx.b, buf.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.addrspace, buf.dtype.itemsize), ctx.r[val])),
|
||||
(UPat(Ops.LOAD, src=(UPat((Ops.INDEX, Ops.SHRINK), src=(UPat.var("buf"), UPat.var("off")), allow_any_len=True), UPat.var("alt"),
|
||||
UPat.var("gate")), name="x"),
|
||||
lambda ctx,x,buf,off,alt,gate: if_phi(ctx.b, ctx.r[gate],
|
||||
lambda: nload(ctx.b, buf.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype, ctx.r[gate]), x), lambda: ctx.r[alt])),
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("off"))).or_casted(),), name="x"),
|
||||
lambda ctx,x,buf,off: nload(ctx.b, buf.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype), x)),
|
||||
lambda: nload(ctx.b, buf.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.addrspace, buf.dtype.itemsize, ctx.r[gate]), x),
|
||||
lambda: ctx.r[alt])),
|
||||
(UPat(Ops.LOAD, src=(UPat((Ops.INDEX, Ops.SHRINK), src=(UPat.var("buf"), UPat.var("off")), allow_any_len=True),), name="x"),
|
||||
lambda ctx,x,buf,off: nload(ctx.b, buf.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.addrspace, buf.dtype.itemsize), x)),
|
||||
(UPat(Ops.STACK, name="x"), lambda ctx,x: nalu(ctx.b, f"vec{x.max_numel()}", *[ctx.r[src] for src in x.src])),
|
||||
(UPat(GroupOp.ALU, name="x"), lambda ctx,x: nalu(ctx.b, aop[x.src[0].dtype.scalar()][x.op], *[ctx.r[src] for src in x.src])),
|
||||
(UPat(GroupOp.ALU, name="x"), lambda ctx,x: nalu(ctx.b, aop[x.src[0].dtype][x.op], *[ctx.r[src] for src in x.src])),
|
||||
(UPat(Ops.CAST, name="x"), lambda ctx,x: ncast(ctx.b, ctx.r[x.src[0]], x.src[0].dtype, x.dtype)),
|
||||
(UPat(Ops.BITCAST, src=(UPat.var("a"),), allow_any_len=True), lambda ctx,a: ctx.r[a]),
|
||||
(UPat(Ops.GEP, src=(UPat.var("a"),), name="x"), lambda ctx,x,a: nchannel(ctx.b, ctx.r[a], x.arg[0])),
|
||||
(UPat(Ops.DEFINE_REG, name="x"), lambda ctx,x:mesa.nir_local_variable_create(ctx.b.impl, glsl_type(x.dtype), f"acc{x.arg}".encode()).contents),
|
||||
(UPat(Ops.BUFFER, name="x"), lambda ctx,x: mesa.nir_local_variable_create(ctx.b.impl,
|
||||
mesa.glsl_array_type(glsl_type(x.dtype), x.max_numel(), 0).contents, f"acc{x.arg.slot}".encode()).contents),
|
||||
(UPat(Ops.BARRIER), lambda ctx: nbarrier(ctx.b)),
|
||||
(UPat(Ops.IF, name="x"), lambda ctx,x: mesa.nir_push_if(ctx.b, ctx.r[x.src[0]])),
|
||||
(UPat(Ops.ENDIF, name="x"), lambda ctx,x: (lambda _: mesa.nir_def())(mesa.nir_pop_if(ctx.b, ctx.r[x.src[0]])))
|
||||
|
|
@ -190,19 +191,21 @@ class NIRRenderer(Renderer):
|
|||
self.param_idx, ranges = 0, []
|
||||
|
||||
for u in uops:
|
||||
if u.op in {Ops.NOOP, Ops.GROUP, Ops.INDEX}: pass
|
||||
elif u.op is Ops.CAST and isinstance(u.dtype, PtrDType): pass
|
||||
if u.op in {Ops.NOOP, Ops.GROUP} or (u.op is Ops.STACK and len(u.src) == 0): pass
|
||||
elif u.op in {Ops.INDEX, Ops.SHRINK}:
|
||||
# INDEX on a register value picks the element, memory INDEX is handled in the LOAD/STORE patterns
|
||||
if u.src[0].op not in {Ops.PARAM, Ops.BUFFER, Ops.AFTER}: self.r[u] = nchannel(self.b, self.r[u.src[0]], u.src[1].arg)
|
||||
elif u.op is Ops.AFTER:
|
||||
self.r[u] = self.r[u.src[0]]
|
||||
elif u.op == Ops.SINK:
|
||||
if u.arg is not None:
|
||||
self.b.shader.contents.info.name = ctypes.cast(ctypes.create_string_buffer(u.arg.function_name.encode()), POINTER[ctypes.c_char])
|
||||
elif u.op == Ops.DEFINE_LOCAL:
|
||||
elif u.op == Ops.BUFFER and u.addrspace == AddrSpace.LOCAL:
|
||||
self.r[u] = nimm(self.b, self.b.shader.contents.info.shared_size, dtypes.long)
|
||||
self.b.shader.contents.info.shared_size += u.dtype.nbytes()
|
||||
self.b.shader.contents.info.shared_size += u.max_numel()*u.dtype.itemsize
|
||||
elif u.op == Ops.RANGE:
|
||||
ranges.append(i:=deref_var(self.b, mesa.nir_local_variable_create(self.b.impl, glsl_type(u.dtype), f"idx{range_str(u)}".encode()).contents))
|
||||
nstore(self.b, AddrSpace.REG, i, nimm(self.b, 0, u.dtype), u.dtype)
|
||||
nstore(self.b, AddrSpace.REG, i, nimm(self.b, 0, u.dtype))
|
||||
mesa.nir_push_loop(self.b)
|
||||
self.r[u] = nload(self.b, AddrSpace.REG, i, u)
|
||||
nif(self.b, nalu(self.b, "ilt", self.r[u], self.r[u.src[0]]), lambda: None, lambda: njump(self.b, mesa.nir_jump_break))
|
||||
|
|
@ -211,7 +214,7 @@ class NIRRenderer(Renderer):
|
|||
next_i = nalu(self.b, "iadd", self.r[r], nimm(self.b, 1, r.dtype))
|
||||
# TODO: this nif should be removable ... but TestMultiTensor.test_double_matmul_shard_W_0 segfaults with it gone
|
||||
nif(self.b, nalu(self.b, "ilt", next_i, self.r[r.src[0]]), lambda: None, lambda: njump(self.b, mesa.nir_jump_break))
|
||||
nstore(self.b, AddrSpace.REG, ranges.pop(), next_i, r.dtype),
|
||||
nstore(self.b, AddrSpace.REG, ranges.pop(), next_i),
|
||||
mesa.nir_pop_loop(self.b, None)
|
||||
else:
|
||||
if (d:=self.def_rewrite.rewrite(u, ctx=self)) is None: raise RuntimeError(f"failed to render {u.op} srcs {[x.dtype for x in u.src]}")
|
||||
|
|
@ -254,7 +257,7 @@ class LVPRenderer(NIRRenderer):
|
|||
|
||||
def prerender(self, uops:list[UOp]):
|
||||
super().prerender(uops)
|
||||
self.param_sz = sum([8 if u.op == Ops.PARAM else u.dtype.itemsize for u in uops if u.op in (Ops.PARAM, Ops.DEFINE_VAR)])
|
||||
self.param_sz = sum([8 if u.addrspace is not None else u.dtype.itemsize for u in uops if u.op is Ops.PARAM])
|
||||
|
||||
def tovec(b, idx_y, idx_x): return nalu(b, "vec4", idx_x, idx_y, nundef(b, dtypes.int), nundef(b, dtypes.int))
|
||||
def nfloat(dtype): return mesa.nir_type_float16 if dtype == dtypes.half else mesa.nir_type_float32
|
||||
|
|
@ -269,7 +272,6 @@ _nload_img = nir_instr(intrins=lambda dtype:{'IMAGE_DIM':mesa.GLSL_SAMPLER_DIM_2
|
|||
lambda b,img,idx_y,idx_x,dtype: mesa.nir_intrinsic_instr_create(b.shader, g("nir_intrinsic_image_load")))
|
||||
|
||||
class IR3Renderer(NIRRenderer, OpenCLRenderer):
|
||||
new_style = False
|
||||
has_aux = True
|
||||
|
||||
def nload_img(ctx,img,idx_y,idx_x):
|
||||
|
|
@ -294,11 +296,11 @@ class IR3Renderer(NIRRenderer, OpenCLRenderer):
|
|||
def prerender(self, uops:list[UOp]):
|
||||
super().prerender(uops)
|
||||
self.texs:set[UOp] = set()
|
||||
self.uops, self.ibo_idx, self.img_idx = uops, 0, 0
|
||||
self.param_sz = sum([8 if u.op == Ops.PARAM else u.dtype.itemsize for u in uops if u.op in (Ops.PARAM, Ops.DEFINE_VAR)])
|
||||
self.img_idx = 0
|
||||
self.param_sz = sum([8 if u.addrspace is not None else u.dtype.itemsize for u in uops if u.op is Ops.PARAM])
|
||||
|
||||
def postrender(self, uops:list[UOp]):
|
||||
bufs, texs, imgs = [u for u in uops if u.op == Ops.PARAM], itertools.count().__next__, itertools.count().__next__
|
||||
bufs, texs, imgs = [u for u in uops if u.op is Ops.PARAM and u.addrspace is not None], itertools.count().__next__, itertools.count().__next__
|
||||
for b in filter(lambda b: isinstance(b.dtype, ImageDType), bufs): nimm_set(self.r[b], texs() if b in self.texs else imgs(), dtypes.int)
|
||||
|
||||
self.b.shader.contents.info.num_ubos = len([u for u in bufs if not isinstance(u.dtype, ImageDType)])
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue