Compare commits

...

8 commits

Author SHA1 Message Date
George Hotz
eee8d261e3
Merge branch 'master' into remove_define_var 2026-06-17 17:50:53 -07:00
George Hotz
f1a0643b15
Merge branch 'master' into remove_define_var 2026-06-17 16:40:59 -07:00
George Hotz
b8c70eb3e9
Merge branch 'master' into remove_define_var 2026-06-17 15:58:57 -07:00
George Hotz
0c3fdda48b work 2026-06-17 11:29:59 -07:00
George Hotz
28bd22db0b passing? 2026-06-17 11:03:30 -07:00
George Hotz
1c7cc0a379 do things pass 2026-06-17 09:30:37 -07:00
George Hotz
123585b93c passing 2026-06-17 09:07:19 -07:00
George Hotz
27fe5e23a8 remove DEFINE_VAR 2026-06-17 08:51:11 -07:00
13 changed files with 53 additions and 34 deletions

View file

@ -13,7 +13,7 @@ if __name__ == "__main__":
print(f"Progress: {i}")
dt = random.choice(dtypes.ints + tuple(dt.vec(4) for dt in dtypes.ints))
u = UOp.variable('x', random.randint(dt.min, 0), random.randint(1, dt.max), dtype=dt)
d = random.randint(1, max(1, u.arg[2])*2)
d = random.randint(1, max(1, u.vmax)*2)
if d in powers_of_two: continue
expr = fast_idiv(Device[Device.DEFAULT].renderer, u, d)
if expr is None: continue

View file

@ -20,7 +20,7 @@ binary_ops = [lambda a,b: a+b, lambda a,b: a*b, lambda a,b:a.maximum(b), lambda
comp_ops = [operator.lt, operator.le, operator.gt, operator.ge]
def random_or_sub_expression_int(depth, expr):
sub_expr = random.choice([e for e in expr.toposort() if e.dtype is not dtypes.bool])
sub_expr = random.choice([e for e in expr.toposort() if e.dtype not in (dtypes.bool, dtypes.void)])
return random.choice([random_int_expr(depth-1), sub_expr])
def random_int_expr(depth=10):

View file

@ -171,7 +171,7 @@ class TestGraphRewrite(unittest.TestCase):
c2 = UOp.const(dtypes.float, 2.0)
nout = graph_rewrite(v+c1+c2, simple_pm)
self.assertEqual(nout.op, Ops.ADD)
self.assertEqual(nout.src[0].op, Ops.DEFINE_VAR)
self.assertEqual(nout.src[0].op, Ops.PARAM)
self.assertEqual(nout.src[1].op, Ops.CONST)
self.assertEqual(nout.src[1].arg, 3.0)

View file

@ -9,7 +9,8 @@ from tinygrad.uop.symbolic import sym, commutative, pm_simplify_valid, pm_move_w
from tinygrad.uop.validate import uops_to_z3
def check_uop_against_string(self, v:UOp, s:str):
sym_vars = {v.render():v for v in v.toposort() if v.op in (Ops.DEFINE_VAR, Ops.RANGE, Ops.SPECIAL)}
sym_vars = {v.render():v for v in v.toposort()
if v.op in (Ops.DEFINE_VAR, Ops.RANGE, Ops.SPECIAL) or (v.op is Ops.PARAM and v.arg.vmin_vmax is not None)}
s_eval = eval(s, sym_vars)
if isinstance(s_eval, int) and v.dtype==dtypes.weakint: s_eval = UOp.const(dtypes.weakint, s_eval)
elif isinstance(s_eval, (bool, int, float)): s_eval = UOp.const(dtypes.from_py(s_eval), s_eval)
@ -1296,8 +1297,8 @@ class TestMoveWhereOnLoad(unittest.TestCase):
# cond has a range that the rewrite can move into the valid: gate (a<4) goes into load valid
cond = (a < 4) & (r < 2)
valid = (a < 2) # pre-existing valid on the load (to pass can_move check for the r-only clause)
idx = buf.index(a.valid(valid), ptr=True)
expr = cond.where(idx, 0)
idx = buf.index(a.valid(valid))
expr = cond.where(idx, False)
out = graph_rewrite(expr, pm_move_where_on_load)
# any WHERE in the rewritten graph must have matched-dtype branches
for u in out.toposort():

View file

@ -224,7 +224,7 @@ class TestViz(unittest.TestCase):
self.assertEqual(len(lst), 1)
graphs = [x["graph"] for x in viz.get_details(0, 0)]
# const is always in the graph, client side hides exclude=True nodes by default
self.assertEqual(list(graphs[0]), [id(a), id(z), id(alu), id(y), id(sink)])
self.assertEqual(list(graphs[0]), [id(a.src[0]), id(a), id(z), id(alu), id(y), id(sink)])
self.assertTrue(graphs[0][id(z)]["exclude"])
self.assertTrue(graphs[0][id(y)]["exclude"])
self.assertFalse(graphs[0][id(alu)]["exclude"])
@ -265,7 +265,7 @@ class TestViz(unittest.TestCase):
# VIZ displays nested graph_rewrites in a tree view
def leaf_rewrite(x:UOp): return x.rtag(1) if x.tag is None else None
leaf = TrackedPatternMatcher([(UPat(Ops.DEFINE_VAR, name="x"), leaf_rewrite)])
leaf = TrackedPatternMatcher([(UPat(Ops.PARAM, name="x"), leaf_rewrite)])
def branch_rewrite(x:UOp, y:UOp):
if x.tag is not None: return

View file

@ -183,7 +183,7 @@ def finalize_after(ctx:AllocCtx, x:UOp):
def replace_input_buffer(ctx:AllocCtx, b:UOp):
ctx.replacements.append(b)
return UOp.param(len(ctx.replacements)-1, b.dtype, b.shape, b.device,
b._min_max if b.op is Ops.BIND else None, b.src[0].arg[0] if b.op is Ops.BIND else None,
b._min_max if b.op is Ops.BIND else None, b.src[0].expr if b.op is Ops.BIND else None,
b.addrspace if isinstance(b.dtype, (PtrDType, ImageDType)) else AddrSpace.GLOBAL)
pm_finalize_call = PatternMatcher([
@ -197,7 +197,7 @@ pm_replace_buf = PatternMatcher([
# replace SLICE with PARAM. this rewrite is bottom up so BUFFERs we don't need won't be in the input
(UPat(Ops.SLICE, src=(UPat(Ops.BUFFER), UPat(Ops.CONST, dtype=dtypes.weakint)), name="b"), replace_input_buffer),
# strip value from BIND for cache key normalization, so different values hit same cache
(UPat(Ops.BIND, src=(UPat(Ops.DEFINE_VAR), UPat(Ops.CONST)), name="b"), replace_input_buffer),
(UPat(Ops.BIND, src=(UPat(Ops.PARAM), UPat(Ops.CONST)), name="b"), replace_input_buffer),
])
@track_rewrites(lambda _,ret: f"Callify {pluralize('Buffer', len(ret[1]))}")

View file

@ -113,8 +113,8 @@ pm_reduce_collapse = pm_reduce_unparented + PatternMatcher([
((UPat.var("x")+UPat.var("y")).reduce(arg=Ops.ADD, allow_any_len=True, name="r"),
lambda x,y,r: x.reduce(*r.src[1:], arg=Ops.ADD) + y.reduce(*r.src[1:],arg=Ops.ADD)),
# AND on WHERE
((UPat(Ops.DEFINE_VAR, name="x") & UPat.var("y")).where(UPat.var("c"), 0).reduce(arg=Ops.ADD, allow_any_len=True, name="r"),
lambda x,y,c,r: y.where(c, 0).reduce(*r.src[1:], arg=Ops.ADD)*x.cast(c.dtype)),
((UPat((Ops.DEFINE_VAR, Ops.PARAM), name="x") & UPat.var("y")).where(UPat.var("c"), 0).reduce(arg=Ops.ADD, allow_any_len=True, name="r"),
lambda x,y,c,r: y.where(c, 0).reduce(*r.src[1:], arg=Ops.ADD)*x.cast(c.dtype) if x.op is Ops.DEFINE_VAR or x.arg.vmin_vmax is not None else None),
# MUL casted bool
((UPat.var("x") * UPat.var("gate", dtype=dtypes.bool).cast()), lambda x,gate: gate.where(x, 0)),
])+symbolic

View file

@ -1,7 +1,7 @@
import math, functools, operator
from typing import TYPE_CHECKING, Literal, Self
from tinygrad.uop import Ops
from tinygrad.dtype import dtypes, ConstType, PyConst, least_upper_dtype, least_upper_float
from tinygrad.dtype import dtypes, ConstType, PyConst, least_upper_dtype, least_upper_float, Invalid
from tinygrad.helpers import argfix, polyN
from tinygrad.mixin.dtype import DTypeMixin
from tinygrad.mixin.creation import CreationMixin
@ -416,7 +416,10 @@ class ElementwiseMixin(DTypeMixin, CreationMixin):
def where(self, x: Self | ConstType, y: Self | ConstType) -> Self:
ref: Self = x if isinstance(x, type(self)) else y if isinstance(y, type(self)) else \
self.cast(least_upper_dtype(dtypes.from_py(x), dtypes.from_py(y)))
return self.alu(Ops.WHERE, ref.ufix(x), ref.ufix(y))
fx, fy = ref.ufix(x), ref.ufix(y)
if getattr(fx, "op", None) is Ops.CONST and getattr(fx, "arg", None) is Invalid and fx.dtype != fy.dtype: fx = fy.ufix(Invalid)
if getattr(fy, "op", None) is Ops.CONST and getattr(fy, "arg", None) is Invalid and fy.dtype != fx.dtype: fy = fx.ufix(Invalid)
return self.alu(Ops.WHERE, fx, fy)
def masked_fill(self, mask:Self, value:Self|PyConst) -> Self:
"""

View file

@ -127,7 +127,7 @@ mop_cleanup = PatternMatcher([
(UPat(Ops.RESHAPE, src=(UPat(Ops.RESHAPE, name="x2"), UPat()), name="x"), lambda x,x2: x.replace(src=(x2.src[0], x.src[1]))),
])
pm_gather_params = PatternMatcher([ (UPat(Ops.PARAM, name="p"), lambda ctx, p: ctx.append(p)), ])
pm_gather_params = PatternMatcher([ (UPat(Ops.PARAM, name="p"), lambda ctx, p: ctx.append(p) if p.arg.slot >= 0 else None), ])
def resolve_function(c:UOp, allow_param_mismatch=True) -> UOp|None:
if c.arg.precompile: return None
params: list[UOp] = []
@ -533,7 +533,8 @@ to_define_global = PatternMatcher([
if v.arg.name is not None and v.arg.vmin_vmax is not None else None),
(UPat(Ops.PARAM, name="buf"), lambda ctx, buf:
None if isinstance(buf.dtype, PtrDType) or buf.arg.name is not None or buf._shape is None else debuf(ctx, buf)),
(UPat(Ops.INDEX, src=(UPat(Ops.DEFINE_VAR, name="v"),)), lambda v: v),
(UPat(Ops.INDEX, src=(UPat((Ops.DEFINE_VAR, Ops.PARAM), name="v"),)),
lambda v: v if v.op is Ops.DEFINE_VAR or v.arg.vmin_vmax is not None else None),
(UPat(Ops.BIND, name="b"), unbind_kernel),
(UPat(Ops.AFTER, name="after"), handle_after),

View file

@ -901,21 +901,22 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
# *** uop Variable stuff ***
@staticmethod
def variable(name:str, min_val:ConstType, max_val:ConstType, dtype:DType=dtypes.weakint) -> UOp:
assert not isinstance(min_val, UOp) and not isinstance(max_val, UOp), f"can't create Variable {name} with {min_val}/{max_val}"
return UOp(Ops.DEFINE_VAR, dtype, arg=(name, min_val, max_val))
def variable(name:str, min_val:PyConst, max_val:PyConst, dtype:DType=dtypes.weakint) -> UOp:
return UOp(Ops.PARAM, dtype, src=(shape_to_shape_arg((dtype.count,) if dtype.count > 1 else ()),),
arg=ParamArg(-1, name=name, vmin_vmax=(min_val, max_val), addrspace=None))
@property
def expr(self) -> str:
if self.op is Ops.PARAM: return unwrap(self.arg.name)
assert self.op is Ops.DEFINE_VAR, f"op is {self.op}, need DEFINE_VAR"
return self.arg[0]
def bind(self, val:int|UOp):
assert self.op is Ops.DEFINE_VAR, f"op is {self.op}, need DEFINE_VAR"
assert self.op is Ops.PARAM and self.addrspace is None, f"op is {self.op}, need DEFINE_VAR"
uval = self.const_like(val) if isinstance(val, int) else val
assert self.arg[1] <= uval.vmin and uval.vmax <= self.arg[2], f"bind {val} not in range [{self.arg[1]}, {self.arg[2]}]"
assert self.arg.vmin_vmax[0] <= uval.vmin and uval.vmax <= self.arg.vmin_vmax[1], f"bind {val} not in range [{self.arg[1]}, {self.arg[2]}]"
return UOp(Ops.BIND, self.dtype, (self, uval))
def unbind(self) -> tuple[Variable, int]:
assert self.op is Ops.BIND and self.src[0].op is Ops.DEFINE_VAR and self.src[1].op is Ops.CONST, f"can't unbind {self}"
assert self.op is Ops.BIND and self.src[0].op is Ops.PARAM and self.src[0].arg.vmin_vmax is not None and \
self.src[1].op is Ops.CONST, f"can't unbind {self}"
return self.src[0], self.src[1].arg
def unbind_all(self) -> tuple[UOp, dict[Variable, int]]:
ret:dict[Variable, int] = {}
@ -930,6 +931,7 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
def is_increasing(self:UOp) -> bool:
# is f a monotonically increasing function regards its input
if self.op is Ops.PARAM and self.arg.vmin_vmax is not None: return True
if self.op in GroupOp.Irreducible: return True
if self.op is Ops.ADD: return self.src[0].is_increasing() and self.src[1].is_increasing()
if self.op in (Ops.MUL, Ops.CDIV, Ops.FLOORDIV) and self.src[1].op is Ops.CONST and self.src[1].arg >= 0: return self.src[0].is_increasing()
@ -1020,7 +1022,7 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
if self.op is Ops.DEFINE_VAR and self.arg: return self.arg[1], self.arg[2]
if self.op in (Ops.RANGE, Ops.SPECIAL): return 0, (self.src[0]-1).vmax
if self.op is Ops.BIND: return self.src[0]._min_max # ignore the bound value
if self.op in {Ops.UNROLL, Ops.STACK}: return min(x.vmin for x in self.src), max(x.vmax for x in self.src)
if self.op in {Ops.UNROLL, Ops.STACK}: return (0, 0) if len(self.src) == 0 else (min(x.vmin for x in self.src), max(x.vmax for x in self.src))
if self.op is Ops.CONST and self.arg is not Invalid: return self.arg, self.arg
if self.op is Ops.GEP: return self.src[0]._min_max
# TODO: CAST to bool/unsigned is not monotone, still some case can be simplified
@ -1084,7 +1086,7 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
def param_like(self, slot:int):
addrspace = self.addrspace if isinstance(self.dtype, (PtrDType, ImageDType)) else AddrSpace.GLOBAL
if self.op is Ops.BIND:
return UOp.param(slot, self.dtype, self._shape, self.device, cast(tuple[int, int], self._min_max), self.src[0].arg[0], addrspace)
return UOp.param(slot, self.dtype, self._shape, self.device, cast(tuple[int, int], self._min_max), self.src[0].expr, addrspace)
return UOp.param(slot, self.dtype, self.shard_shape if self.axis is not None else self._shape, self.device, addrspace=addrspace, axis=self.axis)
# opaque bodies stay as Ops.CALL; value-producing bodies become Ops.FUNCTION (wrapped in TUPLE)
@ -1672,7 +1674,8 @@ pm_lower_index_dtype = PatternMatcher([
# special can only be int32
(UPat(Ops.SPECIAL, src=(UPat.var("var").cast(dtypes.weakint),), name="u"),
lambda u,var: u.replace(dtype=dtypes.int, src=(var,)).cast(dtypes.weakint)),
(UPat(Ops.DEFINE_VAR, dtype=dtypes.weakint, name="u"), lambda u: u.replace(dtype=dtypes.int).cast(dtypes.weakint)),
(UPat(Ops.PARAM, dtype=dtypes.weakint, name="u"),
lambda u: u.replace(dtype=dtypes.int).cast(dtypes.weakint) if u.addrspace is None else None),
(UPat(Ops.BIND, src=(UPat.var("var").cast(dtypes.weakint), UPat.cvar("val").cast(dtypes.weakint))),
lambda var,val: var.bind(val).cast(dtypes.weakint)),
# remove hanging casts

View file

@ -25,7 +25,7 @@ def validate_index(uidx:UOp, gate:UOp|None=None):
# WEBGPU has a BITCAST in the index, PTX casts pointer to long
# VECTORIZE/GEP can't be properly modeled in z3 since it doesn't support vectors
for x in idx.toposort() | gate.toposort():
if x.op in {Ops.BITCAST, Ops.STACK, Ops.GEP} or (x.op is Ops.CAST and isinstance(x.src[0].dtype, PtrDType)): return True
if x.op in {Ops.BITCAST, Ops.GEP} or (x.op is Ops.CAST and isinstance(x.src[0].dtype, PtrDType)): return True
# if all is good and CHECK_OOB=1, validate with z3
from tinygrad.uop.validate import validate_index_with_z3
@ -42,6 +42,8 @@ def type_verify(ast:UOp|list[UOp], check_spec:PatternMatcher):
if DEBUG >= 3: print_uops(lst)
raise RuntimeError(f"UOp verification failed at {i} on {u.op} {u.dtype} {len(u.src)} {[(x.op, x.dtype, x.arg) for x in u.src]} {u.arg}")
def is_shape_arg(u:UOp) -> bool: return u.dtype.scalar() in (dtypes.weakint, dtypes.int) or (u.op is Ops.STACK and len(u.src) == 0)
# ***** new specs *****
# these ops can be used in the tensor graph and programs
@ -54,6 +56,7 @@ spec_shared = PatternMatcher([
# CONST/DEFINE_VAR are everywhere
(UPat(Ops.CONST, src=(), name="x"), lambda x: type(x.arg) is type(x.dtype.const(x.arg))),
(UPat(Ops.DEFINE_VAR, name="x"), lambda x: len(x.arg) == 3 and isinstance(x.arg[0], str)),
(UPat(Ops.STACK, src=(), dtype=dtypes.void), lambda: True),
# STACK is everywhere too
(UPat(Ops.STACK, dtype=dtypes.void, src=()), lambda: True),
@ -133,7 +136,7 @@ spec_tensor = PatternMatcher([
lambda buf: isinstance(buf.arg, int) and isinstance(buf.dtype, DType)),
# Tensor variable bindings
(UPat(Ops.BIND, (dtypes.int, dtypes.weakint,), (UPat(Ops.DEFINE_VAR), UPat.cvar(dtype=(dtypes.int,dtypes.weakint,))), arg=None), lambda: True),
(UPat(Ops.BIND, (dtypes.int, dtypes.weakint,), (UPat(Ops.PARAM), UPat.cvar(dtype=(dtypes.int,dtypes.weakint,))), arg=None), lambda: True),
# custom function
(UPat(Ops.CUSTOM_FUNCTION, name="x"), lambda x: isinstance(x.arg, str)),
@ -195,6 +198,9 @@ spec_program = PatternMatcher([
# no more of these in programs
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.DEFINE_VAR, Ops.GEP)), lambda: False),
# scalar shape metadata
(UPat(Ops.STACK, src=(), dtype=dtypes.void), lambda: True),
# weakint is not allowed in programs
(UPat(GroupOp.All, dtypes.weakint), lambda: False),

View file

@ -1,8 +1,9 @@
# all of symbolic lives here now
import math, struct
from typing import cast
from collections import defaultdict
from tinygrad.uop.ops import Ops, PatternMatcher, UPat, UOp, GroupOp, exec_alu
from tinygrad.dtype import ConstType, dtypes, PtrDType, can_lossless_cast, Invalid
from tinygrad.dtype import ConstType, PyConst, dtypes, PtrDType, can_lossless_cast, Invalid
from tinygrad.helpers import partition, all_same, prod, flatten, get_single_element, unwrap, IMAGE, dedup
from tinygrad.uop.decompositions import threefry2x32, xpow
from tinygrad.uop.divandmod import div_and_mod_symbolic
@ -186,7 +187,7 @@ def canonicalize_simplex(X:UOp) -> UOp|None:
if u.op is Ops.MUL and u.src[1].op is Ops.CONST and u.src[1].arg > 0:
changed = True
u = u.src[0]
if not (u.op in GroupOp.Irreducible and u.vmin >= 0): return None
if not ((u.op in GroupOp.Irreducible or (u.op is Ops.PARAM and u.arg.vmin_vmax is not None)) and u.vmin >= 0): return None
ret.append(u)
return UOp.usum(*ret) if changed else None
@ -259,8 +260,8 @@ symbolic = symbolic_simple+commutative+PatternMatcher([
((UPat.var("y")+UPat.var("c").where(UPat.var("t"), UPat.var("f"))) + UPat.var("c").where(UPat.var("tt"), UPat.var("ff")), \
lambda y,c,t,tt,f,ff: y+c.where(t+tt, f+ff) if t.op == tt.op == Ops.CONST or f.op == ff.op == Ops.CONST else None),
# ALU/variable min==max -> CONST
(UPat({Ops.CMPLT, Ops.CMPNE, Ops.FLOORDIV, Ops.FLOORMOD, Ops.DEFINE_VAR, Ops.BIND, Ops.SPECIAL}, name="x"),
lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None),
(UPat({Ops.CMPLT, Ops.CMPNE, Ops.FLOORDIV, Ops.FLOORMOD, Ops.DEFINE_VAR, Ops.PARAM, Ops.BIND, Ops.SPECIAL}, name="x"),
lambda x: x.const_like(x.vmin) if x.vmin == x.vmax and (x.op is not Ops.PARAM or x.addrspace is None) else None),
(UPat(Ops.RANGE, src=(UPat(Ops.CONST,)), name="x"), lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None),
# max folding
(UPat.maximum(UPat.var("x"), UPat.var("y")), lambda x,y: x if x.vmin >= y.vmax else y if x.vmax <= y.vmin else None),
@ -343,7 +344,7 @@ def uop_given_valid(valid:UOp, uop:UOp, try_simplex=True) -> UOp:
for i,(expr,v) in enumerate(bounds.items()):
v0, v1 = (expr.vmin if v[0] is None else v[0], expr.vmax if v[1] is None else v[1])
# try checking the whole clause
all_candidates.append((expr, UOp.variable(f"fake{i}", v0, v1, expr.dtype)))
all_candidates.append((expr, UOp.variable(f"fake{i}", cast(PyConst, v0), cast(PyConst, v1), expr.dtype)))
if try_simplex:
# every candidate is a set of constrained UOp based on valid, and if every item in a set simplifies the uop into a same output, we rewrite uop

View file

@ -28,6 +28,8 @@ z3_renderer = PatternMatcher([
(UPat.var("cond").where(UPat.var("x"), UPat.const(dtypes.weakint, Invalid)), lambda x,cond,ctx: (ctx[1][x], ctx[1][cond])),
# variables
(UPat(Ops.SPECIAL, name="x"), lambda x,ctx: create_bounded(x.arg, 0, ctx[1][x.src[0]]-1, ctx[0])),
(UPat(Ops.PARAM, name="x"), lambda x,ctx:
create_bounded(x.expr, x.arg.vmin_vmax[0], x.arg.vmin_vmax[1], ctx[0]) if x.arg.vmin_vmax is not None else None),
(UPat(Ops.DEFINE_VAR, name="x"), lambda x,ctx: create_bounded(x.arg[0], x.arg[1], x.arg[2], ctx[0])),
(UPat(Ops.RANGE, name="x"), lambda x,ctx: create_bounded(x.render(simplify=False), 0, ctx[1][x.src[0]]-1, ctx[0])),
# loads are variables bounded by the min/max of the dtype. non-pointer INDEX is also a LOAD
@ -52,8 +54,10 @@ z3_renderer = PatternMatcher([
def uops_to_z3(solver:z3.Solver, *uops: UOp) -> list[z3.ExprRef]:
# gate on upstream AFTER/BUFFER as a replacement for PtrDType, but keep INDEX as an unknown LOAD
lst = list(UOp.sink(*uops).toposort(gate=lambda x: x.op not in {Ops.AFTER, Ops.BUFFER, Ops.PARAM} and \
(x.dtype.scalar() in dtypes.ints+(dtypes.bool, dtypes.weakint) or x.op is Ops.SINK)))[:-1]
raw = list(UOp.sink(*uops).toposort(gate=lambda x: x.op not in {Ops.AFTER, Ops.BUFFER} and \
not (x.op is Ops.PARAM and x.arg.vmin_vmax is None) and (x.dtype.scalar() in dtypes.ints+(dtypes.bool, dtypes.weakint) or x.op is Ops.SINK)))[:-1]
param_shape_args = {p.src[0] for p in raw if p.op is Ops.PARAM and p.arg.vmin_vmax is not None and len(p.src) == 1}
lst = [u for u in raw if u not in param_shape_args]
z3map: dict[UOp, z3.ExprRef] = {}
for u in lst:
z3_rewritten = z3_renderer.rewrite(u, ctx=(solver.ctx, z3map))