Compare commits

...

4 commits

Author SHA1 Message Date
George Hotz
c7b6ee0c7d dt.count 2026-06-02 13:19:04 -07:00
George Hotz
13f5d39fcf count 2026-06-02 13:16:42 -07:00
George Hotz
431accc9b7 bitsize in nir 2026-06-02 13:11:31 -07:00
George Hotz
df000116ea more renderer cleanups 2026-06-02 13:02:28 -07:00
6 changed files with 34 additions and 36 deletions

View file

@ -56,7 +56,7 @@ class AddrSpace(Enum):
@dataclass(frozen=True, eq=False)
class DType(metaclass=DTypeMetaClass):
priority: int # this determines when things get upcasted
bitsize: int
bitsize: int # this is the bitsize of the base dtype
name: str
fmt: FmtStr|None
count: int
@ -76,7 +76,7 @@ class DType(metaclass=DTypeMetaClass):
def vec(self, sz:int) -> DType:
assert self.count == 1, f"can't vectorize {self} with size {sz}"
if sz == 1 or self == dtypes.void: return self # void doesn't vectorize, and sz=1 is scalar
return DType(self.priority, self.bitsize*sz, f"{INVERSE_DTYPES_DICT[self.name]}{sz}", None, sz, self)
return DType(self.priority, self.bitsize, f"{INVERSE_DTYPES_DICT[self.name]}{sz}", None, sz, self)
def ptr(self, size=-1, addrspace=AddrSpace.GLOBAL) -> PtrDType:
return PtrDType(self.priority, self.bitsize, self.name, self.fmt, self.count, None, self, addrspace, 1, size)
def scalar(self) -> DType: return self._scalar if self._scalar is not None else self

View file

@ -3,7 +3,7 @@ from typing import Callable, cast
from dataclasses import dataclass
from tinygrad.helpers import prod, Target, EMULATED_DTYPES
from tinygrad.uop.ops import Ops, UOp, sint, ssimplify, smin, GroupOp, PatternMatcher
from tinygrad.dtype import AddrSpace, PtrDType, DType, dtypes
from tinygrad.dtype import AddrSpace, DType, dtypes
from tinygrad.codegen.opt.tc import TensorCore
from tinygrad.device import Compiler
@ -41,7 +41,7 @@ class Estimates:
while len(buf.src) and buf.op is not Ops.PARAM: buf = buf.src[0]
if buf.op is Ops.PARAM:
# u.src[0] is INDEX, cap at buffer size for re-reads (e.g. matmul)
accessed = mem.get((buf, u.op), 0) + u.src[0].dtype.base.itemsize * mults
accessed = mem.get((buf, u.op), 0) + u.max_numel() * u.src[0].dtype.itemsize * mults
mem[(buf, u.op)] = smin(accessed, buf.max_numel() * buf.dtype.itemsize)
if u.op is Ops.RANGE:
mult_stack.append(mults)
@ -51,10 +51,10 @@ class Estimates:
elif u.op is Ops.END: mults = mult_stack.pop(-1)
elif u.op is Ops.SPECIAL: mults *= cast(sint, u.src[0].ssimplify()) # NOTE: we don't push to the mult_stack here, you can't end these
elif u.op is Ops.DEFINE_VAR and u.arg[0] == 'core_id': mults *= u.arg[2] + 1
elif u.op is Ops.LOAD and (not isinstance(u.src[0].dtype, PtrDType) or u.src[0].dtype.addrspace != AddrSpace.REG):
lds += u.dtype.itemsize * mults
elif u.op is Ops.STORE and (not isinstance(u.src[0].dtype, PtrDType) or u.src[0].dtype.addrspace != AddrSpace.REG):
lds += u.src[1].dtype.itemsize * mults
elif u.op is Ops.LOAD and u.src[0].addrspace != AddrSpace.REG:
lds += u.max_numel() * u.dtype.itemsize * mults
elif u.op is Ops.STORE and u.src[0].addrspace != AddrSpace.REG:
lds += u.max_numel() * u.src[1].dtype.itemsize * mults
elif u.op in GroupOp.ALU and u not in dont_count: flops += (mults * (2 if u.op is Ops.MULACC else 1)) * u.dtype.count
elif u.op is Ops.WMMA and u not in dont_count: flops += 2 * prod(u.arg[1]) // u.arg[5] * mults
return Estimates(flops, lds, sum(mem.values()))

View file

@ -244,7 +244,7 @@ class ClangRenderer(CStyleLanguage):
kernel_typedef = "__attribute__((ms_abi)) void"
def render_vector_prefix(self, dt:DType) -> str:
# round (down) to power of two (this is actually the default clang behavior)
alignment = 2**int(math.log2(dt.itemsize)) if getenv("ALIGNED", 1) and not dtypes.is_bool(dt) else 1
alignment = 2**int(math.log2(dt.itemsize*dt.count)) if getenv("ALIGNED", 1) and not dtypes.is_bool(dt) else 1
return f"typedef {self.render_dtype(dt.scalar())} {self.render_dtype(dt)} __attribute__((aligned({alignment}),ext_vector_type({dt.count})));"
def _render_defines(self, uops) -> list[str]:
@ -433,7 +433,8 @@ class CUDARenderer(CStyleLanguage):
def render_vector_prefix(self, dt:DType) -> str:
vec, scal = self.render_dtype(dt), self.render_dtype(dt.scalar()),
elems, header = ', '.join(_nms[:dt.count]), ', '.join([f"{scal} {x}" for x in _nms[:dt.count]])
return f"struct __align__({dt.itemsize}) {vec} {{ {scal} {elems}; }}; __device__ {vec} make_{vec}({header}) {{ {vec} r={{{elems}}}; return r; }}"
return f"struct __align__({dt.itemsize*dt.count}) {vec} {{ {scal} {elems}; }};" + \
f"__device__ {vec} make_{vec}({header}) {{ {vec} r={{{elems}}}; return r; }}"
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None):
# TODO: why is dtypes.bfloat16.name == "__bf16"? would be easier not override dtypes.name
@ -450,7 +451,8 @@ class CUDARenderer(CStyleLanguage):
for name, (N, M, K), dtype_in, dtype_out, _, _, upcast_axes, _ in wmma_args(uops):
upcast_sizes = [prod(size for _, size in upcast) for upcast in upcast_axes]
wmma_dtypes = [self.render_dtype(dtype.vec(size)) for dtype, size in zip([dtype_in, dtype_in, dtype_out], upcast_sizes)]
n_operands = [size*dtype.itemsize//4 for dtype, size in zip([dtype_in, dtype_in, dtype_out], upcast_sizes)] # 4 => CUDA reg size in bytes
# 4 => CUDA reg size in bytes
n_operands = [size*dtype.itemsize*dtype.count//4 for dtype, size in zip([dtype_in, dtype_in, dtype_out], upcast_sizes)]
operands = [f"%{i}" for i in range(sum(n_operands))]
# mma operands => {c}, {a}, {b}, {c}

View file

@ -633,8 +633,8 @@ def encode(x:UOp, opc:int, reg:int|None=None, pp:int=0, sel:int=0, we:int=0) ->
reg = cast(int, cast(Register, reg_uop.reg).index if reg_uop is not None else reg)
rm = cast(Register, rm_uop.reg).index
idx = cast(Register, idx_uop.reg).index if idx_uop is not None and idx_uop.reg is not None else 4
rm_sz = 8 if isinstance(rm_uop.dtype, PtrDType) and disp_uop is None else rm_uop.dtype.itemsize
reg_sz = (reg_uop.dtype.itemsize if not isinstance(reg_uop.dtype, PtrDType) else 8) if reg_uop is not None else 0
rm_sz = 8 if isinstance(rm_uop.dtype, PtrDType) and disp_uop is None else (rm_uop.dtype.itemsize*rm_uop.dtype.count)
reg_sz = ((reg_uop.dtype.itemsize*reg_uop.dtype.count) if not isinstance(reg_uop.dtype, PtrDType) else 8) if reg_uop is not None else 0
sz = reg_sz or rm_sz
# encode instruction

View file

@ -71,12 +71,12 @@ def nimm_set(imm:mesa.nir_def, x, dtype:DType):
instr = ctypes.cast(imm.parent_instr, ctypes.POINTER(mesa.nir_load_const_instr))
struct.pack_into(unwrap(dtype.fmt), (ctypes.c_ubyte * dtype.itemsize).from_address(ctypes.addressof(instr.contents.value)), 0, truncate[dtype](x))
@nir_instr(nc=1, bs=lambda dtype: dtype.bitsize)
@nir_instr(nc=1, bs=lambda dtype: dtype.bitsize*dtype.count)
def nimm(b:mesa.nir_builder, x, dtype:DType) -> mesa.nir_def:
nimm_set((instr:=mesa.nir_load_const_instr_create(b.shader, 1, dtype.bitsize)).contents._def, x, dtype)
nimm_set((instr:=mesa.nir_load_const_instr_create(b.shader, 1, dtype.bitsize*dtype.count)).contents._def, x, dtype)
return instr
@nir_instr(nc=1, bs=lambda dtype: dtype.bitsize)
def nundef(b, dtype): return mesa.nir_undef_instr_create(b.shader, 1, dtype.bitsize)
@nir_instr(nc=1, bs=lambda dtype: dtype.bitsize*dtype.count)
def nundef(b, dtype): return mesa.nir_undef_instr_create(b.shader, 1, dtype.bitsize*dtype.count)
deref_var = nir_instr(nc=1, bs=32, modes=lambda var:var.data.mode, type=lambda var:var.type, var=lambda var:ctypes.pointer(var))( # pylint: disable=W0108
lambda b, var: mesa.nir_deref_instr_create(b.shader, mesa.nir_deref_type_var))
@ -86,7 +86,7 @@ def scope(space): return 'global' if space == AddrSpace.GLOBAL else ('shared' if
nstore = nir_instr(has_def=False, df=lambda addr:addr, intrins=lambda space,val: {"WRITE_MASK":(1<<val.num_components)-1, **iointr(space)},
num_components=lambda val:val.num_components, srcs=lambda space, addr, val: [nsrc(val), nsrc(addr)][::1 if space != AddrSpace.REG else -1])(
lambda b, space, addr, val, dtype: mesa.nir_intrinsic_instr_create(b.shader, g(f"nir_intrinsic_store_{scope(space)}")))
nload = nir_instr(nc=lambda dtype:dtype.count, bs=lambda dtype:dtype.bitsize//dtype.count, num_components=lambda dtype:dtype.count,
nload = nir_instr(nc=lambda dtype:dtype.count, bs=lambda dtype:dtype.bitsize, num_components=lambda dtype:dtype.count,
intrins=lambda space:{**({"ACCESS":mesa.ACCESS_CAN_REORDER} if space==AddrSpace.GLOBAL else {}), **iointr(space)}, srcs=lambda addr: [nsrc(addr)])(
lambda b, space, addr, dtype: mesa.nir_intrinsic_instr_create(b.shader, g(f"nir_intrinsic_load_{scope(space)}")))

View file

@ -27,12 +27,8 @@ def packed_load(root:UOp, bidx:UOp, dtype:DType, var:UOp|None=None, gate:UOp|Non
val = (load.cast(dtypes.uint32) >> shift_am) & mask
return sign_extend(val, 8*dtype.itemsize).cast(dtype) if dtype in [dtypes.char, dtypes.short] else val.cast(dtype)
def is_packed(dt:DType, odt:DType|None = None) -> bool:
if odt is None: odt = dt
# registers aren't packed
if isinstance(odt, PtrDType) and odt.addrspace == AddrSpace.REG: return False
return dt.itemsize < 4 and dt.base != dtypes.half
def _packed_size(dt:PtrDType): return dt.size // (4//dt.itemsize) if is_packed(dt) else dt.size
def is_packed(u:UOp) -> bool: return u.dtype.itemsize < 4 and u.dtype.base != dtypes.half and u.addrspace != AddrSpace.REG
def _packed_size(u:UOp): return u.max_numel() // (4//u.dtype.itemsize) if is_packed(u) else u.max_numel()
def is_nan(a):
bs, (exp, mant) = a.dtype.bitsize, dtypes.finfo(a.dtype)
@ -43,12 +39,12 @@ wgsl_matcher = PatternMatcher([
lambda a,b,c: a.cast(dtypes.int).alu(c.op, b.cast(dtypes.int)).cast(dtypes.bool)),
# TODO: load alt value doesnt have to be a const
(UPat.load(UPat.var("b"), UPat.cvar("c"), UPat.var("gate"), name="l"),
lambda l,b,c,gate: packed_load(l,b,l.dtype,c.cast(dtypes.uint32),gate) if is_packed(l.dtype, b.dtype) else None),
(UPat.load(UPat.var("b"), name='l'), lambda l,b: packed_load(l, b, l.dtype) if is_packed(l.dtype, b.dtype) else None),
lambda l,b,c,gate: packed_load(l,b,l.dtype,c.cast(dtypes.uint32),gate) if is_packed(b) else None),
(UPat.load(UPat.var("b"), name='l'), lambda l,b: packed_load(l, b, l.dtype) if is_packed(b) else None),
(UPat.store(UPat.var("bidx"), UPat.var("var"), UPat.var("gate")),
lambda bidx,var,gate: packed_store(bidx,var,gate) if is_packed(var.dtype, bidx.dtype) else None),
lambda bidx,var,gate: packed_store(bidx,var,gate) if is_packed(bidx) else None),
(UPat.store(UPat.var("bidx"), UPat.var("var")),
lambda bidx,var: packed_store(bidx,var) if is_packed(var.dtype, bidx.dtype) else None),
lambda bidx,var: packed_store(bidx,var) if is_packed(bidx) else None),
(UPat.var("a") << UPat.var("b"),lambda a,b:(a.bitcast(dtypes.uint32)<<b.cast(dtypes.uint32)).bitcast(a.dtype) if b.dtype!=dtypes.uint32 else None),
(UPat.var("x") >> UPat.var("y"), lambda x,y: UOp(Ops.SHR, x.dtype, (x,y.cast(dtypes.uint))) if y.dtype != dtypes.uint else None),
# fix nan check: 'a != a -> is_nan()'
@ -73,8 +69,8 @@ class WGSLRenderer(CStyleLanguage):
(UPat(Ops.CONST, dtype=(dtypes.uchar, dtypes.ushort, dtypes.uint32), name="x"),
lambda x: f"bitcast<u32>({x.arg})" if x.arg < 0 else f"{x.arg&0xFFFFFFFF}u"),
(UPat(Ops.CONST, dtype=dtypes.int32, name="x"), lambda ctx,x: f"{truncate[x.dtype](x.arg)}"),
(UPat(Ops.DEFINE_LOCAL, name="x"), lambda ctx,x: f"var<workgroup> {ctx[x]}: array<{ctx.buf_map(x.dtype.base)},{_packed_size(x.dtype)}>;"),
(UPat(Ops.DEFINE_REG, name="x"), lambda ctx,x: f"var {ctx[x]}: array<{ctx.buf_map(x.dtype)},{_packed_size(x.dtype)}>;"),
(UPat(Ops.DEFINE_LOCAL, name="x"), lambda ctx,x: f"var<workgroup> {ctx[x]}: array<{ctx.buf_map(x)},{_packed_size(x)}>;"),
(UPat(Ops.DEFINE_REG, name="x"), lambda ctx,x: f"var {ctx[x]}: array<{ctx.buf_map(x)},{_packed_size(x)}>;"),
(UPat(Ops.BITCAST, dtype=dtypes.half, name="x", src=(UPat(dtype=(dtypes.short, dtypes.ushort, dtypes.uint32),),)),
lambda ctx,x: f"bitcast<vec2<f16>>({ctx[x.src[0]]})[0]"),
(UPat(Ops.BITCAST, dtype=dtypes.uchar, name="x"), lambda ctx,x: f"bitcast<u32>({ctx[x.src[0]]}&0xFF)"),
@ -86,11 +82,11 @@ class WGSLRenderer(CStyleLanguage):
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"bitcast<{ctx.type_map[x.dtype]}>({ctx[x.src[0]]})"),
# TODO: load alt value doesnt have to be a const
(UPat.load(UPat.var("b"), UPat.cvar("v"), UPat.var("gate")),
lambda ctx,b,v,gate: f"select({ctx[v]}, {ctx.render_load(ctx[b],b.src[0].dtype)}, {ctx[gate]})"),
(UPat.load(UPat.var("b")), lambda ctx, b: ctx.render_load(ctx[b], b.dtype)),
lambda ctx,b,v,gate: f"select({ctx[v]}, {ctx.render_load(ctx[b], b.src[0])}, {ctx[gate]})"),
(UPat.load(UPat.var("b")), lambda ctx, b: ctx.render_load(ctx[b], b)),
(UPat.store(UPat.var("b"), UPat.var("v")), lambda ctx,b,v:\
# (load & mask) | var -> mask = v.src[0].src[1], var = v.src[1]
f"atomicAnd(&{ctx[b]},{ctx[v.src[0].src[1]]});\n atomicAdd(&{ctx[b]},{ctx[v.src[1]]});" if is_packed(b.src[0].dtype) \
f"atomicAnd(&{ctx[b]},{ctx[v.src[0].src[1]]});\n atomicAdd(&{ctx[b]},{ctx[v.src[1]]});" if is_packed(b.src[0]) \
else f"{ctx[b]} = {ctx[v]};"),
(UPat(Ops.INDEX, src=(UPat.var("b"), UPat.var("idx"))),
lambda ctx,b,idx: f"{ctx[b]}[{strip_parens(ctx[idx]) if idx.arg is Ops.ADD else ctx[idx]}]"),
@ -98,8 +94,8 @@ class WGSLRenderer(CStyleLanguage):
def render_cast(self, dt:DType, val: str) -> str: return f"{self.type_map[dt]}({val})"
def render_dtype(self, dt:DType, mutable=True) -> str: return "var"
def render_load(self, x:str, dt:DType) -> str: return f"atomicLoad(&{x})" if is_packed(dt) else x
def buf_map(self, dt:DType) -> str: return "atomic<u32>" if is_packed(dt) else self.type_map[dt.base]
def render_load(self, x:str, u:UOp) -> str: return f"atomicLoad(&{x})" if is_packed(u) else x
def buf_map(self, u:UOp) -> str: return "atomic<u32>" if is_packed(u) else self.type_map[u.dtype.base]
def render_kernel(self, function_name:str, kernel:list[str], bufs:list[tuple[str,tuple[UOp,bool]]], uops:list[UOp], prefix=None) -> str:
local_size = [u.src[0].ssimplify() for u in sorted([u for u in uops if u.op is Ops.SPECIAL and u.arg[0] == 'l'], key=lambda u: u.arg)]
if not local_size: local_size = [1]
@ -111,7 +107,7 @@ class WGSLRenderer(CStyleLanguage):
prg += "@group(0) @binding(0)\nvar<uniform> INFINITY : f32;\n"
prg += "\n".join((external_local_bufs or [])+[f"@group(0) @binding({next(bind_it)+1})" +
f"{'var<storage,read_write>' if isinstance(u.dtype, PtrDType) else 'var<uniform>'}" +
f"{name}:{f'array<{self.buf_map(u.dtype.base)}>' if isinstance(u.dtype,PtrDType) else self.buf_map(u.dtype)};" for name,(u,_) in bufs])
f"{name}:{f'array<{self.buf_map(u)}>' if isinstance(u.dtype,PtrDType) else self.buf_map(u)};" for name,(u,_) in bufs])
prg += f"\n@compute @workgroup_size({','.join([str(x) for x in local_size])}) fn {function_name}(@builtin(workgroup_id) gindex: vec3<u32>,"
return prg + "@builtin(local_invocation_id) lindex: vec3<u32>) {\n" + "\n".join(kernel) + "\n}"