mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
addrspace special/range (#16647)
* addrspace special/range * just include indexing * define var is alu * bring old ignore indexing back * mults to fix * fixes * ALU * fixes
This commit is contained in:
parent
1e08c0a07c
commit
aef85ddc4d
8 changed files with 26 additions and 18 deletions
|
|
@ -248,7 +248,7 @@ class TestTorchBackend(unittest.TestCase):
|
|||
samples = torch.randint(0, X_train.shape[0], (32,))
|
||||
X,Y = X_train[samples], Y_train[samples]
|
||||
X.cpu(), Y.cpu()
|
||||
self.assertLessEqual(GlobalCounters.global_ops, 20_000_000)
|
||||
self.assertLessEqual(GlobalCounters.global_ops, 25_000_000)
|
||||
|
||||
def _test_diagonal(self, *shape):
|
||||
a = torch.randn(*shape, dtype=torch.float32, device=device)
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ def custom_add_var(A:UOp, B:UOp) -> UOp:
|
|||
A,B = A.flatten(), B.flatten()
|
||||
assert A.dtype.base == dtypes.uint32, f"buffer dtype must be uint32, got {A.dtype}"
|
||||
threads = UOp.special(A.numel(), "lidx0")
|
||||
var = UOp.variable("var", 0, 10)
|
||||
var = UOp.param(2, dtypes.weakint, vmin_vmax=(0, 10), name="var", addrspace=AddrSpace.ALU)
|
||||
insts = [
|
||||
s_load_b128(s[4:7], s[0:1]),
|
||||
s_load_b32(s[8], s[0:1], offset=0x10), # all threads load the same variable
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import unittest, math
|
||||
from tinygrad import dtypes
|
||||
from tinygrad.dtype import AddrSpace
|
||||
from tinygrad.helpers import all_same, Context
|
||||
from tinygrad.uop.ops import GroupOp, UOp, Ops, exec_alu, PatternMatcher, TrackedPatternMatcher, UPat
|
||||
from test.helpers import full_rewrite
|
||||
|
|
@ -21,7 +22,7 @@ def apply_rewrite_values(expr):
|
|||
def evaluate_uop(uop, variables):
|
||||
if uop.op == Ops.CONST:
|
||||
return uop.arg
|
||||
elif uop.op == Ops.DEFINE_VAR or (uop.op == Ops.PARAM and uop.arg.addrspace is None):
|
||||
elif uop.op == Ops.DEFINE_VAR or (uop.op == Ops.PARAM and uop.arg.addrspace is AddrSpace.ALU):
|
||||
return variables[uop.expr]
|
||||
elif uop.op in GroupOp.ALU:
|
||||
src_values = [evaluate_uop(src, variables) for src in uop.src]
|
||||
|
|
|
|||
|
|
@ -47,7 +47,7 @@ pm_remove_vec_dtypes = PatternMatcher([
|
|||
x.replace(op=Ops.BUFFER, arg=ParamArg(x.arg, addrspace=AddrSpace.LOCAL if x.op == Ops.DEFINE_LOCAL else AddrSpace.REG))),
|
||||
# replace DEFINE_VAR with PARAM
|
||||
(UPat(Ops.DEFINE_VAR, name="x"), lambda x:
|
||||
x.replace(op=Ops.PARAM, src=(UOp(Ops.STACK),), arg=ParamArg(slot=-1, name=x.arg[0], vmin_vmax=x.arg[1:], addrspace=None))),
|
||||
x.replace(op=Ops.PARAM, src=(UOp(Ops.STACK),), arg=ParamArg(slot=-1, name=x.arg[0], vmin_vmax=x.arg[1:], addrspace=AddrSpace.ALU))),
|
||||
])+pm_clean_up_group_sink
|
||||
|
||||
def do_number_param(ctx:list[int], x:UOp):
|
||||
|
|
|
|||
|
|
@ -24,6 +24,11 @@ class Estimates:
|
|||
mem: dict[tuple[UOp, Ops], sint] = {}
|
||||
mults: sint = 1
|
||||
mult_stack: list[sint] = []
|
||||
excluded: set[UOp] = set()
|
||||
if ignore_indexing:
|
||||
for u in uops:
|
||||
if u.op in {Ops.INDEX, Ops.SHRINK}:
|
||||
excluded = excluded.union(set(UOp.sink(*u.src[1:]).toposort()))
|
||||
for u in uops:
|
||||
if u.op in {Ops.LOAD, Ops.STORE}:
|
||||
buf = u
|
||||
|
|
@ -39,14 +44,14 @@ class Estimates:
|
|||
mults = mults.substitute({x:x.const_like(0) for x in mults.toposort() if x.op is Ops.SPECIAL}) if isinstance(mults, UOp) else mults
|
||||
elif u.op is Ops.END: mults = mult_stack.pop(-1)
|
||||
elif u.op is Ops.SPECIAL: mults *= cast(sint, u.src[0].ssimplify()) # NOTE: we don't push to the mult_stack here, you can't end these
|
||||
elif u.op is Ops.PARAM and u.arg.addrspace is None and u.expr == 'core_id': mults *= int(u.vmax) + 1
|
||||
elif u.op is Ops.PARAM and u.arg.addrspace == AddrSpace.ALU and u.expr == 'core_id': mults *= int(u.vmax) + 1
|
||||
elif u.op is Ops.LOAD and u.src[0].addrspace != AddrSpace.REG:
|
||||
lds += u.max_numel() * u.dtype.scalar().itemsize * mults
|
||||
elif u.op is Ops.STORE and u.src[0].addrspace != AddrSpace.REG:
|
||||
lds += u.max_numel() * u.src[1].dtype.scalar().itemsize * mults
|
||||
elif u.op in GroupOp.ALU and (not ignore_indexing or u.addrspace is not None):
|
||||
elif u.op in GroupOp.ALU and u not in excluded:
|
||||
flops += (mults * (2 if u.op is Ops.MULACC else 1)) * u.max_numel()
|
||||
elif u.op is Ops.WMMA and (not ignore_indexing or u.addrspace is not None):
|
||||
elif u.op is Ops.WMMA and u not in excluded:
|
||||
flops += 2 * prod(u.arg[1]) // u.arg[5] * mults
|
||||
return Estimates(flops, lds, sum(mem.values()))
|
||||
|
||||
|
|
|
|||
|
|
@ -145,7 +145,7 @@ class NIRRenderer(Renderer):
|
|||
|
||||
def_rewrite = PatternMatcher([
|
||||
(UPat(Ops.CONST, name="x"), lambda ctx,x: nimm(ctx.b, x.arg, x.dtype)),
|
||||
(UPat(Ops.PARAM, name="x"), lambda ctx,x: ctx.param(ctx.b, x, 8 if x.addrspace is not None else x.dtype.itemsize)),
|
||||
(UPat(Ops.PARAM, name="x"), lambda ctx,x: ctx.param(ctx.b, x, x.dtype.itemsize if x.addrspace is AddrSpace.ALU else 8)),
|
||||
(UPat(Ops.SPECIAL, name="x"), lambda ctx,x: nchannel(ctx.b, {'g':ngid, 'l':nlid, 'i': nid}[x.arg[0]](ctx.b), int(x.arg[-1]))),
|
||||
(UPat(Ops.STORE, src=(UPat((Ops.INDEX, Ops.SHRINK), src=(UPat.var("buf"),UPat.var("off")), allow_any_len=True), UPat.var("val"))),
|
||||
lambda ctx,buf,off,val: nstore(ctx.b, buf.addrspace, nidx(ctx.b, ctx.r[buf], ctx.r[off], buf.addrspace, buf.dtype.itemsize), ctx.r[val])),
|
||||
|
|
@ -256,7 +256,7 @@ class LVPRenderer(NIRRenderer):
|
|||
|
||||
def prerender(self, uops:list[UOp]):
|
||||
super().prerender(uops)
|
||||
self.param_sz = sum([8 if u.addrspace is not None else u.dtype.itemsize for u in uops if u.op is Ops.PARAM])
|
||||
self.param_sz = sum([u.dtype.itemsize if u.addrspace is AddrSpace.ALU else 8 for u in uops if u.op is Ops.PARAM])
|
||||
|
||||
def tovec(b, idx_y, idx_x): return nalu(b, "vec4", idx_x, idx_y, nundef(b, dtypes.int), nundef(b, dtypes.int))
|
||||
def nfloat(dtype): return mesa.nir_type_float16 if dtype == dtypes.half else mesa.nir_type_float32
|
||||
|
|
@ -296,10 +296,11 @@ class IR3Renderer(NIRRenderer, OpenCLRenderer):
|
|||
super().prerender(uops)
|
||||
self.texs:set[UOp] = set()
|
||||
self.img_idx = 0
|
||||
self.param_sz = sum([8 if u.addrspace is not None else u.dtype.itemsize for u in uops if u.op is Ops.PARAM])
|
||||
self.param_sz = sum([u.dtype.itemsize if u.addrspace is AddrSpace.ALU else 8 for u in uops if u.op is Ops.PARAM])
|
||||
|
||||
def postrender(self, uops:list[UOp]):
|
||||
bufs, texs, imgs = [u for u in uops if u.op is Ops.PARAM and u.addrspace is not None], itertools.count().__next__, itertools.count().__next__
|
||||
bufs = [u for u in uops if u.op is Ops.PARAM and u.addrspace is not AddrSpace.ALU]
|
||||
texs, imgs = itertools.count().__next__, itertools.count().__next__
|
||||
for b in filter(lambda b: isinstance(b.dtype, ImageDType), bufs): nimm_set(self.r[b], texs() if b in self.texs else imgs(), dtypes.int)
|
||||
|
||||
self.b.shader.contents.info.num_ubos = len([u for u in bufs if not isinstance(u.dtype, ImageDType)])
|
||||
|
|
|
|||
|
|
@ -85,7 +85,7 @@ class PythonProgram:
|
|||
i += 1
|
||||
continue
|
||||
if u.op is Ops.AFTER: values[u] = src_values[0]
|
||||
elif u.op is Ops.PARAM and u.addrspace is None: values[u] = [pvals.pop(0)] * warp_size
|
||||
elif u.op is Ops.PARAM and u.addrspace is AddrSpace.ALU: values[u] = [pvals.pop(0)] * warp_size
|
||||
elif u.op in {Ops.PARAM, Ops.BUFFER}:
|
||||
storage_fmt = storage_fmt_for_dtype(u.dtype.base.scalar())
|
||||
if storage_fmt is None: raise RuntimeError(f"dtype={u.dtype} is not supported")
|
||||
|
|
|
|||
|
|
@ -796,6 +796,7 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
|
|||
if self.op is Ops.BUFFER: return self.arg.addrspace if isinstance(self.arg, ParamArg) else AddrSpace.GLOBAL
|
||||
if self.op is Ops.DEFINE_LOCAL: return AddrSpace.LOCAL
|
||||
if self.op is Ops.DEFINE_REG: return AddrSpace.REG
|
||||
if self.op in {Ops.DEFINE_VAR, Ops.SPECIAL, Ops.RANGE}: return AddrSpace.ALU
|
||||
if self.op is Ops.LOAD: return AddrSpace.ALU # LOAD brings things into the ALU
|
||||
if self.op in {Ops.INDEX, Ops.CAST, Ops.AFTER, Ops.REDUCE, Ops.GEP, Ops.STORE, Ops.MSTACK, Ops.MSELECT}:
|
||||
return self.src[0].addrspace
|
||||
|
|
@ -907,7 +908,7 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
|
|||
return UOp(Ops.DEFINE_VAR, dtype, arg=(name, min_val, max_val))
|
||||
@property
|
||||
def expr(self) -> str:
|
||||
if self.op is Ops.PARAM and self.arg.addrspace is None: return unwrap(self.arg.name)
|
||||
if self.op is Ops.PARAM: return unwrap(self.arg.name)
|
||||
assert self.op is Ops.DEFINE_VAR, f"op is {self.op}, need DEFINE_VAR"
|
||||
return self.arg[0]
|
||||
def bind(self, val:int|UOp):
|
||||
|
|
@ -924,7 +925,7 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
|
|||
@property
|
||||
def val(self) -> int: return self.unbind()[1]
|
||||
def variables(self) -> list[Variable]:
|
||||
return sorted({x for x in self.backward_slice_with_self if x.op is Ops.DEFINE_VAR or (x.op is Ops.PARAM and x.arg.addrspace is None)},
|
||||
return sorted({x for x in self.backward_slice_with_self if x.op is Ops.DEFINE_VAR or (x.op is Ops.PARAM and x.arg.addrspace is AddrSpace.ALU)},
|
||||
key=lambda v: v.expr)
|
||||
|
||||
# *** uop symbolic stuff ***
|
||||
|
|
@ -1033,7 +1034,7 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
|
|||
def _sym_fxn(self):
|
||||
from tinygrad.uop.render import _render_with_splits, renderer_infer
|
||||
sself = self.simplify()
|
||||
varnames = tuple(dedup(x.expr for x in sself.toposort() if x.op is Ops.DEFINE_VAR or (x.op is Ops.PARAM and x.arg.addrspace is None)))
|
||||
varnames = tuple(dedup(x.expr for x in sself.toposort() if x.op is Ops.DEFINE_VAR or (x.op is Ops.PARAM and x.arg.addrspace == AddrSpace.ALU)))
|
||||
# TODO: sanitize varnames, or don't use naked eval while staying fast
|
||||
ret = _render_with_splits(list(sself.toposort()), renderer_infer, {sself})
|
||||
lines = [f" {k}={v}" for k,v in ret.items() if k != "ast"] + [f" return {ret['ast']}"]
|
||||
|
|
@ -1149,8 +1150,8 @@ class ProgramInfo:
|
|||
global_size: list[int] = [1, 1, 1]
|
||||
local_size: list[int]|None = [1, 1, 1]
|
||||
for u in sink.toposort():
|
||||
if u.op is Ops.DEFINE_VAR or (u.op is Ops.PARAM and u.addrspace is None): _vars.append(u)
|
||||
if u.op is Ops.PARAM and u.addrspace is not None: _globals.append(u.arg.slot)
|
||||
if u.op is Ops.DEFINE_VAR or (u.op is Ops.PARAM and u.addrspace == AddrSpace.ALU): _vars.append(u)
|
||||
if u.op is Ops.PARAM and u.addrspace != AddrSpace.ALU: _globals.append(u.arg.slot)
|
||||
if u.op in (Ops.STORE, Ops.LOAD):
|
||||
if (idx:=u.src[0]).op in (Ops.INDEX, Ops.SHRINK) or (u.src[0].op is Ops.CAST and (idx:=u.src[0].src[0]).op is Ops.INDEX):
|
||||
if (buf:=idx.src[0]).op is Ops.PARAM: (outs if u.op is Ops.STORE else ins).append(buf.arg.slot)
|
||||
|
|
@ -1160,7 +1161,7 @@ class ProgramInfo:
|
|||
if special_size is not None: special_size[int(u.arg[-1])] = cast(int, u.src[0].ssimplify())
|
||||
if u.op in (Ops.DEFINE_VAR, Ops.PARAM) and u in _vars and u.expr == 'core_id': global_size[0] = int(u.vmax) + 1
|
||||
return ProgramInfo(sink.arg.name if isinstance(sink.arg, KernelInfo) else "test", tuple(global_size),
|
||||
tuple(local_size) if local_size is not None else None, tuple(sorted(_vars, key=lambda v: v.expr)),
|
||||
tuple(local_size) if local_size is not None else None, tuple(sorted(dedup(_vars), key=lambda v: v.arg.slot)),
|
||||
tuple(sorted(dedup(_globals))), tuple(sorted(dedup(outs))), tuple(sorted(dedup(ins))), aux)
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue