addrspace special/range (#16647)

* addrspace special/range

* just include indexing

* define var is alu

* bring old ignore indexing back

* mults to fix

* fixes

* ALU

* fixes
This commit is contained in:
George Hotz 2026-06-17 15:57:37 -07:00 committed by GitHub
commit aef85ddc4d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 26 additions and 18 deletions

View file

@ -248,7 +248,7 @@ class TestTorchBackend(unittest.TestCase):
samples = torch.randint(0, X_train.shape[0], (32,))
X,Y = X_train[samples], Y_train[samples]
X.cpu(), Y.cpu()
self.assertLessEqual(GlobalCounters.global_ops, 20_000_000)
self.assertLessEqual(GlobalCounters.global_ops, 25_000_000)
def _test_diagonal(self, *shape):
a = torch.randn(*shape, dtype=torch.float32, device=device)

View file

@ -36,7 +36,7 @@ def custom_add_var(A:UOp, B:UOp) -> UOp:
A,B = A.flatten(), B.flatten()
assert A.dtype.base == dtypes.uint32, f"buffer dtype must be uint32, got {A.dtype}"
threads = UOp.special(A.numel(), "lidx0")
var = UOp.variable("var", 0, 10)
var = UOp.param(2, dtypes.weakint, vmin_vmax=(0, 10), name="var", addrspace=AddrSpace.ALU)
insts = [
s_load_b128(s[4:7], s[0:1]),
s_load_b32(s[8], s[0:1], offset=0x10), # all threads load the same variable

View file

@ -1,5 +1,6 @@
import unittest, math
from tinygrad import dtypes
from tinygrad.dtype import AddrSpace
from tinygrad.helpers import all_same, Context
from tinygrad.uop.ops import GroupOp, UOp, Ops, exec_alu, PatternMatcher, TrackedPatternMatcher, UPat
from test.helpers import full_rewrite
@ -21,7 +22,7 @@ def apply_rewrite_values(expr):
def evaluate_uop(uop, variables):
if uop.op == Ops.CONST:
return uop.arg
elif uop.op == Ops.DEFINE_VAR or (uop.op == Ops.PARAM and uop.arg.addrspace is None):
elif uop.op == Ops.DEFINE_VAR or (uop.op == Ops.PARAM and uop.arg.addrspace is AddrSpace.ALU):
return variables[uop.expr]
elif uop.op in GroupOp.ALU:
src_values = [evaluate_uop(src, variables) for src in uop.src]

View file

@ -47,7 +47,7 @@ pm_remove_vec_dtypes = PatternMatcher([
x.replace(op=Ops.BUFFER, arg=ParamArg(x.arg, addrspace=AddrSpace.LOCAL if x.op == Ops.DEFINE_LOCAL else AddrSpace.REG))),
# replace DEFINE_VAR with PARAM
(UPat(Ops.DEFINE_VAR, name="x"), lambda x:
x.replace(op=Ops.PARAM, src=(UOp(Ops.STACK),), arg=ParamArg(slot=-1, name=x.arg[0], vmin_vmax=x.arg[1:], addrspace=None))),
x.replace(op=Ops.PARAM, src=(UOp(Ops.STACK),), arg=ParamArg(slot=-1, name=x.arg[0], vmin_vmax=x.arg[1:], addrspace=AddrSpace.ALU))),
])+pm_clean_up_group_sink
def do_number_param(ctx:list[int], x:UOp):

View file

@ -24,6 +24,11 @@ class Estimates:
mem: dict[tuple[UOp, Ops], sint] = {}
mults: sint = 1
mult_stack: list[sint] = []
excluded: set[UOp] = set()
if ignore_indexing:
for u in uops:
if u.op in {Ops.INDEX, Ops.SHRINK}:
excluded = excluded.union(set(UOp.sink(*u.src[1:]).toposort()))
for u in uops:
if u.op in {Ops.LOAD, Ops.STORE}:
buf = u
@ -39,14 +44,14 @@ class Estimates:
mults = mults.substitute({x:x.const_like(0) for x in mults.toposort() if x.op is Ops.SPECIAL}) if isinstance(mults, UOp) else mults
elif u.op is Ops.END: mults = mult_stack.pop(-1)
elif u.op is Ops.SPECIAL: mults *= cast(sint, u.src[0].ssimplify()) # NOTE: we don't push to the mult_stack here, you can't end these
elif u.op is Ops.PARAM and u.arg.addrspace is None and u.expr == 'core_id': mults *= int(u.vmax) + 1
elif u.op is Ops.PARAM and u.arg.addrspace == AddrSpace.ALU and u.expr == 'core_id': mults *= int(u.vmax) + 1
elif u.op is Ops.LOAD and u.src[0].addrspace != AddrSpace.REG:
lds += u.max_numel() * u.dtype.scalar().itemsize * mults
elif u.op is Ops.STORE and u.src[0].addrspace != AddrSpace.REG:
lds += u.max_numel() * u.src[1].dtype.scalar().itemsize * mults
elif u.op in GroupOp.ALU and (not ignore_indexing or u.addrspace is not None):
elif u.op in GroupOp.ALU and u not in excluded:
flops += (mults * (2 if u.op is Ops.MULACC else 1)) * u.max_numel()
elif u.op is Ops.WMMA and (not ignore_indexing or u.addrspace is not None):
elif u.op is Ops.WMMA and u not in excluded:
flops += 2 * prod(u.arg[1]) // u.arg[5] * mults
return Estimates(flops, lds, sum(mem.values()))

View file

@ -145,7 +145,7 @@ class NIRRenderer(Renderer):
def_rewrite = PatternMatcher([
(UPat(Ops.CONST, name="x"), lambda ctx,x: nimm(ctx.b, x.arg, x.dtype)),
(UPat(Ops.PARAM, name="x"), lambda ctx,x: ctx.param(ctx.b, x, 8 if x.addrspace is not None else x.dtype.itemsize)),
(UPat(Ops.PARAM, name="x"), lambda ctx,x: ctx.param(ctx.b, x, x.dtype.itemsize if x.addrspace is AddrSpace.ALU else 8)),
(UPat(Ops.SPECIAL, name="x"), lambda ctx,x: nchannel(ctx.b, {'g':ngid, 'l':nlid, 'i': nid}[x.arg[0]](ctx.b), int(x.arg[-1]))),
(UPat(Ops.STORE, src=(UPat((Ops.INDEX, Ops.SHRINK), src=(UPat.var("buf"),UPat.var("off")), allow_any_len=True), UPat.var("val"))),
lambda ctx,buf,off,val: nstore(ctx.b, buf.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.addrspace, buf.dtype.itemsize), ctx.r[val])),
@ -256,7 +256,7 @@ class LVPRenderer(NIRRenderer):
def prerender(self, uops:list[UOp]):
super().prerender(uops)
self.param_sz = sum([8 if u.addrspace is not None else u.dtype.itemsize for u in uops if u.op is Ops.PARAM])
self.param_sz = sum([u.dtype.itemsize if u.addrspace is AddrSpace.ALU else 8 for u in uops if u.op is Ops.PARAM])
def tovec(b, idx_y, idx_x): return nalu(b, "vec4", idx_x, idx_y, nundef(b, dtypes.int), nundef(b, dtypes.int))
def nfloat(dtype): return mesa.nir_type_float16 if dtype == dtypes.half else mesa.nir_type_float32
@ -296,10 +296,11 @@ class IR3Renderer(NIRRenderer, OpenCLRenderer):
super().prerender(uops)
self.texs:set[UOp] = set()
self.img_idx = 0
self.param_sz = sum([8 if u.addrspace is not None else u.dtype.itemsize for u in uops if u.op is Ops.PARAM])
self.param_sz = sum([u.dtype.itemsize if u.addrspace is AddrSpace.ALU else 8 for u in uops if u.op is Ops.PARAM])
def postrender(self, uops:list[UOp]):
bufs, texs, imgs = [u for u in uops if u.op is Ops.PARAM and u.addrspace is not None], itertools.count().__next__, itertools.count().__next__
bufs = [u for u in uops if u.op is Ops.PARAM and u.addrspace is not AddrSpace.ALU]
texs, imgs = itertools.count().__next__, itertools.count().__next__
for b in filter(lambda b: isinstance(b.dtype, ImageDType), bufs): nimm_set(self.r[b], texs() if b in self.texs else imgs(), dtypes.int)
self.b.shader.contents.info.num_ubos = len([u for u in bufs if not isinstance(u.dtype, ImageDType)])

View file

@ -85,7 +85,7 @@ class PythonProgram:
i += 1
continue
if u.op is Ops.AFTER: values[u] = src_values[0]
elif u.op is Ops.PARAM and u.addrspace is None: values[u] = [pvals.pop(0)] * warp_size
elif u.op is Ops.PARAM and u.addrspace is AddrSpace.ALU: values[u] = [pvals.pop(0)] * warp_size
elif u.op in {Ops.PARAM, Ops.BUFFER}:
storage_fmt = storage_fmt_for_dtype(u.dtype.base.scalar())
if storage_fmt is None: raise RuntimeError(f"dtype={u.dtype} is not supported")

View file

@ -796,6 +796,7 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
if self.op is Ops.BUFFER: return self.arg.addrspace if isinstance(self.arg, ParamArg) else AddrSpace.GLOBAL
if self.op is Ops.DEFINE_LOCAL: return AddrSpace.LOCAL
if self.op is Ops.DEFINE_REG: return AddrSpace.REG
if self.op in {Ops.DEFINE_VAR, Ops.SPECIAL, Ops.RANGE}: return AddrSpace.ALU
if self.op is Ops.LOAD: return AddrSpace.ALU # LOAD brings things into the ALU
if self.op in {Ops.INDEX, Ops.CAST, Ops.AFTER, Ops.REDUCE, Ops.GEP, Ops.STORE, Ops.MSTACK, Ops.MSELECT}:
return self.src[0].addrspace
@ -907,7 +908,7 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
return UOp(Ops.DEFINE_VAR, dtype, arg=(name, min_val, max_val))
@property
def expr(self) -> str:
if self.op is Ops.PARAM and self.arg.addrspace is None: return unwrap(self.arg.name)
if self.op is Ops.PARAM: return unwrap(self.arg.name)
assert self.op is Ops.DEFINE_VAR, f"op is {self.op}, need DEFINE_VAR"
return self.arg[0]
def bind(self, val:int|UOp):
@ -924,7 +925,7 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
@property
def val(self) -> int: return self.unbind()[1]
def variables(self) -> list[Variable]:
return sorted({x for x in self.backward_slice_with_self if x.op is Ops.DEFINE_VAR or (x.op is Ops.PARAM and x.arg.addrspace is None)},
return sorted({x for x in self.backward_slice_with_self if x.op is Ops.DEFINE_VAR or (x.op is Ops.PARAM and x.arg.addrspace is AddrSpace.ALU)},
key=lambda v: v.expr)
# *** uop symbolic stuff ***
@ -1033,7 +1034,7 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
def _sym_fxn(self):
from tinygrad.uop.render import _render_with_splits, renderer_infer
sself = self.simplify()
varnames = tuple(dedup(x.expr for x in sself.toposort() if x.op is Ops.DEFINE_VAR or (x.op is Ops.PARAM and x.arg.addrspace is None)))
varnames = tuple(dedup(x.expr for x in sself.toposort() if x.op is Ops.DEFINE_VAR or (x.op is Ops.PARAM and x.arg.addrspace == AddrSpace.ALU)))
# TODO: sanitize varnames, or don't use naked eval while staying fast
ret = _render_with_splits(list(sself.toposort()), renderer_infer, {sself})
lines = [f" {k}={v}" for k,v in ret.items() if k != "ast"] + [f" return {ret['ast']}"]
@ -1149,8 +1150,8 @@ class ProgramInfo:
global_size: list[int] = [1, 1, 1]
local_size: list[int]|None = [1, 1, 1]
for u in sink.toposort():
if u.op is Ops.DEFINE_VAR or (u.op is Ops.PARAM and u.addrspace is None): _vars.append(u)
if u.op is Ops.PARAM and u.addrspace is not None: _globals.append(u.arg.slot)
if u.op is Ops.DEFINE_VAR or (u.op is Ops.PARAM and u.addrspace == AddrSpace.ALU): _vars.append(u)
if u.op is Ops.PARAM and u.addrspace != AddrSpace.ALU: _globals.append(u.arg.slot)
if u.op in (Ops.STORE, Ops.LOAD):
if (idx:=u.src[0]).op in (Ops.INDEX, Ops.SHRINK) or (u.src[0].op is Ops.CAST and (idx:=u.src[0].src[0]).op is Ops.INDEX):
if (buf:=idx.src[0]).op is Ops.PARAM: (outs if u.op is Ops.STORE else ins).append(buf.arg.slot)
@ -1160,7 +1161,7 @@ class ProgramInfo:
if special_size is not None: special_size[int(u.arg[-1])] = cast(int, u.src[0].ssimplify())
if u.op in (Ops.DEFINE_VAR, Ops.PARAM) and u in _vars and u.expr == 'core_id': global_size[0] = int(u.vmax) + 1
return ProgramInfo(sink.arg.name if isinstance(sink.arg, KernelInfo) else "test", tuple(global_size),
tuple(local_size) if local_size is not None else None, tuple(sorted(_vars, key=lambda v: v.expr)),
tuple(local_size) if local_size is not None else None, tuple(sorted(dedup(_vars), key=lambda v: v.arg.slot)),
tuple(sorted(dedup(_globals))), tuple(sorted(dedup(outs))), tuple(sorted(dedup(ins))), aux)
@dataclass(frozen=True)