mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
2 commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
785f09aac4 | ||
|
|
32ad5e8b96 |
1 changed files with 33 additions and 31 deletions
|
|
@ -1,5 +1,5 @@
|
||||||
from typing import Callable, cast, Any
|
from typing import Callable, cast, Any
|
||||||
from tinygrad.dtype import AddrSpace, DType, PtrDType, ImageDType, dtypes, truncate
|
from tinygrad.dtype import AddrSpace, DType, ImageDType, dtypes, truncate
|
||||||
from tinygrad.helpers import DEBUG, OSX, unwrap, fromimport, Target
|
from tinygrad.helpers import DEBUG, OSX, unwrap, fromimport, Target
|
||||||
from tinygrad.renderer import Renderer
|
from tinygrad.renderer import Renderer
|
||||||
from tinygrad.renderer.cstyle import CUDARenderer, OpenCLRenderer
|
from tinygrad.renderer.cstyle import CUDARenderer, OpenCLRenderer
|
||||||
|
|
@ -11,7 +11,7 @@ import base64, ctypes, ctypes.util, struct, functools, inspect, itertools
|
||||||
def g(s:str): return getattr(mesa, s)
|
def g(s:str): return getattr(mesa, s)
|
||||||
def nsrc(d:mesa.nir_def) -> mesa.nir_src: return mesa.nir_src(ssa=ctypes.pointer(d))
|
def nsrc(d:mesa.nir_def) -> mesa.nir_src: return mesa.nir_src(ssa=ctypes.pointer(d))
|
||||||
|
|
||||||
def glsl_type(t:DType): return mesa.glsl_array_type(glsl_type(t.base), t.size, 0).contents if isinstance(t, PtrDType) else {
|
def glsl_type(t:DType): return {
|
||||||
**{getattr(dtypes,k):g(f"glsl_type_builtin_{v}") for k,v in [('double','double'),('float','float'),('float16','float16_t'),('bool','uint8_t')]},
|
**{getattr(dtypes,k):g(f"glsl_type_builtin_{v}") for k,v in [('double','double'),('float','float'),('float16','float16_t'),('bool','uint8_t')]},
|
||||||
**{d:g(f"glsl_type_builtin_{'u' * (d in dtypes.uints)}int{str(d.bitsize)+'_t' if d.itemsize != 4 else ''}") for d in dtypes.ints}}[t]
|
**{d:g(f"glsl_type_builtin_{'u' * (d in dtypes.uints)}int{str(d.bitsize)+'_t' if d.itemsize != 4 else ''}") for d in dtypes.ints}}[t]
|
||||||
|
|
||||||
|
|
@ -25,7 +25,6 @@ aop = {**{x:u_aop for x in (dtypes.bool,)+dtypes.uints}, **{x:s_aop for x in dty
|
||||||
|
|
||||||
def c(t:DType, u:bool=True) -> str: return "u" if t in dtypes.uints and u else ("i" if t in dtypes.ints else ("f" if t in dtypes.floats else "b"))
|
def c(t:DType, u:bool=True) -> str: return "u" if t in dtypes.uints and u else ("i" if t in dtypes.ints else ("f" if t in dtypes.floats else "b"))
|
||||||
def ncast(b:mesa.nir_builder, src:mesa.nir_def, it:DType, ot:DType) -> mesa.nir_def:
|
def ncast(b:mesa.nir_builder, src:mesa.nir_def, it:DType, ot:DType) -> mesa.nir_def:
|
||||||
if isinstance(it, PtrDType) and ot == dtypes.long: return src
|
|
||||||
return nalu(b, f"{c(it)}2{c(it) if it in dtypes.ints and ot in dtypes.ints else c(ot, ot == dtypes.bool)}{ot.bitsize}", src)
|
return nalu(b, f"{c(it)}2{c(it) if it in dtypes.ints and ot in dtypes.ints else c(ot, ot == dtypes.bool)}{ot.bitsize}", src)
|
||||||
|
|
||||||
def nif(b:mesa.nir_builder, cond:mesa.nir_def, then_fn:Callable, else_fn:Callable):
|
def nif(b:mesa.nir_builder, cond:mesa.nir_def, then_fn:Callable, else_fn:Callable):
|
||||||
|
|
@ -85,10 +84,10 @@ def scope(space): return 'global' if space == AddrSpace.GLOBAL else ('shared' if
|
||||||
nstore = nir_instr(has_def=False, df=lambda addr:addr, intrins=lambda space,val: {"WRITE_MASK":(1<<val.num_components)-1,
|
nstore = nir_instr(has_def=False, df=lambda addr:addr, intrins=lambda space,val: {"WRITE_MASK":(1<<val.num_components)-1,
|
||||||
**({"ALIGN_MUL":val.bit_size//8*val.num_components} if space != AddrSpace.REG else {})},
|
**({"ALIGN_MUL":val.bit_size//8*val.num_components} if space != AddrSpace.REG else {})},
|
||||||
num_components=lambda val:val.num_components, srcs=lambda space, addr, val: [nsrc(val), nsrc(addr)][::1 if space != AddrSpace.REG else -1])(
|
num_components=lambda val:val.num_components, srcs=lambda space, addr, val: [nsrc(val), nsrc(addr)][::1 if space != AddrSpace.REG else -1])(
|
||||||
lambda b, space, addr, val, dtype: mesa.nir_intrinsic_instr_create(b.shader, g(f"nir_intrinsic_store_{scope(space)}")))
|
lambda b, space, addr, val: mesa.nir_intrinsic_instr_create(b.shader, g(f"nir_intrinsic_store_{scope(space)}")))
|
||||||
nload = nir_instr(nc=lambda u:u.max_numel(), bs=lambda u:u.dtype.bitsize//u.max_numel(), num_components=lambda u:u.max_numel(),
|
nload = nir_instr(nc=lambda u:u.max_numel(), bs=lambda u:u.dtype.bitsize, num_components=lambda u:u.max_numel(),
|
||||||
intrins=lambda space,u:{**({"ACCESS":mesa.ACCESS_CAN_REORDER} if space==AddrSpace.GLOBAL else {}),
|
intrins=lambda space,u:{**({"ACCESS":mesa.ACCESS_CAN_REORDER} if space==AddrSpace.GLOBAL else {}),
|
||||||
**({"ALIGN_MUL":u.dtype.itemsize} if space != AddrSpace.REG else {})}, srcs=lambda addr: [nsrc(addr)])(
|
**({"ALIGN_MUL":u.dtype.itemsize*u.max_numel()} if space != AddrSpace.REG else {})}, srcs=lambda addr: [nsrc(addr)])(
|
||||||
lambda b, space, addr, u: mesa.nir_intrinsic_instr_create(b.shader, g(f"nir_intrinsic_load_{scope(space)}")))
|
lambda b, space, addr, u: mesa.nir_intrinsic_instr_create(b.shader, g(f"nir_intrinsic_load_{scope(space)}")))
|
||||||
|
|
||||||
ngid = nir_instr(nc=3, bs=32)(lambda b: mesa.nir_intrinsic_instr_create(b.shader, mesa.nir_intrinsic_load_workgroup_id))
|
ngid = nir_instr(nc=3, bs=32)(lambda b: mesa.nir_intrinsic_instr_create(b.shader, mesa.nir_intrinsic_load_workgroup_id))
|
||||||
|
|
@ -105,17 +104,18 @@ def njump(b:mesa.nir_builder, typ, tgt=None, cond=None, else_tgt=None): return m
|
||||||
|
|
||||||
def if_phi(b:mesa.nir_builder, cond, then_fn, else_fn): return mesa.nir_if_phi(b, *nif(b, cond, then_fn, else_fn)).contents
|
def if_phi(b:mesa.nir_builder, cond, then_fn, else_fn): return mesa.nir_if_phi(b, *nif(b, cond, then_fn, else_fn)).contents
|
||||||
|
|
||||||
def nidx(b:mesa.nir_builder, buf, off, dtype, gate=None) -> mesa.nir_def:
|
def nidx(b:mesa.nir_builder, buf, off, space, itemsize, gate=None) -> mesa.nir_def:
|
||||||
@nir_instr(nc=1, bs=32, modes=lambda buf: buf.data.mode, type=lambda buf: mesa.glsl_get_array_element(buf.type))
|
@nir_instr(nc=1, bs=32, modes=lambda buf: buf.data.mode, type=lambda buf: mesa.glsl_get_array_element(buf.type))
|
||||||
def reg(b, buf):
|
def reg(b, buf):
|
||||||
deref = mesa.nir_deref_instr_create(b.shader, mesa.nir_deref_type_array)
|
deref = mesa.nir_deref_instr_create(b.shader, mesa.nir_deref_type_array)
|
||||||
deref.contents.parent, deref.contents.arr.index = nsrc(deref_var(b, buf)), nsrc(off)
|
deref.contents.parent, deref.contents.arr.index = nsrc(deref_var(b, buf)), nsrc(off)
|
||||||
return deref
|
return deref
|
||||||
f = (functools.partial(reg, b, buf) if dtype.addrspace == AddrSpace.REG else
|
f = (functools.partial(reg, b, buf) if space == AddrSpace.REG else
|
||||||
lambda: nalu(b, "iadd", buf, nalu(b, "imul", off, nimm(b, dtype.itemsize, dtypes.long))))
|
lambda: nalu(b, "iadd", buf, nalu(b, "imul", off, nimm(b, itemsize, dtypes.long))))
|
||||||
return if_phi(b, gate, f, lambda: buf) if gate is not None else f()
|
return if_phi(b, gate, f, lambda: buf) if gate is not None else f()
|
||||||
|
|
||||||
class NIRRenderer(Renderer):
|
class NIRRenderer(Renderer):
|
||||||
|
new_style = True
|
||||||
suffix = "NIR"
|
suffix = "NIR"
|
||||||
nir_options: bytes
|
nir_options: bytes
|
||||||
global_max, local_max, shared_max = CUDARenderer.global_max, CUDARenderer.local_max, CUDARenderer.shared_max
|
global_max, local_max, shared_max = CUDARenderer.global_max, CUDARenderer.local_max, CUDARenderer.shared_max
|
||||||
|
|
@ -146,22 +146,23 @@ class NIRRenderer(Renderer):
|
||||||
|
|
||||||
def_rewrite = PatternMatcher([
|
def_rewrite = PatternMatcher([
|
||||||
(UPat(Ops.CONST, name="x"), lambda ctx,x: nimm(ctx.b, x.arg, x.dtype)),
|
(UPat(Ops.CONST, name="x"), lambda ctx,x: nimm(ctx.b, x.arg, x.dtype)),
|
||||||
(UPat(Ops.PARAM, name="x"), lambda ctx,x: ctx.param(ctx.b, x, 8)),
|
(UPat(Ops.PARAM, name="x"), lambda ctx,x: ctx.param(ctx.b, x, 8 if x.addrspace is not None else x.dtype.itemsize)),
|
||||||
(UPat(Ops.DEFINE_VAR, name="x"), lambda ctx,x: ctx.param(ctx.b, x, 4)),
|
|
||||||
(UPat(Ops.SPECIAL, name="x"), lambda ctx,x: nchannel(ctx.b, {'g':ngid, 'l':nlid, 'i': nid}[x.arg[0]](ctx.b), int(x.arg[-1]))),
|
(UPat(Ops.SPECIAL, name="x"), lambda ctx,x: nchannel(ctx.b, {'g':ngid, 'l':nlid, 'i': nid}[x.arg[0]](ctx.b), int(x.arg[-1]))),
|
||||||
(UPat(Ops.STORE, src=(UPat(Ops.INDEX, src=(UPat.var("buf"),UPat.var("off"))).or_casted(), UPat.var("val"))),
|
(UPat(Ops.STORE, src=(UPat((Ops.INDEX, Ops.SHRINK), src=(UPat.var("buf"),UPat.var("off")), allow_any_len=True), UPat.var("val"))),
|
||||||
lambda ctx,buf,off,val: nstore(ctx.b, buf.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype), ctx.r[val], val.dtype)),
|
lambda ctx,buf,off,val: nstore(ctx.b, buf.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.addrspace, buf.dtype.itemsize), ctx.r[val])),
|
||||||
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("off"))).or_casted(), UPat.var("alt"), UPat.var("gate")), name="x"),
|
(UPat(Ops.LOAD, src=(UPat((Ops.INDEX, Ops.SHRINK), src=(UPat.var("buf"), UPat.var("off")), allow_any_len=True), UPat.var("alt"),
|
||||||
|
UPat.var("gate")), name="x"),
|
||||||
lambda ctx,x,buf,off,alt,gate: if_phi(ctx.b, ctx.r[gate],
|
lambda ctx,x,buf,off,alt,gate: if_phi(ctx.b, ctx.r[gate],
|
||||||
lambda: nload(ctx.b, buf.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype, ctx.r[gate]), x), lambda: ctx.r[alt])),
|
lambda: nload(ctx.b, buf.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.addrspace, buf.dtype.itemsize, ctx.r[gate]), x),
|
||||||
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("off"))).or_casted(),), name="x"),
|
lambda: ctx.r[alt])),
|
||||||
lambda ctx,x,buf,off: nload(ctx.b, buf.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.dtype), x)),
|
(UPat(Ops.LOAD, src=(UPat((Ops.INDEX, Ops.SHRINK), src=(UPat.var("buf"), UPat.var("off")), allow_any_len=True),), name="x"),
|
||||||
|
lambda ctx,x,buf,off: nload(ctx.b, buf.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.addrspace, buf.dtype.itemsize), x)),
|
||||||
(UPat(Ops.STACK, name="x"), lambda ctx,x: nalu(ctx.b, f"vec{x.max_numel()}", *[ctx.r[src] for src in x.src])),
|
(UPat(Ops.STACK, name="x"), lambda ctx,x: nalu(ctx.b, f"vec{x.max_numel()}", *[ctx.r[src] for src in x.src])),
|
||||||
(UPat(GroupOp.ALU, name="x"), lambda ctx,x: nalu(ctx.b, aop[x.src[0].dtype.scalar()][x.op], *[ctx.r[src] for src in x.src])),
|
(UPat(GroupOp.ALU, name="x"), lambda ctx,x: nalu(ctx.b, aop[x.src[0].dtype][x.op], *[ctx.r[src] for src in x.src])),
|
||||||
(UPat(Ops.CAST, name="x"), lambda ctx,x: ncast(ctx.b, ctx.r[x.src[0]], x.src[0].dtype, x.dtype)),
|
(UPat(Ops.CAST, name="x"), lambda ctx,x: ncast(ctx.b, ctx.r[x.src[0]], x.src[0].dtype, x.dtype)),
|
||||||
(UPat(Ops.BITCAST, src=(UPat.var("a"),), allow_any_len=True), lambda ctx,a: ctx.r[a]),
|
(UPat(Ops.BITCAST, src=(UPat.var("a"),), allow_any_len=True), lambda ctx,a: ctx.r[a]),
|
||||||
(UPat(Ops.GEP, src=(UPat.var("a"),), name="x"), lambda ctx,x,a: nchannel(ctx.b, ctx.r[a], x.arg[0])),
|
(UPat(Ops.BUFFER, name="x"), lambda ctx,x: mesa.nir_local_variable_create(ctx.b.impl,
|
||||||
(UPat(Ops.DEFINE_REG, name="x"), lambda ctx,x:mesa.nir_local_variable_create(ctx.b.impl, glsl_type(x.dtype), f"acc{x.arg}".encode()).contents),
|
mesa.glsl_array_type(glsl_type(x.dtype), x.max_numel(), 0).contents, f"acc{x.arg.slot}".encode()).contents),
|
||||||
(UPat(Ops.BARRIER), lambda ctx: nbarrier(ctx.b)),
|
(UPat(Ops.BARRIER), lambda ctx: nbarrier(ctx.b)),
|
||||||
(UPat(Ops.IF, name="x"), lambda ctx,x: mesa.nir_push_if(ctx.b, ctx.r[x.src[0]])),
|
(UPat(Ops.IF, name="x"), lambda ctx,x: mesa.nir_push_if(ctx.b, ctx.r[x.src[0]])),
|
||||||
(UPat(Ops.ENDIF, name="x"), lambda ctx,x: (lambda _: mesa.nir_def())(mesa.nir_pop_if(ctx.b, ctx.r[x.src[0]])))
|
(UPat(Ops.ENDIF, name="x"), lambda ctx,x: (lambda _: mesa.nir_def())(mesa.nir_pop_if(ctx.b, ctx.r[x.src[0]])))
|
||||||
|
|
@ -190,19 +191,21 @@ class NIRRenderer(Renderer):
|
||||||
self.param_idx, ranges = 0, []
|
self.param_idx, ranges = 0, []
|
||||||
|
|
||||||
for u in uops:
|
for u in uops:
|
||||||
if u.op in {Ops.NOOP, Ops.GROUP, Ops.INDEX}: pass
|
if u.op in {Ops.NOOP, Ops.GROUP} or (u.op is Ops.STACK and len(u.src) == 0): pass
|
||||||
elif u.op is Ops.CAST and isinstance(u.dtype, PtrDType): pass
|
elif u.op in {Ops.INDEX, Ops.SHRINK}:
|
||||||
|
# INDEX on a register value picks the element, memory INDEX is handled in the LOAD/STORE patterns
|
||||||
|
if u.src[0].op not in {Ops.PARAM, Ops.BUFFER, Ops.AFTER}: self.r[u] = nchannel(self.b, self.r[u.src[0]], u.src[1].arg)
|
||||||
elif u.op is Ops.AFTER:
|
elif u.op is Ops.AFTER:
|
||||||
self.r[u] = self.r[u.src[0]]
|
self.r[u] = self.r[u.src[0]]
|
||||||
elif u.op == Ops.SINK:
|
elif u.op == Ops.SINK:
|
||||||
if u.arg is not None:
|
if u.arg is not None:
|
||||||
self.b.shader.contents.info.name = ctypes.cast(ctypes.create_string_buffer(u.arg.function_name.encode()), POINTER[ctypes.c_char])
|
self.b.shader.contents.info.name = ctypes.cast(ctypes.create_string_buffer(u.arg.function_name.encode()), POINTER[ctypes.c_char])
|
||||||
elif u.op == Ops.DEFINE_LOCAL:
|
elif u.op == Ops.BUFFER and u.addrspace == AddrSpace.LOCAL:
|
||||||
self.r[u] = nimm(self.b, self.b.shader.contents.info.shared_size, dtypes.long)
|
self.r[u] = nimm(self.b, self.b.shader.contents.info.shared_size, dtypes.long)
|
||||||
self.b.shader.contents.info.shared_size += u.dtype.nbytes()
|
self.b.shader.contents.info.shared_size += u.max_numel()*u.dtype.itemsize
|
||||||
elif u.op == Ops.RANGE:
|
elif u.op == Ops.RANGE:
|
||||||
ranges.append(i:=deref_var(self.b, mesa.nir_local_variable_create(self.b.impl, glsl_type(u.dtype), f"idx{range_str(u)}".encode()).contents))
|
ranges.append(i:=deref_var(self.b, mesa.nir_local_variable_create(self.b.impl, glsl_type(u.dtype), f"idx{range_str(u)}".encode()).contents))
|
||||||
nstore(self.b, AddrSpace.REG, i, nimm(self.b, 0, u.dtype), u.dtype)
|
nstore(self.b, AddrSpace.REG, i, nimm(self.b, 0, u.dtype))
|
||||||
mesa.nir_push_loop(self.b)
|
mesa.nir_push_loop(self.b)
|
||||||
self.r[u] = nload(self.b, AddrSpace.REG, i, u)
|
self.r[u] = nload(self.b, AddrSpace.REG, i, u)
|
||||||
nif(self.b, nalu(self.b, "ilt", self.r[u], self.r[u.src[0]]), lambda: None, lambda: njump(self.b, mesa.nir_jump_break))
|
nif(self.b, nalu(self.b, "ilt", self.r[u], self.r[u.src[0]]), lambda: None, lambda: njump(self.b, mesa.nir_jump_break))
|
||||||
|
|
@ -211,7 +214,7 @@ class NIRRenderer(Renderer):
|
||||||
next_i = nalu(self.b, "iadd", self.r[r], nimm(self.b, 1, r.dtype))
|
next_i = nalu(self.b, "iadd", self.r[r], nimm(self.b, 1, r.dtype))
|
||||||
# TODO: this nif should be removable ... but TestMultiTensor.test_double_matmul_shard_W_0 segfaults with it gone
|
# TODO: this nif should be removable ... but TestMultiTensor.test_double_matmul_shard_W_0 segfaults with it gone
|
||||||
nif(self.b, nalu(self.b, "ilt", next_i, self.r[r.src[0]]), lambda: None, lambda: njump(self.b, mesa.nir_jump_break))
|
nif(self.b, nalu(self.b, "ilt", next_i, self.r[r.src[0]]), lambda: None, lambda: njump(self.b, mesa.nir_jump_break))
|
||||||
nstore(self.b, AddrSpace.REG, ranges.pop(), next_i, r.dtype),
|
nstore(self.b, AddrSpace.REG, ranges.pop(), next_i),
|
||||||
mesa.nir_pop_loop(self.b, None)
|
mesa.nir_pop_loop(self.b, None)
|
||||||
else:
|
else:
|
||||||
if (d:=self.def_rewrite.rewrite(u, ctx=self)) is None: raise RuntimeError(f"failed to render {u.op} srcs {[x.dtype for x in u.src]}")
|
if (d:=self.def_rewrite.rewrite(u, ctx=self)) is None: raise RuntimeError(f"failed to render {u.op} srcs {[x.dtype for x in u.src]}")
|
||||||
|
|
@ -254,7 +257,7 @@ class LVPRenderer(NIRRenderer):
|
||||||
|
|
||||||
def prerender(self, uops:list[UOp]):
|
def prerender(self, uops:list[UOp]):
|
||||||
super().prerender(uops)
|
super().prerender(uops)
|
||||||
self.param_sz = sum([8 if u.op == Ops.PARAM else u.dtype.itemsize for u in uops if u.op in (Ops.PARAM, Ops.DEFINE_VAR)])
|
self.param_sz = sum([8 if u.addrspace is not None else u.dtype.itemsize for u in uops if u.op is Ops.PARAM])
|
||||||
|
|
||||||
def tovec(b, idx_y, idx_x): return nalu(b, "vec4", idx_x, idx_y, nundef(b, dtypes.int), nundef(b, dtypes.int))
|
def tovec(b, idx_y, idx_x): return nalu(b, "vec4", idx_x, idx_y, nundef(b, dtypes.int), nundef(b, dtypes.int))
|
||||||
def nfloat(dtype): return mesa.nir_type_float16 if dtype == dtypes.half else mesa.nir_type_float32
|
def nfloat(dtype): return mesa.nir_type_float16 if dtype == dtypes.half else mesa.nir_type_float32
|
||||||
|
|
@ -269,7 +272,6 @@ _nload_img = nir_instr(intrins=lambda dtype:{'IMAGE_DIM':mesa.GLSL_SAMPLER_DIM_2
|
||||||
lambda b,img,idx_y,idx_x,dtype: mesa.nir_intrinsic_instr_create(b.shader, g("nir_intrinsic_image_load")))
|
lambda b,img,idx_y,idx_x,dtype: mesa.nir_intrinsic_instr_create(b.shader, g("nir_intrinsic_image_load")))
|
||||||
|
|
||||||
class IR3Renderer(NIRRenderer, OpenCLRenderer):
|
class IR3Renderer(NIRRenderer, OpenCLRenderer):
|
||||||
new_style = False
|
|
||||||
has_aux = True
|
has_aux = True
|
||||||
|
|
||||||
def nload_img(ctx,img,idx_y,idx_x):
|
def nload_img(ctx,img,idx_y,idx_x):
|
||||||
|
|
@ -294,11 +296,11 @@ class IR3Renderer(NIRRenderer, OpenCLRenderer):
|
||||||
def prerender(self, uops:list[UOp]):
|
def prerender(self, uops:list[UOp]):
|
||||||
super().prerender(uops)
|
super().prerender(uops)
|
||||||
self.texs:set[UOp] = set()
|
self.texs:set[UOp] = set()
|
||||||
self.uops, self.ibo_idx, self.img_idx = uops, 0, 0
|
self.img_idx = 0
|
||||||
self.param_sz = sum([8 if u.op == Ops.PARAM else u.dtype.itemsize for u in uops if u.op in (Ops.PARAM, Ops.DEFINE_VAR)])
|
self.param_sz = sum([8 if u.addrspace is not None else u.dtype.itemsize for u in uops if u.op is Ops.PARAM])
|
||||||
|
|
||||||
def postrender(self, uops:list[UOp]):
|
def postrender(self, uops:list[UOp]):
|
||||||
bufs, texs, imgs = [u for u in uops if u.op == Ops.PARAM], itertools.count().__next__, itertools.count().__next__
|
bufs, texs, imgs = [u for u in uops if u.op is Ops.PARAM and u.addrspace is not None], itertools.count().__next__, itertools.count().__next__
|
||||||
for b in filter(lambda b: isinstance(b.dtype, ImageDType), bufs): nimm_set(self.r[b], texs() if b in self.texs else imgs(), dtypes.int)
|
for b in filter(lambda b: isinstance(b.dtype, ImageDType), bufs): nimm_set(self.r[b], texs() if b in self.texs else imgs(), dtypes.int)
|
||||||
|
|
||||||
self.b.shader.contents.info.num_ubos = len([u for u in bufs if not isinstance(u.dtype, ImageDType)])
|
self.b.shader.contents.info.num_ubos = len([u for u in bufs if not isinstance(u.dtype, ImageDType)])
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue