mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
remove DEFINE_VAR try 2 (#16651)
* remove DEFINE_VAR try 2 * param * null index * fix fuzzing * fixes * no gather neg params * param is just Irreducible * fixes * skip stack * need to filter slots there
This commit is contained in:
parent
d37248c3ec
commit
5989d0b150
14 changed files with 36 additions and 31 deletions
2
test/external/fuzz_fast_idiv.py
vendored
2
test/external/fuzz_fast_idiv.py
vendored
|
|
@ -13,7 +13,7 @@ if __name__ == "__main__":
|
|||
print(f"Progress: {i}")
|
||||
dt = random.choice(dtypes.ints + tuple(dt.vec(4) for dt in dtypes.ints))
|
||||
u = UOp.variable('x', random.randint(dt.min, 0), random.randint(1, dt.max), dtype=dt)
|
||||
d = random.randint(1, max(1, u.arg[2])*2)
|
||||
d = random.randint(1, max(1, u.vmax)*2)
|
||||
if d in powers_of_two: continue
|
||||
expr = fast_idiv(Device[Device.DEFAULT].renderer, u, d)
|
||||
if expr is None: continue
|
||||
|
|
|
|||
2
test/external/fuzz_symbolic.py
vendored
2
test/external/fuzz_symbolic.py
vendored
|
|
@ -20,7 +20,7 @@ binary_ops = [lambda a,b: a+b, lambda a,b: a*b, lambda a,b:a.maximum(b), lambda
|
|||
comp_ops = [operator.lt, operator.le, operator.gt, operator.ge]
|
||||
|
||||
def random_or_sub_expression_int(depth, expr):
|
||||
sub_expr = random.choice([e for e in expr.toposort() if e.dtype is not dtypes.bool])
|
||||
sub_expr = random.choice([e for e in expr.toposort() if e.dtype not in (dtypes.bool, dtypes.void)])
|
||||
return random.choice([random_int_expr(depth-1), sub_expr])
|
||||
|
||||
def random_int_expr(depth=10):
|
||||
|
|
|
|||
|
|
@ -171,7 +171,7 @@ class TestGraphRewrite(unittest.TestCase):
|
|||
c2 = UOp.const(dtypes.float, 2.0)
|
||||
nout = graph_rewrite(v+c1+c2, simple_pm)
|
||||
self.assertEqual(nout.op, Ops.ADD)
|
||||
self.assertEqual(nout.src[0].op, Ops.DEFINE_VAR)
|
||||
self.assertEqual(nout.src[0].op, Ops.PARAM)
|
||||
self.assertEqual(nout.src[1].op, Ops.CONST)
|
||||
self.assertEqual(nout.src[1].arg, 3.0)
|
||||
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from tinygrad.uop.symbolic import sym, commutative, pm_simplify_valid, pm_move_w
|
|||
from tinygrad.uop.validate import uops_to_z3
|
||||
|
||||
def check_uop_against_string(self, v:UOp, s:str):
|
||||
sym_vars = {v.render():v for v in v.toposort() if v.op in (Ops.DEFINE_VAR, Ops.RANGE, Ops.SPECIAL)}
|
||||
sym_vars = {v.render():v for v in v.toposort() if v.op in (Ops.DEFINE_VAR, Ops.RANGE, Ops.SPECIAL, Ops.PARAM)}
|
||||
s_eval = eval(s, sym_vars)
|
||||
if isinstance(s_eval, int) and v.dtype==dtypes.weakint: s_eval = UOp.const(dtypes.weakint, s_eval)
|
||||
elif isinstance(s_eval, (bool, int, float)): s_eval = UOp.const(dtypes.from_py(s_eval), s_eval)
|
||||
|
|
|
|||
|
|
@ -226,7 +226,7 @@ class TestViz(unittest.TestCase):
|
|||
self.assertEqual(len(lst), 1)
|
||||
graphs = [x["graph"] for x in viz.get_details(0, 0)]
|
||||
# const is always in the graph, client side hides exclude=True nodes by default
|
||||
self.assertEqual(list(graphs[0]), [id(a), id(z), id(alu), id(y), id(sink)])
|
||||
self.assertEqual(list(graphs[0]), [id(a.src[0]), id(a), id(z), id(alu), id(y), id(sink)])
|
||||
self.assertTrue(graphs[0][id(z)]["exclude"])
|
||||
self.assertTrue(graphs[0][id(y)]["exclude"])
|
||||
self.assertFalse(graphs[0][id(alu)]["exclude"])
|
||||
|
|
@ -267,7 +267,7 @@ class TestViz(unittest.TestCase):
|
|||
# VIZ displays nested graph_rewrites in a tree view
|
||||
|
||||
def leaf_rewrite(x:UOp): return x.rtag(1) if x.tag is None else None
|
||||
leaf = TrackedPatternMatcher([(UPat(Ops.DEFINE_VAR, name="x"), leaf_rewrite)])
|
||||
leaf = TrackedPatternMatcher([(UPat(Ops.PARAM, name="x"), leaf_rewrite)])
|
||||
|
||||
def branch_rewrite(x:UOp, y:UOp):
|
||||
if x.tag is not None: return
|
||||
|
|
|
|||
|
|
@ -183,7 +183,7 @@ def finalize_after(ctx:AllocCtx, x:UOp):
|
|||
def replace_input_buffer(ctx:AllocCtx, b:UOp):
|
||||
ctx.replacements.append(b)
|
||||
return UOp.param(len(ctx.replacements)-1, b.dtype, b.shape, b.device,
|
||||
b._min_max if b.op is Ops.BIND else None, b.src[0].arg[0] if b.op is Ops.BIND else None,
|
||||
b._min_max if b.op is Ops.BIND else None, b.src[0].expr if b.op is Ops.BIND else None,
|
||||
b.addrspace if isinstance(b.dtype, (PtrDType, ImageDType)) else AddrSpace.GLOBAL)
|
||||
|
||||
pm_finalize_call = PatternMatcher([
|
||||
|
|
@ -197,7 +197,7 @@ pm_replace_buf = PatternMatcher([
|
|||
# replace SLICE with PARAM. this rewrite is bottom up so BUFFERs we don't need won't be in the input
|
||||
(UPat(Ops.SLICE, src=(UPat(Ops.BUFFER), UPat(Ops.CONST, dtype=dtypes.weakint)), name="b"), replace_input_buffer),
|
||||
# strip value from BIND for cache key normalization, so different values hit same cache
|
||||
(UPat(Ops.BIND, src=(UPat(Ops.DEFINE_VAR), UPat(Ops.CONST)), name="b"), replace_input_buffer),
|
||||
(UPat(Ops.BIND, src=(UPat(Ops.PARAM), UPat(Ops.CONST)), name="b"), replace_input_buffer),
|
||||
])
|
||||
|
||||
@track_rewrites(lambda _,ret: f"Callify {pluralize('Buffer', len(ret[1]))}")
|
||||
|
|
|
|||
|
|
@ -328,7 +328,7 @@ class Scheduler:
|
|||
def group_for_reduces(self) -> int: return len(self.axes_of(AxisType.GROUP_REDUCE))
|
||||
|
||||
def bufs_from_ast(ast:UOp, dname:str) -> list[Buffer]:
|
||||
glbls = sorted([x for x in ast.backward_slice if x.op is Ops.PARAM], key=lambda x: x.arg.slot)
|
||||
glbls = sorted([x for x in ast.backward_slice if x.op is Ops.PARAM and x.arg.slot >= 0], key=lambda x: x.arg.slot)
|
||||
return [Buffer(dname, x.max_numel(), x.dtype.base) for x in glbls]
|
||||
|
||||
def apply_opts(ast:UOp, ren:Renderer, beam:int=0) -> UOp:
|
||||
|
|
|
|||
|
|
@ -113,7 +113,7 @@ pm_reduce_collapse = pm_reduce_unparented + PatternMatcher([
|
|||
((UPat.var("x")+UPat.var("y")).reduce(arg=Ops.ADD, allow_any_len=True, name="r"),
|
||||
lambda x,y,r: x.reduce(*r.src[1:], arg=Ops.ADD) + y.reduce(*r.src[1:],arg=Ops.ADD)),
|
||||
# AND on WHERE
|
||||
((UPat(Ops.DEFINE_VAR, name="x") & UPat.var("y")).where(UPat.var("c"), 0).reduce(arg=Ops.ADD, allow_any_len=True, name="r"),
|
||||
((UPat(Ops.PARAM, name="x") & UPat.var("y")).where(UPat.var("c"), 0).reduce(arg=Ops.ADD, allow_any_len=True, name="r"),
|
||||
lambda x,y,c,r: y.where(c, 0).reduce(*r.src[1:], arg=Ops.ADD)*x.cast(c.dtype)),
|
||||
# MUL casted bool
|
||||
((UPat.var("x") * UPat.var("gate", dtype=dtypes.bool).cast()), lambda x,gate: gate.where(x, 0)),
|
||||
|
|
|
|||
|
|
@ -127,7 +127,7 @@ mop_cleanup = PatternMatcher([
|
|||
(UPat(Ops.RESHAPE, src=(UPat(Ops.RESHAPE, name="x2"), UPat()), name="x"), lambda x,x2: x.replace(src=(x2.src[0], x.src[1]))),
|
||||
])
|
||||
|
||||
pm_gather_params = PatternMatcher([ (UPat(Ops.PARAM, name="p"), lambda ctx, p: ctx.append(p)), ])
|
||||
pm_gather_params = PatternMatcher([ (UPat(Ops.PARAM, name="p"), lambda ctx, p: ctx.append(p) if p.arg.slot >= 0 else None), ])
|
||||
def resolve_function(c:UOp, allow_param_mismatch=True) -> UOp|None:
|
||||
if c.arg.precompile: return None
|
||||
params: list[UOp] = []
|
||||
|
|
@ -533,7 +533,9 @@ to_define_global = PatternMatcher([
|
|||
if v.arg.name is not None and v.arg.vmin_vmax is not None else None),
|
||||
(UPat(Ops.PARAM, name="buf"), lambda ctx, buf:
|
||||
None if isinstance(buf.dtype, PtrDType) or buf.arg.name is not None or buf._shape is None else debuf(ctx, buf)),
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.DEFINE_VAR, name="v"),)), lambda v: v),
|
||||
|
||||
# this was DEFINE_VAR, clean this up and make it universal
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.PARAM, name="v"),)), lambda v: v if v.addrspace == AddrSpace.ALU else None),
|
||||
|
||||
(UPat(Ops.BIND, name="b"), unbind_kernel),
|
||||
(UPat(Ops.AFTER, name="after"), handle_after),
|
||||
|
|
|
|||
|
|
@ -127,7 +127,7 @@ class GroupOp:
|
|||
|
||||
Defines = {Ops.PARAM, Ops.DEFINE_LOCAL, Ops.DEFINE_REG}
|
||||
|
||||
Irreducible = {Ops.CONST, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.RANGE}
|
||||
Irreducible = {Ops.CONST, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.RANGE, Ops.PARAM}
|
||||
Movement = {Ops.RESHAPE, Ops.EXPAND, Ops.PERMUTE, Ops.PAD, Ops.SHRINK, Ops.FLIP}
|
||||
|
||||
# BinaryOps that can be flipped
|
||||
|
|
|
|||
|
|
@ -903,21 +903,20 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
|
|||
# *** uop Variable stuff ***
|
||||
|
||||
@staticmethod
|
||||
def variable(name:str, min_val:ConstType, max_val:ConstType, dtype:DType=dtypes.weakint) -> UOp:
|
||||
assert not isinstance(min_val, UOp) and not isinstance(max_val, UOp), f"can't create Variable {name} with {min_val}/{max_val}"
|
||||
return UOp(Ops.DEFINE_VAR, dtype, arg=(name, min_val, max_val))
|
||||
def variable(name:str, min_val:PyConst, max_val:PyConst, dtype:DType=dtypes.weakint) -> UOp:
|
||||
return UOp(Ops.PARAM, dtype, src=(shape_to_shape_arg((dtype.count,) if dtype.count > 1 else ()),),
|
||||
arg=ParamArg(-1, name=name, vmin_vmax=(min_val, max_val), addrspace=AddrSpace.ALU))
|
||||
@property
|
||||
def expr(self) -> str:
|
||||
if self.op is Ops.PARAM: return unwrap(self.arg.name)
|
||||
assert self.op is Ops.DEFINE_VAR, f"op is {self.op}, need DEFINE_VAR"
|
||||
return self.arg[0]
|
||||
assert self.op is Ops.PARAM
|
||||
return unwrap(self.arg.name)
|
||||
def bind(self, val:int|UOp):
|
||||
assert self.op is Ops.DEFINE_VAR, f"op is {self.op}, need DEFINE_VAR"
|
||||
assert self.op is Ops.PARAM and self.addrspace is AddrSpace.ALU, f"op is {self.op}, need PARAM"
|
||||
uval = self.const_like(val) if isinstance(val, int) else val
|
||||
assert self.arg[1] <= uval.vmin and uval.vmax <= self.arg[2], f"bind {val} not in range [{self.arg[1]}, {self.arg[2]}]"
|
||||
assert self.vmin <= uval.vmin and uval.vmax <= self.vmax, f"bind {val} not in range [{self.vmin}, {self.vmax}]"
|
||||
return UOp(Ops.BIND, self.dtype, (self, uval))
|
||||
def unbind(self) -> tuple[Variable, int]:
|
||||
assert self.op is Ops.BIND and self.src[0].op is Ops.DEFINE_VAR and self.src[1].op is Ops.CONST, f"can't unbind {self}"
|
||||
assert self.op is Ops.BIND and self.src[0].op is Ops.PARAM and self.src[1].op is Ops.CONST, f"can't unbind {self}"
|
||||
return self.src[0], self.src[1].arg
|
||||
def unbind_all(self) -> tuple[UOp, dict[Variable, int]]:
|
||||
ret:dict[Variable, int] = {}
|
||||
|
|
@ -1086,7 +1085,7 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
|
|||
def param_like(self, slot:int):
|
||||
addrspace = self.addrspace if isinstance(self.dtype, (PtrDType, ImageDType)) else AddrSpace.GLOBAL
|
||||
if self.op is Ops.BIND:
|
||||
return UOp.param(slot, self.dtype, self._shape, self.device, cast(tuple[int, int], self._min_max), self.src[0].arg[0], addrspace)
|
||||
return UOp.param(slot, self.dtype, self._shape, self.device, cast(tuple[int, int], self._min_max), self.src[0].expr, addrspace)
|
||||
return UOp.param(slot, self.dtype, self.shard_shape if self.axis is not None else self._shape, self.device, addrspace=addrspace, axis=self.axis)
|
||||
|
||||
# opaque bodies stay as Ops.CALL; value-producing bodies become Ops.FUNCTION (wrapped in TUPLE)
|
||||
|
|
@ -1675,7 +1674,8 @@ pm_lower_index_dtype = PatternMatcher([
|
|||
# special can only be int32
|
||||
(UPat(Ops.SPECIAL, src=(UPat.var("var").cast(dtypes.weakint),), name="u"),
|
||||
lambda u,var: u.replace(dtype=dtypes.int, src=(var,)).cast(dtypes.weakint)),
|
||||
(UPat(Ops.DEFINE_VAR, dtype=dtypes.weakint, name="u"), lambda u: u.replace(dtype=dtypes.int).cast(dtypes.weakint)),
|
||||
(UPat(Ops.PARAM, dtype=dtypes.weakint, name="u"),
|
||||
lambda u: u.replace(dtype=dtypes.int).cast(dtypes.weakint) if u.addrspace == AddrSpace.ALU else None),
|
||||
(UPat(Ops.BIND, src=(UPat.var("var").cast(dtypes.weakint), UPat.cvar("val").cast(dtypes.weakint))),
|
||||
lambda var,val: var.bind(val).cast(dtypes.weakint)),
|
||||
# remove hanging casts
|
||||
|
|
|
|||
|
|
@ -24,7 +24,8 @@ def validate_index(uidx:UOp, gate:UOp|None=None):
|
|||
# TODO: validate these
|
||||
# WEBGPU has a BITCAST in the index, PTX casts pointer to long
|
||||
# VECTORIZE/GEP can't be properly modeled in z3 since it doesn't support vectors
|
||||
for x in idx.toposort() | gate.toposort():
|
||||
# don't descend into PARAM shape metadata; only the PARAM value participates in index arithmetic
|
||||
for x in idx.toposort(gate=lambda x: x.op is not Ops.PARAM) | gate.toposort(gate=lambda x: x.op is not Ops.PARAM):
|
||||
if x.op in {Ops.BITCAST, Ops.STACK, Ops.GEP} or (x.op is Ops.CAST and isinstance(x.src[0].dtype, PtrDType)): return True
|
||||
|
||||
# if all is good and CHECK_OOB=1, validate with z3
|
||||
|
|
@ -136,7 +137,7 @@ spec_tensor = PatternMatcher([
|
|||
lambda buf: isinstance(buf.arg, int) and isinstance(buf.dtype, DType)),
|
||||
|
||||
# Tensor variable bindings
|
||||
(UPat(Ops.BIND, (dtypes.int, dtypes.weakint,), (UPat(Ops.DEFINE_VAR), UPat.cvar(dtype=(dtypes.int,dtypes.weakint,))), arg=None), lambda: True),
|
||||
(UPat(Ops.BIND, (dtypes.int, dtypes.weakint,), (UPat(Ops.PARAM), UPat.cvar(dtype=(dtypes.int,dtypes.weakint,))), arg=None), lambda: True),
|
||||
|
||||
# custom function
|
||||
(UPat(Ops.CUSTOM_FUNCTION, name="x"), lambda x: isinstance(x.arg, str)),
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
import math, struct
|
||||
from collections import defaultdict
|
||||
from tinygrad.uop.ops import Ops, PatternMatcher, UPat, UOp, GroupOp, exec_alu
|
||||
from tinygrad.dtype import ConstType, dtypes, PtrDType, can_lossless_cast, Invalid
|
||||
from tinygrad.dtype import PyConst, ConstType, dtypes, PtrDType, can_lossless_cast, Invalid
|
||||
from tinygrad.helpers import partition, all_same, prod, flatten, get_single_element, unwrap, IMAGE, dedup
|
||||
from tinygrad.uop.decompositions import threefry2x32, xpow
|
||||
from tinygrad.uop.divandmod import div_and_mod_symbolic
|
||||
|
|
@ -259,7 +259,7 @@ symbolic = symbolic_simple+commutative+PatternMatcher([
|
|||
((UPat.var("y")+UPat.var("c").where(UPat.var("t"), UPat.var("f"))) + UPat.var("c").where(UPat.var("tt"), UPat.var("ff")), \
|
||||
lambda y,c,t,tt,f,ff: y+c.where(t+tt, f+ff) if t.op == tt.op == Ops.CONST or f.op == ff.op == Ops.CONST else None),
|
||||
# ALU/variable min==max -> CONST
|
||||
(UPat({Ops.CMPLT, Ops.CMPNE, Ops.FLOORDIV, Ops.FLOORMOD, Ops.DEFINE_VAR, Ops.BIND, Ops.SPECIAL}, name="x"),
|
||||
(UPat({Ops.CMPLT, Ops.CMPNE, Ops.FLOORDIV, Ops.FLOORMOD, Ops.DEFINE_VAR, Ops.PARAM, Ops.BIND, Ops.SPECIAL}, name="x"),
|
||||
lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None),
|
||||
(UPat(Ops.RANGE, src=(UPat(Ops.CONST,)), name="x"), lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None),
|
||||
# max folding
|
||||
|
|
@ -332,7 +332,7 @@ def uop_given_valid(valid:UOp, uop:UOp, try_simplex=True) -> UOp:
|
|||
# return simplified uop (might be the same as input)
|
||||
|
||||
# first, parse valid into {expr: (lower_bound, upper_bound)}
|
||||
bounds:defaultdict[UOp, list[ConstType|None]] = defaultdict(lambda: [None, None])
|
||||
bounds:defaultdict[UOp, list[PyConst|None]] = defaultdict(lambda: [None, None])
|
||||
for stmt in valid.split_uop(Ops.AND):
|
||||
if (res:=parse_valid(stmt)) is None: continue
|
||||
expr, is_upper, c = res
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ z3_renderer = PatternMatcher([
|
|||
(UPat.var("cond").where(UPat.var("x"), UPat.const(dtypes.weakint, Invalid)), lambda x,cond,ctx: (ctx[1][x], ctx[1][cond])),
|
||||
# variables
|
||||
(UPat(Ops.SPECIAL, name="x"), lambda x,ctx: create_bounded(x.arg, 0, ctx[1][x.src[0]]-1, ctx[0])),
|
||||
(UPat(Ops.DEFINE_VAR, name="x"), lambda x,ctx: create_bounded(x.arg[0], x.arg[1], x.arg[2], ctx[0])),
|
||||
(UPat(Ops.PARAM, name="x"), lambda x,ctx: create_bounded(x.arg.name, x.vmin, x.vmax, ctx[0])),
|
||||
(UPat(Ops.RANGE, name="x"), lambda x,ctx: create_bounded(x.render(simplify=False), 0, ctx[1][x.src[0]]-1, ctx[0])),
|
||||
# loads are variables bounded by the min/max of the dtype. non-pointer INDEX is also a LOAD
|
||||
(UPat((Ops.LOAD, Ops.INDEX), dtypes.ints+(dtypes.weakint,), name="x"), lambda x,ctx:
|
||||
|
|
@ -52,10 +52,12 @@ z3_renderer = PatternMatcher([
|
|||
|
||||
def uops_to_z3(solver:z3.Solver, *uops: UOp) -> list[z3.ExprRef]:
|
||||
# gate on upstream AFTER/BUFFER as a replacement for PtrDType, but keep INDEX as an unknown LOAD
|
||||
lst = list(UOp.sink(*uops).toposort(gate=lambda x: x.op not in {Ops.AFTER, Ops.BUFFER, Ops.PARAM} and \
|
||||
lst = list(UOp.sink(*uops).toposort(gate=lambda x: x.op not in {Ops.AFTER, Ops.BUFFER} and \
|
||||
(x.dtype.scalar() in dtypes.ints+(dtypes.bool, dtypes.weakint) or x.op is Ops.SINK)))[:-1]
|
||||
z3map: dict[UOp, z3.ExprRef] = {}
|
||||
for u in lst:
|
||||
# NOTE: we skip STACK here, it can't actually be accessed
|
||||
if u.op is Ops.STACK: continue
|
||||
z3_rewritten = z3_renderer.rewrite(u, ctx=(solver.ctx, z3map))
|
||||
if z3_rewritten is None: raise NotImplementedError(f"{u.op} is not supported by z3")
|
||||
new_u, constraint = cast(tuple[z3.ArithRef, z3.BoolRef|None], z3_rewritten)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue