remove DEFINE_VAR try 2 (#16651)

* remove DEFINE_VAR try 2

* param

* null index

* fix fuzzing

* fixes

* no gather neg params

* param is just Irreducible

* fixes

* skip stack

* need to filter slots there
This commit is contained in:
George Hotz 2026-06-18 12:34:25 -07:00 committed by GitHub
commit 5989d0b150
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 36 additions and 31 deletions

View file

@ -13,7 +13,7 @@ if __name__ == "__main__":
print(f"Progress: {i}")
dt = random.choice(dtypes.ints + tuple(dt.vec(4) for dt in dtypes.ints))
u = UOp.variable('x', random.randint(dt.min, 0), random.randint(1, dt.max), dtype=dt)
d = random.randint(1, max(1, u.arg[2])*2)
d = random.randint(1, max(1, u.vmax)*2)
if d in powers_of_two: continue
expr = fast_idiv(Device[Device.DEFAULT].renderer, u, d)
if expr is None: continue

View file

@ -20,7 +20,7 @@ binary_ops = [lambda a,b: a+b, lambda a,b: a*b, lambda a,b:a.maximum(b), lambda
comp_ops = [operator.lt, operator.le, operator.gt, operator.ge]
def random_or_sub_expression_int(depth, expr):
sub_expr = random.choice([e for e in expr.toposort() if e.dtype is not dtypes.bool])
sub_expr = random.choice([e for e in expr.toposort() if e.dtype not in (dtypes.bool, dtypes.void)])
return random.choice([random_int_expr(depth-1), sub_expr])
def random_int_expr(depth=10):

View file

@ -171,7 +171,7 @@ class TestGraphRewrite(unittest.TestCase):
c2 = UOp.const(dtypes.float, 2.0)
nout = graph_rewrite(v+c1+c2, simple_pm)
self.assertEqual(nout.op, Ops.ADD)
self.assertEqual(nout.src[0].op, Ops.DEFINE_VAR)
self.assertEqual(nout.src[0].op, Ops.PARAM)
self.assertEqual(nout.src[1].op, Ops.CONST)
self.assertEqual(nout.src[1].arg, 3.0)

View file

@ -9,7 +9,7 @@ from tinygrad.uop.symbolic import sym, commutative, pm_simplify_valid, pm_move_w
from tinygrad.uop.validate import uops_to_z3
def check_uop_against_string(self, v:UOp, s:str):
sym_vars = {v.render():v for v in v.toposort() if v.op in (Ops.DEFINE_VAR, Ops.RANGE, Ops.SPECIAL)}
sym_vars = {v.render():v for v in v.toposort() if v.op in (Ops.DEFINE_VAR, Ops.RANGE, Ops.SPECIAL, Ops.PARAM)}
s_eval = eval(s, sym_vars)
if isinstance(s_eval, int) and v.dtype==dtypes.weakint: s_eval = UOp.const(dtypes.weakint, s_eval)
elif isinstance(s_eval, (bool, int, float)): s_eval = UOp.const(dtypes.from_py(s_eval), s_eval)

View file

@ -226,7 +226,7 @@ class TestViz(unittest.TestCase):
self.assertEqual(len(lst), 1)
graphs = [x["graph"] for x in viz.get_details(0, 0)]
# const is always in the graph, client side hides exclude=True nodes by default
self.assertEqual(list(graphs[0]), [id(a), id(z), id(alu), id(y), id(sink)])
self.assertEqual(list(graphs[0]), [id(a.src[0]), id(a), id(z), id(alu), id(y), id(sink)])
self.assertTrue(graphs[0][id(z)]["exclude"])
self.assertTrue(graphs[0][id(y)]["exclude"])
self.assertFalse(graphs[0][id(alu)]["exclude"])
@ -267,7 +267,7 @@ class TestViz(unittest.TestCase):
# VIZ displays nested graph_rewrites in a tree view
def leaf_rewrite(x:UOp): return x.rtag(1) if x.tag is None else None
leaf = TrackedPatternMatcher([(UPat(Ops.DEFINE_VAR, name="x"), leaf_rewrite)])
leaf = TrackedPatternMatcher([(UPat(Ops.PARAM, name="x"), leaf_rewrite)])
def branch_rewrite(x:UOp, y:UOp):
if x.tag is not None: return

View file

@ -183,7 +183,7 @@ def finalize_after(ctx:AllocCtx, x:UOp):
def replace_input_buffer(ctx:AllocCtx, b:UOp):
ctx.replacements.append(b)
return UOp.param(len(ctx.replacements)-1, b.dtype, b.shape, b.device,
b._min_max if b.op is Ops.BIND else None, b.src[0].arg[0] if b.op is Ops.BIND else None,
b._min_max if b.op is Ops.BIND else None, b.src[0].expr if b.op is Ops.BIND else None,
b.addrspace if isinstance(b.dtype, (PtrDType, ImageDType)) else AddrSpace.GLOBAL)
pm_finalize_call = PatternMatcher([
@ -197,7 +197,7 @@ pm_replace_buf = PatternMatcher([
# replace SLICE with PARAM. this rewrite is bottom up so BUFFERs we don't need won't be in the input
(UPat(Ops.SLICE, src=(UPat(Ops.BUFFER), UPat(Ops.CONST, dtype=dtypes.weakint)), name="b"), replace_input_buffer),
# strip value from BIND for cache key normalization, so different values hit same cache
(UPat(Ops.BIND, src=(UPat(Ops.DEFINE_VAR), UPat(Ops.CONST)), name="b"), replace_input_buffer),
(UPat(Ops.BIND, src=(UPat(Ops.PARAM), UPat(Ops.CONST)), name="b"), replace_input_buffer),
])
@track_rewrites(lambda _,ret: f"Callify {pluralize('Buffer', len(ret[1]))}")

View file

@ -328,7 +328,7 @@ class Scheduler:
def group_for_reduces(self) -> int: return len(self.axes_of(AxisType.GROUP_REDUCE))
def bufs_from_ast(ast:UOp, dname:str) -> list[Buffer]:
glbls = sorted([x for x in ast.backward_slice if x.op is Ops.PARAM], key=lambda x: x.arg.slot)
glbls = sorted([x for x in ast.backward_slice if x.op is Ops.PARAM and x.arg.slot >= 0], key=lambda x: x.arg.slot)
return [Buffer(dname, x.max_numel(), x.dtype.base) for x in glbls]
def apply_opts(ast:UOp, ren:Renderer, beam:int=0) -> UOp:

View file

@ -113,7 +113,7 @@ pm_reduce_collapse = pm_reduce_unparented + PatternMatcher([
((UPat.var("x")+UPat.var("y")).reduce(arg=Ops.ADD, allow_any_len=True, name="r"),
lambda x,y,r: x.reduce(*r.src[1:], arg=Ops.ADD) + y.reduce(*r.src[1:],arg=Ops.ADD)),
# AND on WHERE
((UPat(Ops.DEFINE_VAR, name="x") & UPat.var("y")).where(UPat.var("c"), 0).reduce(arg=Ops.ADD, allow_any_len=True, name="r"),
((UPat(Ops.PARAM, name="x") & UPat.var("y")).where(UPat.var("c"), 0).reduce(arg=Ops.ADD, allow_any_len=True, name="r"),
lambda x,y,c,r: y.where(c, 0).reduce(*r.src[1:], arg=Ops.ADD)*x.cast(c.dtype)),
# MUL casted bool
((UPat.var("x") * UPat.var("gate", dtype=dtypes.bool).cast()), lambda x,gate: gate.where(x, 0)),

View file

@ -127,7 +127,7 @@ mop_cleanup = PatternMatcher([
(UPat(Ops.RESHAPE, src=(UPat(Ops.RESHAPE, name="x2"), UPat()), name="x"), lambda x,x2: x.replace(src=(x2.src[0], x.src[1]))),
])
pm_gather_params = PatternMatcher([ (UPat(Ops.PARAM, name="p"), lambda ctx, p: ctx.append(p)), ])
pm_gather_params = PatternMatcher([ (UPat(Ops.PARAM, name="p"), lambda ctx, p: ctx.append(p) if p.arg.slot >= 0 else None), ])
def resolve_function(c:UOp, allow_param_mismatch=True) -> UOp|None:
if c.arg.precompile: return None
params: list[UOp] = []
@ -533,7 +533,9 @@ to_define_global = PatternMatcher([
if v.arg.name is not None and v.arg.vmin_vmax is not None else None),
(UPat(Ops.PARAM, name="buf"), lambda ctx, buf:
None if isinstance(buf.dtype, PtrDType) or buf.arg.name is not None or buf._shape is None else debuf(ctx, buf)),
(UPat(Ops.INDEX, src=(UPat(Ops.DEFINE_VAR, name="v"),)), lambda v: v),
# this was DEFINE_VAR, clean this up and make it universal
(UPat(Ops.INDEX, src=(UPat(Ops.PARAM, name="v"),)), lambda v: v if v.addrspace == AddrSpace.ALU else None),
(UPat(Ops.BIND, name="b"), unbind_kernel),
(UPat(Ops.AFTER, name="after"), handle_after),

View file

@ -127,7 +127,7 @@ class GroupOp:
Defines = {Ops.PARAM, Ops.DEFINE_LOCAL, Ops.DEFINE_REG}
Irreducible = {Ops.CONST, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.RANGE}
Irreducible = {Ops.CONST, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.RANGE, Ops.PARAM}
Movement = {Ops.RESHAPE, Ops.EXPAND, Ops.PERMUTE, Ops.PAD, Ops.SHRINK, Ops.FLIP}
# BinaryOps that can be flipped

View file

@ -903,21 +903,20 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
# *** uop Variable stuff ***
@staticmethod
def variable(name:str, min_val:ConstType, max_val:ConstType, dtype:DType=dtypes.weakint) -> UOp:
assert not isinstance(min_val, UOp) and not isinstance(max_val, UOp), f"can't create Variable {name} with {min_val}/{max_val}"
return UOp(Ops.DEFINE_VAR, dtype, arg=(name, min_val, max_val))
def variable(name:str, min_val:PyConst, max_val:PyConst, dtype:DType=dtypes.weakint) -> UOp:
return UOp(Ops.PARAM, dtype, src=(shape_to_shape_arg((dtype.count,) if dtype.count > 1 else ()),),
arg=ParamArg(-1, name=name, vmin_vmax=(min_val, max_val), addrspace=AddrSpace.ALU))
@property
def expr(self) -> str:
if self.op is Ops.PARAM: return unwrap(self.arg.name)
assert self.op is Ops.DEFINE_VAR, f"op is {self.op}, need DEFINE_VAR"
return self.arg[0]
assert self.op is Ops.PARAM
return unwrap(self.arg.name)
def bind(self, val:int|UOp):
assert self.op is Ops.DEFINE_VAR, f"op is {self.op}, need DEFINE_VAR"
assert self.op is Ops.PARAM and self.addrspace is AddrSpace.ALU, f"op is {self.op}, need PARAM"
uval = self.const_like(val) if isinstance(val, int) else val
assert self.arg[1] <= uval.vmin and uval.vmax <= self.arg[2], f"bind {val} not in range [{self.arg[1]}, {self.arg[2]}]"
assert self.vmin <= uval.vmin and uval.vmax <= self.vmax, f"bind {val} not in range [{self.vmin}, {self.vmax}]"
return UOp(Ops.BIND, self.dtype, (self, uval))
def unbind(self) -> tuple[Variable, int]:
assert self.op is Ops.BIND and self.src[0].op is Ops.DEFINE_VAR and self.src[1].op is Ops.CONST, f"can't unbind {self}"
assert self.op is Ops.BIND and self.src[0].op is Ops.PARAM and self.src[1].op is Ops.CONST, f"can't unbind {self}"
return self.src[0], self.src[1].arg
def unbind_all(self) -> tuple[UOp, dict[Variable, int]]:
ret:dict[Variable, int] = {}
@ -1086,7 +1085,7 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
def param_like(self, slot:int):
addrspace = self.addrspace if isinstance(self.dtype, (PtrDType, ImageDType)) else AddrSpace.GLOBAL
if self.op is Ops.BIND:
return UOp.param(slot, self.dtype, self._shape, self.device, cast(tuple[int, int], self._min_max), self.src[0].arg[0], addrspace)
return UOp.param(slot, self.dtype, self._shape, self.device, cast(tuple[int, int], self._min_max), self.src[0].expr, addrspace)
return UOp.param(slot, self.dtype, self.shard_shape if self.axis is not None else self._shape, self.device, addrspace=addrspace, axis=self.axis)
# opaque bodies stay as Ops.CALL; value-producing bodies become Ops.FUNCTION (wrapped in TUPLE)
@ -1675,7 +1674,8 @@ pm_lower_index_dtype = PatternMatcher([
# special can only be int32
(UPat(Ops.SPECIAL, src=(UPat.var("var").cast(dtypes.weakint),), name="u"),
lambda u,var: u.replace(dtype=dtypes.int, src=(var,)).cast(dtypes.weakint)),
(UPat(Ops.DEFINE_VAR, dtype=dtypes.weakint, name="u"), lambda u: u.replace(dtype=dtypes.int).cast(dtypes.weakint)),
(UPat(Ops.PARAM, dtype=dtypes.weakint, name="u"),
lambda u: u.replace(dtype=dtypes.int).cast(dtypes.weakint) if u.addrspace == AddrSpace.ALU else None),
(UPat(Ops.BIND, src=(UPat.var("var").cast(dtypes.weakint), UPat.cvar("val").cast(dtypes.weakint))),
lambda var,val: var.bind(val).cast(dtypes.weakint)),
# remove hanging casts

View file

@ -24,7 +24,8 @@ def validate_index(uidx:UOp, gate:UOp|None=None):
# TODO: validate these
# WEBGPU has a BITCAST in the index, PTX casts pointer to long
# VECTORIZE/GEP can't be properly modeled in z3 since it doesn't support vectors
for x in idx.toposort() | gate.toposort():
# don't descend into PARAM shape metadata; only the PARAM value participates in index arithmetic
for x in idx.toposort(gate=lambda x: x.op is not Ops.PARAM) | gate.toposort(gate=lambda x: x.op is not Ops.PARAM):
if x.op in {Ops.BITCAST, Ops.STACK, Ops.GEP} or (x.op is Ops.CAST and isinstance(x.src[0].dtype, PtrDType)): return True
# if all is good and CHECK_OOB=1, validate with z3
@ -136,7 +137,7 @@ spec_tensor = PatternMatcher([
lambda buf: isinstance(buf.arg, int) and isinstance(buf.dtype, DType)),
# Tensor variable bindings
(UPat(Ops.BIND, (dtypes.int, dtypes.weakint,), (UPat(Ops.DEFINE_VAR), UPat.cvar(dtype=(dtypes.int,dtypes.weakint,))), arg=None), lambda: True),
(UPat(Ops.BIND, (dtypes.int, dtypes.weakint,), (UPat(Ops.PARAM), UPat.cvar(dtype=(dtypes.int,dtypes.weakint,))), arg=None), lambda: True),
# custom function
(UPat(Ops.CUSTOM_FUNCTION, name="x"), lambda x: isinstance(x.arg, str)),

View file

@ -2,7 +2,7 @@
import math, struct
from collections import defaultdict
from tinygrad.uop.ops import Ops, PatternMatcher, UPat, UOp, GroupOp, exec_alu
from tinygrad.dtype import ConstType, dtypes, PtrDType, can_lossless_cast, Invalid
from tinygrad.dtype import PyConst, ConstType, dtypes, PtrDType, can_lossless_cast, Invalid
from tinygrad.helpers import partition, all_same, prod, flatten, get_single_element, unwrap, IMAGE, dedup
from tinygrad.uop.decompositions import threefry2x32, xpow
from tinygrad.uop.divandmod import div_and_mod_symbolic
@ -259,7 +259,7 @@ symbolic = symbolic_simple+commutative+PatternMatcher([
((UPat.var("y")+UPat.var("c").where(UPat.var("t"), UPat.var("f"))) + UPat.var("c").where(UPat.var("tt"), UPat.var("ff")), \
lambda y,c,t,tt,f,ff: y+c.where(t+tt, f+ff) if t.op == tt.op == Ops.CONST or f.op == ff.op == Ops.CONST else None),
# ALU/variable min==max -> CONST
(UPat({Ops.CMPLT, Ops.CMPNE, Ops.FLOORDIV, Ops.FLOORMOD, Ops.DEFINE_VAR, Ops.BIND, Ops.SPECIAL}, name="x"),
(UPat({Ops.CMPLT, Ops.CMPNE, Ops.FLOORDIV, Ops.FLOORMOD, Ops.DEFINE_VAR, Ops.PARAM, Ops.BIND, Ops.SPECIAL}, name="x"),
lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None),
(UPat(Ops.RANGE, src=(UPat(Ops.CONST,)), name="x"), lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None),
# max folding
@ -332,7 +332,7 @@ def uop_given_valid(valid:UOp, uop:UOp, try_simplex=True) -> UOp:
# return simplified uop (might be the same as input)
# first, parse valid into {expr: (lower_bound, upper_bound)}
bounds:defaultdict[UOp, list[ConstType|None]] = defaultdict(lambda: [None, None])
bounds:defaultdict[UOp, list[PyConst|None]] = defaultdict(lambda: [None, None])
for stmt in valid.split_uop(Ops.AND):
if (res:=parse_valid(stmt)) is None: continue
expr, is_upper, c = res

View file

@ -28,7 +28,7 @@ z3_renderer = PatternMatcher([
(UPat.var("cond").where(UPat.var("x"), UPat.const(dtypes.weakint, Invalid)), lambda x,cond,ctx: (ctx[1][x], ctx[1][cond])),
# variables
(UPat(Ops.SPECIAL, name="x"), lambda x,ctx: create_bounded(x.arg, 0, ctx[1][x.src[0]]-1, ctx[0])),
(UPat(Ops.DEFINE_VAR, name="x"), lambda x,ctx: create_bounded(x.arg[0], x.arg[1], x.arg[2], ctx[0])),
(UPat(Ops.PARAM, name="x"), lambda x,ctx: create_bounded(x.arg.name, x.vmin, x.vmax, ctx[0])),
(UPat(Ops.RANGE, name="x"), lambda x,ctx: create_bounded(x.render(simplify=False), 0, ctx[1][x.src[0]]-1, ctx[0])),
# loads are variables bounded by the min/max of the dtype. non-pointer INDEX is also a LOAD
(UPat((Ops.LOAD, Ops.INDEX), dtypes.ints+(dtypes.weakint,), name="x"), lambda x,ctx:
@ -52,10 +52,12 @@ z3_renderer = PatternMatcher([
def uops_to_z3(solver:z3.Solver, *uops: UOp) -> list[z3.ExprRef]:
# gate on upstream AFTER/BUFFER as a replacement for PtrDType, but keep INDEX as an unknown LOAD
lst = list(UOp.sink(*uops).toposort(gate=lambda x: x.op not in {Ops.AFTER, Ops.BUFFER, Ops.PARAM} and \
lst = list(UOp.sink(*uops).toposort(gate=lambda x: x.op not in {Ops.AFTER, Ops.BUFFER} and \
(x.dtype.scalar() in dtypes.ints+(dtypes.bool, dtypes.weakint) or x.op is Ops.SINK)))[:-1]
z3map: dict[UOp, z3.ExprRef] = {}
for u in lst:
# NOTE: we skip STACK here, it can't actually be accessed
if u.op is Ops.STACK: continue
z3_rewritten = z3_renderer.rewrite(u, ctx=(solver.ctx, z3map))
if z3_rewritten is None: raise NotImplementedError(f"{u.op} is not supported by z3")
new_u, constraint = cast(tuple[z3.ArithRef, z3.BoolRef|None], z3_rewritten)