Compare commits

...

8 commits

Author SHA1 Message Date
George Hotz
eee8d261e3
Merge branch 'master' into remove_define_var 2026-06-17 17:50:53 -07:00
George Hotz
f1a0643b15
Merge branch 'master' into remove_define_var 2026-06-17 16:40:59 -07:00
George Hotz
b8c70eb3e9
Merge branch 'master' into remove_define_var 2026-06-17 15:58:57 -07:00
George Hotz
0c3fdda48b work 2026-06-17 11:29:59 -07:00
George Hotz
28bd22db0b passing? 2026-06-17 11:03:30 -07:00
George Hotz
1c7cc0a379 do things pass 2026-06-17 09:30:37 -07:00
George Hotz
123585b93c passing 2026-06-17 09:07:19 -07:00
George Hotz
27fe5e23a8 remove DEFINE_VAR 2026-06-17 08:51:11 -07:00
13 changed files with 53 additions and 34 deletions

View file

@ -13,7 +13,7 @@ if __name__ == "__main__":
print(f"Progress: {i}") print(f"Progress: {i}")
dt = random.choice(dtypes.ints + tuple(dt.vec(4) for dt in dtypes.ints)) dt = random.choice(dtypes.ints + tuple(dt.vec(4) for dt in dtypes.ints))
u = UOp.variable('x', random.randint(dt.min, 0), random.randint(1, dt.max), dtype=dt) u = UOp.variable('x', random.randint(dt.min, 0), random.randint(1, dt.max), dtype=dt)
d = random.randint(1, max(1, u.arg[2])*2) d = random.randint(1, max(1, u.vmax)*2)
if d in powers_of_two: continue if d in powers_of_two: continue
expr = fast_idiv(Device[Device.DEFAULT].renderer, u, d) expr = fast_idiv(Device[Device.DEFAULT].renderer, u, d)
if expr is None: continue if expr is None: continue

View file

@ -20,7 +20,7 @@ binary_ops = [lambda a,b: a+b, lambda a,b: a*b, lambda a,b:a.maximum(b), lambda
comp_ops = [operator.lt, operator.le, operator.gt, operator.ge] comp_ops = [operator.lt, operator.le, operator.gt, operator.ge]
def random_or_sub_expression_int(depth, expr): def random_or_sub_expression_int(depth, expr):
sub_expr = random.choice([e for e in expr.toposort() if e.dtype is not dtypes.bool]) sub_expr = random.choice([e for e in expr.toposort() if e.dtype not in (dtypes.bool, dtypes.void)])
return random.choice([random_int_expr(depth-1), sub_expr]) return random.choice([random_int_expr(depth-1), sub_expr])
def random_int_expr(depth=10): def random_int_expr(depth=10):

View file

@ -171,7 +171,7 @@ class TestGraphRewrite(unittest.TestCase):
c2 = UOp.const(dtypes.float, 2.0) c2 = UOp.const(dtypes.float, 2.0)
nout = graph_rewrite(v+c1+c2, simple_pm) nout = graph_rewrite(v+c1+c2, simple_pm)
self.assertEqual(nout.op, Ops.ADD) self.assertEqual(nout.op, Ops.ADD)
self.assertEqual(nout.src[0].op, Ops.DEFINE_VAR) self.assertEqual(nout.src[0].op, Ops.PARAM)
self.assertEqual(nout.src[1].op, Ops.CONST) self.assertEqual(nout.src[1].op, Ops.CONST)
self.assertEqual(nout.src[1].arg, 3.0) self.assertEqual(nout.src[1].arg, 3.0)

View file

@ -9,7 +9,8 @@ from tinygrad.uop.symbolic import sym, commutative, pm_simplify_valid, pm_move_w
from tinygrad.uop.validate import uops_to_z3 from tinygrad.uop.validate import uops_to_z3
def check_uop_against_string(self, v:UOp, s:str): def check_uop_against_string(self, v:UOp, s:str):
sym_vars = {v.render():v for v in v.toposort() if v.op in (Ops.DEFINE_VAR, Ops.RANGE, Ops.SPECIAL)} sym_vars = {v.render():v for v in v.toposort()
if v.op in (Ops.DEFINE_VAR, Ops.RANGE, Ops.SPECIAL) or (v.op is Ops.PARAM and v.arg.vmin_vmax is not None)}
s_eval = eval(s, sym_vars) s_eval = eval(s, sym_vars)
if isinstance(s_eval, int) and v.dtype==dtypes.weakint: s_eval = UOp.const(dtypes.weakint, s_eval) if isinstance(s_eval, int) and v.dtype==dtypes.weakint: s_eval = UOp.const(dtypes.weakint, s_eval)
elif isinstance(s_eval, (bool, int, float)): s_eval = UOp.const(dtypes.from_py(s_eval), s_eval) elif isinstance(s_eval, (bool, int, float)): s_eval = UOp.const(dtypes.from_py(s_eval), s_eval)
@ -1296,8 +1297,8 @@ class TestMoveWhereOnLoad(unittest.TestCase):
# cond has a range that the rewrite can move into the valid: gate (a<4) goes into load valid # cond has a range that the rewrite can move into the valid: gate (a<4) goes into load valid
cond = (a < 4) & (r < 2) cond = (a < 4) & (r < 2)
valid = (a < 2) # pre-existing valid on the load (to pass can_move check for the r-only clause) valid = (a < 2) # pre-existing valid on the load (to pass can_move check for the r-only clause)
idx = buf.index(a.valid(valid), ptr=True) idx = buf.index(a.valid(valid))
expr = cond.where(idx, 0) expr = cond.where(idx, False)
out = graph_rewrite(expr, pm_move_where_on_load) out = graph_rewrite(expr, pm_move_where_on_load)
# any WHERE in the rewritten graph must have matched-dtype branches # any WHERE in the rewritten graph must have matched-dtype branches
for u in out.toposort(): for u in out.toposort():

View file

@ -224,7 +224,7 @@ class TestViz(unittest.TestCase):
self.assertEqual(len(lst), 1) self.assertEqual(len(lst), 1)
graphs = [x["graph"] for x in viz.get_details(0, 0)] graphs = [x["graph"] for x in viz.get_details(0, 0)]
# const is always in the graph, client side hides exclude=True nodes by default # const is always in the graph, client side hides exclude=True nodes by default
self.assertEqual(list(graphs[0]), [id(a), id(z), id(alu), id(y), id(sink)]) self.assertEqual(list(graphs[0]), [id(a.src[0]), id(a), id(z), id(alu), id(y), id(sink)])
self.assertTrue(graphs[0][id(z)]["exclude"]) self.assertTrue(graphs[0][id(z)]["exclude"])
self.assertTrue(graphs[0][id(y)]["exclude"]) self.assertTrue(graphs[0][id(y)]["exclude"])
self.assertFalse(graphs[0][id(alu)]["exclude"]) self.assertFalse(graphs[0][id(alu)]["exclude"])
@ -265,7 +265,7 @@ class TestViz(unittest.TestCase):
# VIZ displays nested graph_rewrites in a tree view # VIZ displays nested graph_rewrites in a tree view
def leaf_rewrite(x:UOp): return x.rtag(1) if x.tag is None else None def leaf_rewrite(x:UOp): return x.rtag(1) if x.tag is None else None
leaf = TrackedPatternMatcher([(UPat(Ops.DEFINE_VAR, name="x"), leaf_rewrite)]) leaf = TrackedPatternMatcher([(UPat(Ops.PARAM, name="x"), leaf_rewrite)])
def branch_rewrite(x:UOp, y:UOp): def branch_rewrite(x:UOp, y:UOp):
if x.tag is not None: return if x.tag is not None: return

View file

@ -183,7 +183,7 @@ def finalize_after(ctx:AllocCtx, x:UOp):
def replace_input_buffer(ctx:AllocCtx, b:UOp): def replace_input_buffer(ctx:AllocCtx, b:UOp):
ctx.replacements.append(b) ctx.replacements.append(b)
return UOp.param(len(ctx.replacements)-1, b.dtype, b.shape, b.device, return UOp.param(len(ctx.replacements)-1, b.dtype, b.shape, b.device,
b._min_max if b.op is Ops.BIND else None, b.src[0].arg[0] if b.op is Ops.BIND else None, b._min_max if b.op is Ops.BIND else None, b.src[0].expr if b.op is Ops.BIND else None,
b.addrspace if isinstance(b.dtype, (PtrDType, ImageDType)) else AddrSpace.GLOBAL) b.addrspace if isinstance(b.dtype, (PtrDType, ImageDType)) else AddrSpace.GLOBAL)
pm_finalize_call = PatternMatcher([ pm_finalize_call = PatternMatcher([
@ -197,7 +197,7 @@ pm_replace_buf = PatternMatcher([
# replace SLICE with PARAM. this rewrite is bottom up so BUFFERs we don't need won't be in the input # replace SLICE with PARAM. this rewrite is bottom up so BUFFERs we don't need won't be in the input
(UPat(Ops.SLICE, src=(UPat(Ops.BUFFER), UPat(Ops.CONST, dtype=dtypes.weakint)), name="b"), replace_input_buffer), (UPat(Ops.SLICE, src=(UPat(Ops.BUFFER), UPat(Ops.CONST, dtype=dtypes.weakint)), name="b"), replace_input_buffer),
# strip value from BIND for cache key normalization, so different values hit same cache # strip value from BIND for cache key normalization, so different values hit same cache
(UPat(Ops.BIND, src=(UPat(Ops.DEFINE_VAR), UPat(Ops.CONST)), name="b"), replace_input_buffer), (UPat(Ops.BIND, src=(UPat(Ops.PARAM), UPat(Ops.CONST)), name="b"), replace_input_buffer),
]) ])
@track_rewrites(lambda _,ret: f"Callify {pluralize('Buffer', len(ret[1]))}") @track_rewrites(lambda _,ret: f"Callify {pluralize('Buffer', len(ret[1]))}")

View file

@ -113,8 +113,8 @@ pm_reduce_collapse = pm_reduce_unparented + PatternMatcher([
((UPat.var("x")+UPat.var("y")).reduce(arg=Ops.ADD, allow_any_len=True, name="r"), ((UPat.var("x")+UPat.var("y")).reduce(arg=Ops.ADD, allow_any_len=True, name="r"),
lambda x,y,r: x.reduce(*r.src[1:], arg=Ops.ADD) + y.reduce(*r.src[1:],arg=Ops.ADD)), lambda x,y,r: x.reduce(*r.src[1:], arg=Ops.ADD) + y.reduce(*r.src[1:],arg=Ops.ADD)),
# AND on WHERE # AND on WHERE
((UPat(Ops.DEFINE_VAR, name="x") & UPat.var("y")).where(UPat.var("c"), 0).reduce(arg=Ops.ADD, allow_any_len=True, name="r"), ((UPat((Ops.DEFINE_VAR, Ops.PARAM), name="x") & UPat.var("y")).where(UPat.var("c"), 0).reduce(arg=Ops.ADD, allow_any_len=True, name="r"),
lambda x,y,c,r: y.where(c, 0).reduce(*r.src[1:], arg=Ops.ADD)*x.cast(c.dtype)), lambda x,y,c,r: y.where(c, 0).reduce(*r.src[1:], arg=Ops.ADD)*x.cast(c.dtype) if x.op is Ops.DEFINE_VAR or x.arg.vmin_vmax is not None else None),
# MUL casted bool # MUL casted bool
((UPat.var("x") * UPat.var("gate", dtype=dtypes.bool).cast()), lambda x,gate: gate.where(x, 0)), ((UPat.var("x") * UPat.var("gate", dtype=dtypes.bool).cast()), lambda x,gate: gate.where(x, 0)),
])+symbolic ])+symbolic

View file

@ -1,7 +1,7 @@
import math, functools, operator import math, functools, operator
from typing import TYPE_CHECKING, Literal, Self from typing import TYPE_CHECKING, Literal, Self
from tinygrad.uop import Ops from tinygrad.uop import Ops
from tinygrad.dtype import dtypes, ConstType, PyConst, least_upper_dtype, least_upper_float from tinygrad.dtype import dtypes, ConstType, PyConst, least_upper_dtype, least_upper_float, Invalid
from tinygrad.helpers import argfix, polyN from tinygrad.helpers import argfix, polyN
from tinygrad.mixin.dtype import DTypeMixin from tinygrad.mixin.dtype import DTypeMixin
from tinygrad.mixin.creation import CreationMixin from tinygrad.mixin.creation import CreationMixin
@ -416,7 +416,10 @@ class ElementwiseMixin(DTypeMixin, CreationMixin):
def where(self, x: Self | ConstType, y: Self | ConstType) -> Self: def where(self, x: Self | ConstType, y: Self | ConstType) -> Self:
ref: Self = x if isinstance(x, type(self)) else y if isinstance(y, type(self)) else \ ref: Self = x if isinstance(x, type(self)) else y if isinstance(y, type(self)) else \
self.cast(least_upper_dtype(dtypes.from_py(x), dtypes.from_py(y))) self.cast(least_upper_dtype(dtypes.from_py(x), dtypes.from_py(y)))
return self.alu(Ops.WHERE, ref.ufix(x), ref.ufix(y)) fx, fy = ref.ufix(x), ref.ufix(y)
if getattr(fx, "op", None) is Ops.CONST and getattr(fx, "arg", None) is Invalid and fx.dtype != fy.dtype: fx = fy.ufix(Invalid)
if getattr(fy, "op", None) is Ops.CONST and getattr(fy, "arg", None) is Invalid and fy.dtype != fx.dtype: fy = fx.ufix(Invalid)
return self.alu(Ops.WHERE, fx, fy)
def masked_fill(self, mask:Self, value:Self|PyConst) -> Self: def masked_fill(self, mask:Self, value:Self|PyConst) -> Self:
""" """

View file

@ -127,7 +127,7 @@ mop_cleanup = PatternMatcher([
(UPat(Ops.RESHAPE, src=(UPat(Ops.RESHAPE, name="x2"), UPat()), name="x"), lambda x,x2: x.replace(src=(x2.src[0], x.src[1]))), (UPat(Ops.RESHAPE, src=(UPat(Ops.RESHAPE, name="x2"), UPat()), name="x"), lambda x,x2: x.replace(src=(x2.src[0], x.src[1]))),
]) ])
pm_gather_params = PatternMatcher([ (UPat(Ops.PARAM, name="p"), lambda ctx, p: ctx.append(p)), ]) pm_gather_params = PatternMatcher([ (UPat(Ops.PARAM, name="p"), lambda ctx, p: ctx.append(p) if p.arg.slot >= 0 else None), ])
def resolve_function(c:UOp, allow_param_mismatch=True) -> UOp|None: def resolve_function(c:UOp, allow_param_mismatch=True) -> UOp|None:
if c.arg.precompile: return None if c.arg.precompile: return None
params: list[UOp] = [] params: list[UOp] = []
@ -533,7 +533,8 @@ to_define_global = PatternMatcher([
if v.arg.name is not None and v.arg.vmin_vmax is not None else None), if v.arg.name is not None and v.arg.vmin_vmax is not None else None),
(UPat(Ops.PARAM, name="buf"), lambda ctx, buf: (UPat(Ops.PARAM, name="buf"), lambda ctx, buf:
None if isinstance(buf.dtype, PtrDType) or buf.arg.name is not None or buf._shape is None else debuf(ctx, buf)), None if isinstance(buf.dtype, PtrDType) or buf.arg.name is not None or buf._shape is None else debuf(ctx, buf)),
(UPat(Ops.INDEX, src=(UPat(Ops.DEFINE_VAR, name="v"),)), lambda v: v), (UPat(Ops.INDEX, src=(UPat((Ops.DEFINE_VAR, Ops.PARAM), name="v"),)),
lambda v: v if v.op is Ops.DEFINE_VAR or v.arg.vmin_vmax is not None else None),
(UPat(Ops.BIND, name="b"), unbind_kernel), (UPat(Ops.BIND, name="b"), unbind_kernel),
(UPat(Ops.AFTER, name="after"), handle_after), (UPat(Ops.AFTER, name="after"), handle_after),

View file

@ -901,21 +901,22 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
# *** uop Variable stuff *** # *** uop Variable stuff ***
@staticmethod @staticmethod
def variable(name:str, min_val:ConstType, max_val:ConstType, dtype:DType=dtypes.weakint) -> UOp: def variable(name:str, min_val:PyConst, max_val:PyConst, dtype:DType=dtypes.weakint) -> UOp:
assert not isinstance(min_val, UOp) and not isinstance(max_val, UOp), f"can't create Variable {name} with {min_val}/{max_val}" return UOp(Ops.PARAM, dtype, src=(shape_to_shape_arg((dtype.count,) if dtype.count > 1 else ()),),
return UOp(Ops.DEFINE_VAR, dtype, arg=(name, min_val, max_val)) arg=ParamArg(-1, name=name, vmin_vmax=(min_val, max_val), addrspace=None))
@property @property
def expr(self) -> str: def expr(self) -> str:
if self.op is Ops.PARAM: return unwrap(self.arg.name) if self.op is Ops.PARAM: return unwrap(self.arg.name)
assert self.op is Ops.DEFINE_VAR, f"op is {self.op}, need DEFINE_VAR" assert self.op is Ops.DEFINE_VAR, f"op is {self.op}, need DEFINE_VAR"
return self.arg[0] return self.arg[0]
def bind(self, val:int|UOp): def bind(self, val:int|UOp):
assert self.op is Ops.DEFINE_VAR, f"op is {self.op}, need DEFINE_VAR" assert self.op is Ops.PARAM and self.addrspace is None, f"op is {self.op}, need DEFINE_VAR"
uval = self.const_like(val) if isinstance(val, int) else val uval = self.const_like(val) if isinstance(val, int) else val
assert self.arg[1] <= uval.vmin and uval.vmax <= self.arg[2], f"bind {val} not in range [{self.arg[1]}, {self.arg[2]}]" assert self.arg.vmin_vmax[0] <= uval.vmin and uval.vmax <= self.arg.vmin_vmax[1], f"bind {val} not in range [{self.arg[1]}, {self.arg[2]}]"
return UOp(Ops.BIND, self.dtype, (self, uval)) return UOp(Ops.BIND, self.dtype, (self, uval))
def unbind(self) -> tuple[Variable, int]: def unbind(self) -> tuple[Variable, int]:
assert self.op is Ops.BIND and self.src[0].op is Ops.DEFINE_VAR and self.src[1].op is Ops.CONST, f"can't unbind {self}" assert self.op is Ops.BIND and self.src[0].op is Ops.PARAM and self.src[0].arg.vmin_vmax is not None and \
self.src[1].op is Ops.CONST, f"can't unbind {self}"
return self.src[0], self.src[1].arg return self.src[0], self.src[1].arg
def unbind_all(self) -> tuple[UOp, dict[Variable, int]]: def unbind_all(self) -> tuple[UOp, dict[Variable, int]]:
ret:dict[Variable, int] = {} ret:dict[Variable, int] = {}
@ -930,6 +931,7 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
def is_increasing(self:UOp) -> bool: def is_increasing(self:UOp) -> bool:
# is f a monotonically increasing function regards its input # is f a monotonically increasing function regards its input
if self.op is Ops.PARAM and self.arg.vmin_vmax is not None: return True
if self.op in GroupOp.Irreducible: return True if self.op in GroupOp.Irreducible: return True
if self.op is Ops.ADD: return self.src[0].is_increasing() and self.src[1].is_increasing() if self.op is Ops.ADD: return self.src[0].is_increasing() and self.src[1].is_increasing()
if self.op in (Ops.MUL, Ops.CDIV, Ops.FLOORDIV) and self.src[1].op is Ops.CONST and self.src[1].arg >= 0: return self.src[0].is_increasing() if self.op in (Ops.MUL, Ops.CDIV, Ops.FLOORDIV) and self.src[1].op is Ops.CONST and self.src[1].arg >= 0: return self.src[0].is_increasing()
@ -1020,7 +1022,7 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
if self.op is Ops.DEFINE_VAR and self.arg: return self.arg[1], self.arg[2] if self.op is Ops.DEFINE_VAR and self.arg: return self.arg[1], self.arg[2]
if self.op in (Ops.RANGE, Ops.SPECIAL): return 0, (self.src[0]-1).vmax if self.op in (Ops.RANGE, Ops.SPECIAL): return 0, (self.src[0]-1).vmax
if self.op is Ops.BIND: return self.src[0]._min_max # ignore the bound value if self.op is Ops.BIND: return self.src[0]._min_max # ignore the bound value
if self.op in {Ops.UNROLL, Ops.STACK}: return min(x.vmin for x in self.src), max(x.vmax for x in self.src) if self.op in {Ops.UNROLL, Ops.STACK}: return (0, 0) if len(self.src) == 0 else (min(x.vmin for x in self.src), max(x.vmax for x in self.src))
if self.op is Ops.CONST and self.arg is not Invalid: return self.arg, self.arg if self.op is Ops.CONST and self.arg is not Invalid: return self.arg, self.arg
if self.op is Ops.GEP: return self.src[0]._min_max if self.op is Ops.GEP: return self.src[0]._min_max
# TODO: CAST to bool/unsigned is not monotone, still some case can be simplified # TODO: CAST to bool/unsigned is not monotone, still some case can be simplified
@ -1084,7 +1086,7 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
def param_like(self, slot:int): def param_like(self, slot:int):
addrspace = self.addrspace if isinstance(self.dtype, (PtrDType, ImageDType)) else AddrSpace.GLOBAL addrspace = self.addrspace if isinstance(self.dtype, (PtrDType, ImageDType)) else AddrSpace.GLOBAL
if self.op is Ops.BIND: if self.op is Ops.BIND:
return UOp.param(slot, self.dtype, self._shape, self.device, cast(tuple[int, int], self._min_max), self.src[0].arg[0], addrspace) return UOp.param(slot, self.dtype, self._shape, self.device, cast(tuple[int, int], self._min_max), self.src[0].expr, addrspace)
return UOp.param(slot, self.dtype, self.shard_shape if self.axis is not None else self._shape, self.device, addrspace=addrspace, axis=self.axis) return UOp.param(slot, self.dtype, self.shard_shape if self.axis is not None else self._shape, self.device, addrspace=addrspace, axis=self.axis)
# opaque bodies stay as Ops.CALL; value-producing bodies become Ops.FUNCTION (wrapped in TUPLE) # opaque bodies stay as Ops.CALL; value-producing bodies become Ops.FUNCTION (wrapped in TUPLE)
@ -1672,7 +1674,8 @@ pm_lower_index_dtype = PatternMatcher([
# special can only be int32 # special can only be int32
(UPat(Ops.SPECIAL, src=(UPat.var("var").cast(dtypes.weakint),), name="u"), (UPat(Ops.SPECIAL, src=(UPat.var("var").cast(dtypes.weakint),), name="u"),
lambda u,var: u.replace(dtype=dtypes.int, src=(var,)).cast(dtypes.weakint)), lambda u,var: u.replace(dtype=dtypes.int, src=(var,)).cast(dtypes.weakint)),
(UPat(Ops.DEFINE_VAR, dtype=dtypes.weakint, name="u"), lambda u: u.replace(dtype=dtypes.int).cast(dtypes.weakint)), (UPat(Ops.PARAM, dtype=dtypes.weakint, name="u"),
lambda u: u.replace(dtype=dtypes.int).cast(dtypes.weakint) if u.addrspace is None else None),
(UPat(Ops.BIND, src=(UPat.var("var").cast(dtypes.weakint), UPat.cvar("val").cast(dtypes.weakint))), (UPat(Ops.BIND, src=(UPat.var("var").cast(dtypes.weakint), UPat.cvar("val").cast(dtypes.weakint))),
lambda var,val: var.bind(val).cast(dtypes.weakint)), lambda var,val: var.bind(val).cast(dtypes.weakint)),
# remove hanging casts # remove hanging casts

View file

@ -25,7 +25,7 @@ def validate_index(uidx:UOp, gate:UOp|None=None):
# WEBGPU has a BITCAST in the index, PTX casts pointer to long # WEBGPU has a BITCAST in the index, PTX casts pointer to long
# VECTORIZE/GEP can't be properly modeled in z3 since it doesn't support vectors # VECTORIZE/GEP can't be properly modeled in z3 since it doesn't support vectors
for x in idx.toposort() | gate.toposort(): for x in idx.toposort() | gate.toposort():
if x.op in {Ops.BITCAST, Ops.STACK, Ops.GEP} or (x.op is Ops.CAST and isinstance(x.src[0].dtype, PtrDType)): return True if x.op in {Ops.BITCAST, Ops.GEP} or (x.op is Ops.CAST and isinstance(x.src[0].dtype, PtrDType)): return True
# if all is good and CHECK_OOB=1, validate with z3 # if all is good and CHECK_OOB=1, validate with z3
from tinygrad.uop.validate import validate_index_with_z3 from tinygrad.uop.validate import validate_index_with_z3
@ -42,6 +42,8 @@ def type_verify(ast:UOp|list[UOp], check_spec:PatternMatcher):
if DEBUG >= 3: print_uops(lst) if DEBUG >= 3: print_uops(lst)
raise RuntimeError(f"UOp verification failed at {i} on {u.op} {u.dtype} {len(u.src)} {[(x.op, x.dtype, x.arg) for x in u.src]} {u.arg}") raise RuntimeError(f"UOp verification failed at {i} on {u.op} {u.dtype} {len(u.src)} {[(x.op, x.dtype, x.arg) for x in u.src]} {u.arg}")
def is_shape_arg(u:UOp) -> bool: return u.dtype.scalar() in (dtypes.weakint, dtypes.int) or (u.op is Ops.STACK and len(u.src) == 0)
# ***** new specs ***** # ***** new specs *****
# these ops can be used in the tensor graph and programs # these ops can be used in the tensor graph and programs
@ -54,6 +56,7 @@ spec_shared = PatternMatcher([
# CONST/DEFINE_VAR are everywhere # CONST/DEFINE_VAR are everywhere
(UPat(Ops.CONST, src=(), name="x"), lambda x: type(x.arg) is type(x.dtype.const(x.arg))), (UPat(Ops.CONST, src=(), name="x"), lambda x: type(x.arg) is type(x.dtype.const(x.arg))),
(UPat(Ops.DEFINE_VAR, name="x"), lambda x: len(x.arg) == 3 and isinstance(x.arg[0], str)), (UPat(Ops.DEFINE_VAR, name="x"), lambda x: len(x.arg) == 3 and isinstance(x.arg[0], str)),
(UPat(Ops.STACK, src=(), dtype=dtypes.void), lambda: True),
# STACK is everywhere too # STACK is everywhere too
(UPat(Ops.STACK, dtype=dtypes.void, src=()), lambda: True), (UPat(Ops.STACK, dtype=dtypes.void, src=()), lambda: True),
@ -133,7 +136,7 @@ spec_tensor = PatternMatcher([
lambda buf: isinstance(buf.arg, int) and isinstance(buf.dtype, DType)), lambda buf: isinstance(buf.arg, int) and isinstance(buf.dtype, DType)),
# Tensor variable bindings # Tensor variable bindings
(UPat(Ops.BIND, (dtypes.int, dtypes.weakint,), (UPat(Ops.DEFINE_VAR), UPat.cvar(dtype=(dtypes.int,dtypes.weakint,))), arg=None), lambda: True), (UPat(Ops.BIND, (dtypes.int, dtypes.weakint,), (UPat(Ops.PARAM), UPat.cvar(dtype=(dtypes.int,dtypes.weakint,))), arg=None), lambda: True),
# custom function # custom function
(UPat(Ops.CUSTOM_FUNCTION, name="x"), lambda x: isinstance(x.arg, str)), (UPat(Ops.CUSTOM_FUNCTION, name="x"), lambda x: isinstance(x.arg, str)),
@ -195,6 +198,9 @@ spec_program = PatternMatcher([
# no more of these in programs # no more of these in programs
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.DEFINE_VAR, Ops.GEP)), lambda: False), (UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.DEFINE_VAR, Ops.GEP)), lambda: False),
# scalar shape metadata
(UPat(Ops.STACK, src=(), dtype=dtypes.void), lambda: True),
# weakint is not allowed in programs # weakint is not allowed in programs
(UPat(GroupOp.All, dtypes.weakint), lambda: False), (UPat(GroupOp.All, dtypes.weakint), lambda: False),

View file

@ -1,8 +1,9 @@
# all of symbolic lives here now # all of symbolic lives here now
import math, struct import math, struct
from typing import cast
from collections import defaultdict from collections import defaultdict
from tinygrad.uop.ops import Ops, PatternMatcher, UPat, UOp, GroupOp, exec_alu from tinygrad.uop.ops import Ops, PatternMatcher, UPat, UOp, GroupOp, exec_alu
from tinygrad.dtype import ConstType, dtypes, PtrDType, can_lossless_cast, Invalid from tinygrad.dtype import ConstType, PyConst, dtypes, PtrDType, can_lossless_cast, Invalid
from tinygrad.helpers import partition, all_same, prod, flatten, get_single_element, unwrap, IMAGE, dedup from tinygrad.helpers import partition, all_same, prod, flatten, get_single_element, unwrap, IMAGE, dedup
from tinygrad.uop.decompositions import threefry2x32, xpow from tinygrad.uop.decompositions import threefry2x32, xpow
from tinygrad.uop.divandmod import div_and_mod_symbolic from tinygrad.uop.divandmod import div_and_mod_symbolic
@ -186,7 +187,7 @@ def canonicalize_simplex(X:UOp) -> UOp|None:
if u.op is Ops.MUL and u.src[1].op is Ops.CONST and u.src[1].arg > 0: if u.op is Ops.MUL and u.src[1].op is Ops.CONST and u.src[1].arg > 0:
changed = True changed = True
u = u.src[0] u = u.src[0]
if not (u.op in GroupOp.Irreducible and u.vmin >= 0): return None if not ((u.op in GroupOp.Irreducible or (u.op is Ops.PARAM and u.arg.vmin_vmax is not None)) and u.vmin >= 0): return None
ret.append(u) ret.append(u)
return UOp.usum(*ret) if changed else None return UOp.usum(*ret) if changed else None
@ -259,8 +260,8 @@ symbolic = symbolic_simple+commutative+PatternMatcher([
((UPat.var("y")+UPat.var("c").where(UPat.var("t"), UPat.var("f"))) + UPat.var("c").where(UPat.var("tt"), UPat.var("ff")), \ ((UPat.var("y")+UPat.var("c").where(UPat.var("t"), UPat.var("f"))) + UPat.var("c").where(UPat.var("tt"), UPat.var("ff")), \
lambda y,c,t,tt,f,ff: y+c.where(t+tt, f+ff) if t.op == tt.op == Ops.CONST or f.op == ff.op == Ops.CONST else None), lambda y,c,t,tt,f,ff: y+c.where(t+tt, f+ff) if t.op == tt.op == Ops.CONST or f.op == ff.op == Ops.CONST else None),
# ALU/variable min==max -> CONST # ALU/variable min==max -> CONST
(UPat({Ops.CMPLT, Ops.CMPNE, Ops.FLOORDIV, Ops.FLOORMOD, Ops.DEFINE_VAR, Ops.BIND, Ops.SPECIAL}, name="x"), (UPat({Ops.CMPLT, Ops.CMPNE, Ops.FLOORDIV, Ops.FLOORMOD, Ops.DEFINE_VAR, Ops.PARAM, Ops.BIND, Ops.SPECIAL}, name="x"),
lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None), lambda x: x.const_like(x.vmin) if x.vmin == x.vmax and (x.op is not Ops.PARAM or x.addrspace is None) else None),
(UPat(Ops.RANGE, src=(UPat(Ops.CONST,)), name="x"), lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None), (UPat(Ops.RANGE, src=(UPat(Ops.CONST,)), name="x"), lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None),
# max folding # max folding
(UPat.maximum(UPat.var("x"), UPat.var("y")), lambda x,y: x if x.vmin >= y.vmax else y if x.vmax <= y.vmin else None), (UPat.maximum(UPat.var("x"), UPat.var("y")), lambda x,y: x if x.vmin >= y.vmax else y if x.vmax <= y.vmin else None),
@ -343,7 +344,7 @@ def uop_given_valid(valid:UOp, uop:UOp, try_simplex=True) -> UOp:
for i,(expr,v) in enumerate(bounds.items()): for i,(expr,v) in enumerate(bounds.items()):
v0, v1 = (expr.vmin if v[0] is None else v[0], expr.vmax if v[1] is None else v[1]) v0, v1 = (expr.vmin if v[0] is None else v[0], expr.vmax if v[1] is None else v[1])
# try checking the whole clause # try checking the whole clause
all_candidates.append((expr, UOp.variable(f"fake{i}", v0, v1, expr.dtype))) all_candidates.append((expr, UOp.variable(f"fake{i}", cast(PyConst, v0), cast(PyConst, v1), expr.dtype)))
if try_simplex: if try_simplex:
# every candidate is a set of constrained UOp based on valid, and if every item in a set simplifies the uop into a same output, we rewrite uop # every candidate is a set of constrained UOp based on valid, and if every item in a set simplifies the uop into a same output, we rewrite uop

View file

@ -28,6 +28,8 @@ z3_renderer = PatternMatcher([
(UPat.var("cond").where(UPat.var("x"), UPat.const(dtypes.weakint, Invalid)), lambda x,cond,ctx: (ctx[1][x], ctx[1][cond])), (UPat.var("cond").where(UPat.var("x"), UPat.const(dtypes.weakint, Invalid)), lambda x,cond,ctx: (ctx[1][x], ctx[1][cond])),
# variables # variables
(UPat(Ops.SPECIAL, name="x"), lambda x,ctx: create_bounded(x.arg, 0, ctx[1][x.src[0]]-1, ctx[0])), (UPat(Ops.SPECIAL, name="x"), lambda x,ctx: create_bounded(x.arg, 0, ctx[1][x.src[0]]-1, ctx[0])),
(UPat(Ops.PARAM, name="x"), lambda x,ctx:
create_bounded(x.expr, x.arg.vmin_vmax[0], x.arg.vmin_vmax[1], ctx[0]) if x.arg.vmin_vmax is not None else None),
(UPat(Ops.DEFINE_VAR, name="x"), lambda x,ctx: create_bounded(x.arg[0], x.arg[1], x.arg[2], ctx[0])), (UPat(Ops.DEFINE_VAR, name="x"), lambda x,ctx: create_bounded(x.arg[0], x.arg[1], x.arg[2], ctx[0])),
(UPat(Ops.RANGE, name="x"), lambda x,ctx: create_bounded(x.render(simplify=False), 0, ctx[1][x.src[0]]-1, ctx[0])), (UPat(Ops.RANGE, name="x"), lambda x,ctx: create_bounded(x.render(simplify=False), 0, ctx[1][x.src[0]]-1, ctx[0])),
# loads are variables bounded by the min/max of the dtype. non-pointer INDEX is also a LOAD # loads are variables bounded by the min/max of the dtype. non-pointer INDEX is also a LOAD
@ -52,8 +54,10 @@ z3_renderer = PatternMatcher([
def uops_to_z3(solver:z3.Solver, *uops: UOp) -> list[z3.ExprRef]: def uops_to_z3(solver:z3.Solver, *uops: UOp) -> list[z3.ExprRef]:
# gate on upstream AFTER/BUFFER as a replacement for PtrDType, but keep INDEX as an unknown LOAD # gate on upstream AFTER/BUFFER as a replacement for PtrDType, but keep INDEX as an unknown LOAD
lst = list(UOp.sink(*uops).toposort(gate=lambda x: x.op not in {Ops.AFTER, Ops.BUFFER, Ops.PARAM} and \ raw = list(UOp.sink(*uops).toposort(gate=lambda x: x.op not in {Ops.AFTER, Ops.BUFFER} and \
(x.dtype.scalar() in dtypes.ints+(dtypes.bool, dtypes.weakint) or x.op is Ops.SINK)))[:-1] not (x.op is Ops.PARAM and x.arg.vmin_vmax is None) and (x.dtype.scalar() in dtypes.ints+(dtypes.bool, dtypes.weakint) or x.op is Ops.SINK)))[:-1]
param_shape_args = {p.src[0] for p in raw if p.op is Ops.PARAM and p.arg.vmin_vmax is not None and len(p.src) == 1}
lst = [u for u in raw if u not in param_shape_args]
z3map: dict[UOp, z3.ExprRef] = {} z3map: dict[UOp, z3.ExprRef] = {}
for u in lst: for u in lst:
z3_rewritten = z3_renderer.rewrite(u, ctx=(solver.ctx, z3map)) z3_rewritten = z3_renderer.rewrite(u, ctx=(solver.ctx, z3map))