mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
bitsize in nir
This commit is contained in:
parent
df000116ea
commit
431accc9b7
2 changed files with 19 additions and 23 deletions
|
|
@ -71,12 +71,12 @@ def nimm_set(imm:mesa.nir_def, x, dtype:DType):
|
|||
instr = ctypes.cast(imm.parent_instr, ctypes.POINTER(mesa.nir_load_const_instr))
|
||||
struct.pack_into(unwrap(dtype.fmt), (ctypes.c_ubyte * dtype.itemsize).from_address(ctypes.addressof(instr.contents.value)), 0, truncate[dtype](x))
|
||||
|
||||
@nir_instr(nc=1, bs=lambda dtype: dtype.bitsize)
|
||||
@nir_instr(nc=1, bs=lambda dtype: dtype.bitsize*dtype.count)
|
||||
def nimm(b:mesa.nir_builder, x, dtype:DType) -> mesa.nir_def:
|
||||
nimm_set((instr:=mesa.nir_load_const_instr_create(b.shader, 1, dtype.bitsize)).contents._def, x, dtype)
|
||||
nimm_set((instr:=mesa.nir_load_const_instr_create(b.shader, 1, dtype.bitsize*dtype.count)).contents._def, x, dtype)
|
||||
return instr
|
||||
@nir_instr(nc=1, bs=lambda dtype: dtype.bitsize)
|
||||
def nundef(b, dtype): return mesa.nir_undef_instr_create(b.shader, 1, dtype.bitsize)
|
||||
@nir_instr(nc=1, bs=lambda dtype: dtype.bitsize*dtype.count)
|
||||
def nundef(b, dtype): return mesa.nir_undef_instr_create(b.shader, 1, dtype.bitsize*dtype.count)
|
||||
|
||||
deref_var = nir_instr(nc=1, bs=32, modes=lambda var:var.data.mode, type=lambda var:var.type, var=lambda var:ctypes.pointer(var))( # pylint: disable=W0108
|
||||
lambda b, var: mesa.nir_deref_instr_create(b.shader, mesa.nir_deref_type_var))
|
||||
|
|
@ -86,7 +86,7 @@ def scope(space): return 'global' if space == AddrSpace.GLOBAL else ('shared' if
|
|||
nstore = nir_instr(has_def=False, df=lambda addr:addr, intrins=lambda space,val: {"WRITE_MASK":(1<<val.num_components)-1, **iointr(space)},
|
||||
num_components=lambda val:val.num_components, srcs=lambda space, addr, val: [nsrc(val), nsrc(addr)][::1 if space != AddrSpace.REG else -1])(
|
||||
lambda b, space, addr, val, dtype: mesa.nir_intrinsic_instr_create(b.shader, g(f"nir_intrinsic_store_{scope(space)}")))
|
||||
nload = nir_instr(nc=lambda dtype:dtype.count, bs=lambda dtype:dtype.bitsize//dtype.count, num_components=lambda dtype:dtype.count,
|
||||
nload = nir_instr(nc=lambda dtype:dtype.count, bs=lambda dtype:dtype.bitsize, num_components=lambda dtype:dtype.count,
|
||||
intrins=lambda space:{**({"ACCESS":mesa.ACCESS_CAN_REORDER} if space==AddrSpace.GLOBAL else {}), **iointr(space)}, srcs=lambda addr: [nsrc(addr)])(
|
||||
lambda b, space, addr, dtype: mesa.nir_intrinsic_instr_create(b.shader, g(f"nir_intrinsic_load_{scope(space)}")))
|
||||
|
||||
|
|
|
|||
|
|
@ -27,12 +27,8 @@ def packed_load(root:UOp, bidx:UOp, dtype:DType, var:UOp|None=None, gate:UOp|Non
|
|||
val = (load.cast(dtypes.uint32) >> shift_am) & mask
|
||||
return sign_extend(val, 8*dtype.itemsize).cast(dtype) if dtype in [dtypes.char, dtypes.short] else val.cast(dtype)
|
||||
|
||||
def is_packed(dt:DType, odt:DType|None = None) -> bool:
|
||||
if odt is None: odt = dt
|
||||
# registers aren't packed
|
||||
if isinstance(odt, PtrDType) and odt.addrspace == AddrSpace.REG: return False
|
||||
return dt.itemsize < 4 and dt.base != dtypes.half
|
||||
def _packed_size(dt:PtrDType): return dt.size // (4//dt.itemsize) if is_packed(dt) else dt.size
|
||||
def is_packed(u:UOp) -> bool: return u.dtype.itemsize < 4 and u.dtype.base != dtypes.half and u.addrspace != AddrSpace.REG
|
||||
def _packed_size(u:UOp): return u.max_numel() // (4//u.dtype.itemsize) if is_packed(u) else u.max_numel()
|
||||
|
||||
def is_nan(a):
|
||||
bs, (exp, mant) = a.dtype.bitsize, dtypes.finfo(a.dtype)
|
||||
|
|
@ -43,12 +39,12 @@ wgsl_matcher = PatternMatcher([
|
|||
lambda a,b,c: a.cast(dtypes.int).alu(c.op, b.cast(dtypes.int)).cast(dtypes.bool)),
|
||||
# TODO: load alt value doesnt have to be a const
|
||||
(UPat.load(UPat.var("b"), UPat.cvar("c"), UPat.var("gate"), name="l"),
|
||||
lambda l,b,c,gate: packed_load(l,b,l.dtype,c.cast(dtypes.uint32),gate) if is_packed(l.dtype, b.dtype) else None),
|
||||
(UPat.load(UPat.var("b"), name='l'), lambda l,b: packed_load(l, b, l.dtype) if is_packed(l.dtype, b.dtype) else None),
|
||||
lambda l,b,c,gate: packed_load(l,b,l.dtype,c.cast(dtypes.uint32),gate) if is_packed(b) else None),
|
||||
(UPat.load(UPat.var("b"), name='l'), lambda l,b: packed_load(l, b, l.dtype) if is_packed(b) else None),
|
||||
(UPat.store(UPat.var("bidx"), UPat.var("var"), UPat.var("gate")),
|
||||
lambda bidx,var,gate: packed_store(bidx,var,gate) if is_packed(var.dtype, bidx.dtype) else None),
|
||||
lambda bidx,var,gate: packed_store(bidx,var,gate) if is_packed(bidx) else None),
|
||||
(UPat.store(UPat.var("bidx"), UPat.var("var")),
|
||||
lambda bidx,var: packed_store(bidx,var) if is_packed(var.dtype, bidx.dtype) else None),
|
||||
lambda bidx,var: packed_store(bidx,var) if is_packed(bidx) else None),
|
||||
(UPat.var("a") << UPat.var("b"),lambda a,b:(a.bitcast(dtypes.uint32)<<b.cast(dtypes.uint32)).bitcast(a.dtype) if b.dtype!=dtypes.uint32 else None),
|
||||
(UPat.var("x") >> UPat.var("y"), lambda x,y: UOp(Ops.SHR, x.dtype, (x,y.cast(dtypes.uint))) if y.dtype != dtypes.uint else None),
|
||||
# fix nan check: 'a != a -> is_nan()'
|
||||
|
|
@ -73,8 +69,8 @@ class WGSLRenderer(CStyleLanguage):
|
|||
(UPat(Ops.CONST, dtype=(dtypes.uchar, dtypes.ushort, dtypes.uint32), name="x"),
|
||||
lambda x: f"bitcast<u32>({x.arg})" if x.arg < 0 else f"{x.arg&0xFFFFFFFF}u"),
|
||||
(UPat(Ops.CONST, dtype=dtypes.int32, name="x"), lambda ctx,x: f"{truncate[x.dtype](x.arg)}"),
|
||||
(UPat(Ops.DEFINE_LOCAL, name="x"), lambda ctx,x: f"var<workgroup> {ctx[x]}: array<{ctx.buf_map(x.dtype.base)},{_packed_size(x.dtype)}>;"),
|
||||
(UPat(Ops.DEFINE_REG, name="x"), lambda ctx,x: f"var {ctx[x]}: array<{ctx.buf_map(x.dtype)},{_packed_size(x.dtype)}>;"),
|
||||
(UPat(Ops.DEFINE_LOCAL, name="x"), lambda ctx,x: f"var<workgroup> {ctx[x]}: array<{ctx.buf_map(x)},{_packed_size(x)}>;"),
|
||||
(UPat(Ops.DEFINE_REG, name="x"), lambda ctx,x: f"var {ctx[x]}: array<{ctx.buf_map(x)},{_packed_size(x)}>;"),
|
||||
(UPat(Ops.BITCAST, dtype=dtypes.half, name="x", src=(UPat(dtype=(dtypes.short, dtypes.ushort, dtypes.uint32),),)),
|
||||
lambda ctx,x: f"bitcast<vec2<f16>>({ctx[x.src[0]]})[0]"),
|
||||
(UPat(Ops.BITCAST, dtype=dtypes.uchar, name="x"), lambda ctx,x: f"bitcast<u32>({ctx[x.src[0]]}&0xFF)"),
|
||||
|
|
@ -86,11 +82,11 @@ class WGSLRenderer(CStyleLanguage):
|
|||
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"bitcast<{ctx.type_map[x.dtype]}>({ctx[x.src[0]]})"),
|
||||
# TODO: load alt value doesnt have to be a const
|
||||
(UPat.load(UPat.var("b"), UPat.cvar("v"), UPat.var("gate")),
|
||||
lambda ctx,b,v,gate: f"select({ctx[v]}, {ctx.render_load(ctx[b],b.src[0].dtype)}, {ctx[gate]})"),
|
||||
(UPat.load(UPat.var("b")), lambda ctx, b: ctx.render_load(ctx[b], b.dtype)),
|
||||
lambda ctx,b,v,gate: f"select({ctx[v]}, {ctx.render_load(ctx[b], b.src[0])}, {ctx[gate]})"),
|
||||
(UPat.load(UPat.var("b")), lambda ctx, b: ctx.render_load(ctx[b], b)),
|
||||
(UPat.store(UPat.var("b"), UPat.var("v")), lambda ctx,b,v:\
|
||||
# (load & mask) | var -> mask = v.src[0].src[1], var = v.src[1]
|
||||
f"atomicAnd(&{ctx[b]},{ctx[v.src[0].src[1]]});\n atomicAdd(&{ctx[b]},{ctx[v.src[1]]});" if is_packed(b.src[0].dtype) \
|
||||
f"atomicAnd(&{ctx[b]},{ctx[v.src[0].src[1]]});\n atomicAdd(&{ctx[b]},{ctx[v.src[1]]});" if is_packed(b.src[0]) \
|
||||
else f"{ctx[b]} = {ctx[v]};"),
|
||||
(UPat(Ops.INDEX, src=(UPat.var("b"), UPat.var("idx"))),
|
||||
lambda ctx,b,idx: f"{ctx[b]}[{strip_parens(ctx[idx]) if idx.arg is Ops.ADD else ctx[idx]}]"),
|
||||
|
|
@ -98,8 +94,8 @@ class WGSLRenderer(CStyleLanguage):
|
|||
|
||||
def render_cast(self, dt:DType, val: str) -> str: return f"{self.type_map[dt]}({val})"
|
||||
def render_dtype(self, dt:DType, mutable=True) -> str: return "var"
|
||||
def render_load(self, x:str, dt:DType) -> str: return f"atomicLoad(&{x})" if is_packed(dt) else x
|
||||
def buf_map(self, dt:DType) -> str: return "atomic<u32>" if is_packed(dt) else self.type_map[dt.base]
|
||||
def render_load(self, x:str, u:UOp) -> str: return f"atomicLoad(&{x})" if is_packed(u) else x
|
||||
def buf_map(self, u:UOp) -> str: return "atomic<u32>" if is_packed(u) else self.type_map[u.dtype.base]
|
||||
def render_kernel(self, function_name:str, kernel:list[str], bufs:list[tuple[str,tuple[UOp,bool]]], uops:list[UOp], prefix=None) -> str:
|
||||
local_size = [u.src[0].ssimplify() for u in sorted([u for u in uops if u.op is Ops.SPECIAL and u.arg[0] == 'l'], key=lambda u: u.arg)]
|
||||
if not local_size: local_size = [1]
|
||||
|
|
@ -111,7 +107,7 @@ class WGSLRenderer(CStyleLanguage):
|
|||
prg += "@group(0) @binding(0)\nvar<uniform> INFINITY : f32;\n"
|
||||
prg += "\n".join((external_local_bufs or [])+[f"@group(0) @binding({next(bind_it)+1})" +
|
||||
f"{'var<storage,read_write>' if isinstance(u.dtype, PtrDType) else 'var<uniform>'}" +
|
||||
f"{name}:{f'array<{self.buf_map(u.dtype.base)}>' if isinstance(u.dtype,PtrDType) else self.buf_map(u.dtype)};" for name,(u,_) in bufs])
|
||||
f"{name}:{f'array<{self.buf_map(u)}>' if isinstance(u.dtype,PtrDType) else self.buf_map(u)};" for name,(u,_) in bufs])
|
||||
prg += f"\n@compute @workgroup_size({','.join([str(x) for x in local_size])}) fn {function_name}(@builtin(workgroup_id) gindex: vec3<u32>,"
|
||||
return prg + "@builtin(local_invocation_id) lindex: vec3<u32>) {\n" + "\n".join(kernel) + "\n}"
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue