mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
8 commits
master
...
remove_def
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
eee8d261e3 |
||
|
|
f1a0643b15 |
||
|
|
b8c70eb3e9 |
||
|
|
0c3fdda48b | ||
|
|
28bd22db0b | ||
|
|
1c7cc0a379 | ||
|
|
123585b93c | ||
|
|
27fe5e23a8 |
13 changed files with 53 additions and 34 deletions
2
test/external/fuzz_fast_idiv.py
vendored
2
test/external/fuzz_fast_idiv.py
vendored
|
|
@ -13,7 +13,7 @@ if __name__ == "__main__":
|
|||
print(f"Progress: {i}")
|
||||
dt = random.choice(dtypes.ints + tuple(dt.vec(4) for dt in dtypes.ints))
|
||||
u = UOp.variable('x', random.randint(dt.min, 0), random.randint(1, dt.max), dtype=dt)
|
||||
d = random.randint(1, max(1, u.arg[2])*2)
|
||||
d = random.randint(1, max(1, u.vmax)*2)
|
||||
if d in powers_of_two: continue
|
||||
expr = fast_idiv(Device[Device.DEFAULT].renderer, u, d)
|
||||
if expr is None: continue
|
||||
|
|
|
|||
2
test/external/fuzz_symbolic.py
vendored
2
test/external/fuzz_symbolic.py
vendored
|
|
@ -20,7 +20,7 @@ binary_ops = [lambda a,b: a+b, lambda a,b: a*b, lambda a,b:a.maximum(b), lambda
|
|||
comp_ops = [operator.lt, operator.le, operator.gt, operator.ge]
|
||||
|
||||
def random_or_sub_expression_int(depth, expr):
|
||||
sub_expr = random.choice([e for e in expr.toposort() if e.dtype is not dtypes.bool])
|
||||
sub_expr = random.choice([e for e in expr.toposort() if e.dtype not in (dtypes.bool, dtypes.void)])
|
||||
return random.choice([random_int_expr(depth-1), sub_expr])
|
||||
|
||||
def random_int_expr(depth=10):
|
||||
|
|
|
|||
|
|
@ -171,7 +171,7 @@ class TestGraphRewrite(unittest.TestCase):
|
|||
c2 = UOp.const(dtypes.float, 2.0)
|
||||
nout = graph_rewrite(v+c1+c2, simple_pm)
|
||||
self.assertEqual(nout.op, Ops.ADD)
|
||||
self.assertEqual(nout.src[0].op, Ops.DEFINE_VAR)
|
||||
self.assertEqual(nout.src[0].op, Ops.PARAM)
|
||||
self.assertEqual(nout.src[1].op, Ops.CONST)
|
||||
self.assertEqual(nout.src[1].arg, 3.0)
|
||||
|
||||
|
|
|
|||
|
|
@ -9,7 +9,8 @@ from tinygrad.uop.symbolic import sym, commutative, pm_simplify_valid, pm_move_w
|
|||
from tinygrad.uop.validate import uops_to_z3
|
||||
|
||||
def check_uop_against_string(self, v:UOp, s:str):
|
||||
sym_vars = {v.render():v for v in v.toposort() if v.op in (Ops.DEFINE_VAR, Ops.RANGE, Ops.SPECIAL)}
|
||||
sym_vars = {v.render():v for v in v.toposort()
|
||||
if v.op in (Ops.DEFINE_VAR, Ops.RANGE, Ops.SPECIAL) or (v.op is Ops.PARAM and v.arg.vmin_vmax is not None)}
|
||||
s_eval = eval(s, sym_vars)
|
||||
if isinstance(s_eval, int) and v.dtype==dtypes.weakint: s_eval = UOp.const(dtypes.weakint, s_eval)
|
||||
elif isinstance(s_eval, (bool, int, float)): s_eval = UOp.const(dtypes.from_py(s_eval), s_eval)
|
||||
|
|
@ -1296,8 +1297,8 @@ class TestMoveWhereOnLoad(unittest.TestCase):
|
|||
# cond has a range that the rewrite can move into the valid: gate (a<4) goes into load valid
|
||||
cond = (a < 4) & (r < 2)
|
||||
valid = (a < 2) # pre-existing valid on the load (to pass can_move check for the r-only clause)
|
||||
idx = buf.index(a.valid(valid), ptr=True)
|
||||
expr = cond.where(idx, 0)
|
||||
idx = buf.index(a.valid(valid))
|
||||
expr = cond.where(idx, False)
|
||||
out = graph_rewrite(expr, pm_move_where_on_load)
|
||||
# any WHERE in the rewritten graph must have matched-dtype branches
|
||||
for u in out.toposort():
|
||||
|
|
|
|||
|
|
@ -224,7 +224,7 @@ class TestViz(unittest.TestCase):
|
|||
self.assertEqual(len(lst), 1)
|
||||
graphs = [x["graph"] for x in viz.get_details(0, 0)]
|
||||
# const is always in the graph, client side hides exclude=True nodes by default
|
||||
self.assertEqual(list(graphs[0]), [id(a), id(z), id(alu), id(y), id(sink)])
|
||||
self.assertEqual(list(graphs[0]), [id(a.src[0]), id(a), id(z), id(alu), id(y), id(sink)])
|
||||
self.assertTrue(graphs[0][id(z)]["exclude"])
|
||||
self.assertTrue(graphs[0][id(y)]["exclude"])
|
||||
self.assertFalse(graphs[0][id(alu)]["exclude"])
|
||||
|
|
@ -265,7 +265,7 @@ class TestViz(unittest.TestCase):
|
|||
# VIZ displays nested graph_rewrites in a tree view
|
||||
|
||||
def leaf_rewrite(x:UOp): return x.rtag(1) if x.tag is None else None
|
||||
leaf = TrackedPatternMatcher([(UPat(Ops.DEFINE_VAR, name="x"), leaf_rewrite)])
|
||||
leaf = TrackedPatternMatcher([(UPat(Ops.PARAM, name="x"), leaf_rewrite)])
|
||||
|
||||
def branch_rewrite(x:UOp, y:UOp):
|
||||
if x.tag is not None: return
|
||||
|
|
|
|||
|
|
@ -183,7 +183,7 @@ def finalize_after(ctx:AllocCtx, x:UOp):
|
|||
def replace_input_buffer(ctx:AllocCtx, b:UOp):
|
||||
ctx.replacements.append(b)
|
||||
return UOp.param(len(ctx.replacements)-1, b.dtype, b.shape, b.device,
|
||||
b._min_max if b.op is Ops.BIND else None, b.src[0].arg[0] if b.op is Ops.BIND else None,
|
||||
b._min_max if b.op is Ops.BIND else None, b.src[0].expr if b.op is Ops.BIND else None,
|
||||
b.addrspace if isinstance(b.dtype, (PtrDType, ImageDType)) else AddrSpace.GLOBAL)
|
||||
|
||||
pm_finalize_call = PatternMatcher([
|
||||
|
|
@ -197,7 +197,7 @@ pm_replace_buf = PatternMatcher([
|
|||
# replace SLICE with PARAM. this rewrite is bottom up so BUFFERs we don't need won't be in the input
|
||||
(UPat(Ops.SLICE, src=(UPat(Ops.BUFFER), UPat(Ops.CONST, dtype=dtypes.weakint)), name="b"), replace_input_buffer),
|
||||
# strip value from BIND for cache key normalization, so different values hit same cache
|
||||
(UPat(Ops.BIND, src=(UPat(Ops.DEFINE_VAR), UPat(Ops.CONST)), name="b"), replace_input_buffer),
|
||||
(UPat(Ops.BIND, src=(UPat(Ops.PARAM), UPat(Ops.CONST)), name="b"), replace_input_buffer),
|
||||
])
|
||||
|
||||
@track_rewrites(lambda _,ret: f"Callify {pluralize('Buffer', len(ret[1]))}")
|
||||
|
|
|
|||
|
|
@ -113,8 +113,8 @@ pm_reduce_collapse = pm_reduce_unparented + PatternMatcher([
|
|||
((UPat.var("x")+UPat.var("y")).reduce(arg=Ops.ADD, allow_any_len=True, name="r"),
|
||||
lambda x,y,r: x.reduce(*r.src[1:], arg=Ops.ADD) + y.reduce(*r.src[1:],arg=Ops.ADD)),
|
||||
# AND on WHERE
|
||||
((UPat(Ops.DEFINE_VAR, name="x") & UPat.var("y")).where(UPat.var("c"), 0).reduce(arg=Ops.ADD, allow_any_len=True, name="r"),
|
||||
lambda x,y,c,r: y.where(c, 0).reduce(*r.src[1:], arg=Ops.ADD)*x.cast(c.dtype)),
|
||||
((UPat((Ops.DEFINE_VAR, Ops.PARAM), name="x") & UPat.var("y")).where(UPat.var("c"), 0).reduce(arg=Ops.ADD, allow_any_len=True, name="r"),
|
||||
lambda x,y,c,r: y.where(c, 0).reduce(*r.src[1:], arg=Ops.ADD)*x.cast(c.dtype) if x.op is Ops.DEFINE_VAR or x.arg.vmin_vmax is not None else None),
|
||||
# MUL casted bool
|
||||
((UPat.var("x") * UPat.var("gate", dtype=dtypes.bool).cast()), lambda x,gate: gate.where(x, 0)),
|
||||
])+symbolic
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import math, functools, operator
|
||||
from typing import TYPE_CHECKING, Literal, Self
|
||||
from tinygrad.uop import Ops
|
||||
from tinygrad.dtype import dtypes, ConstType, PyConst, least_upper_dtype, least_upper_float
|
||||
from tinygrad.dtype import dtypes, ConstType, PyConst, least_upper_dtype, least_upper_float, Invalid
|
||||
from tinygrad.helpers import argfix, polyN
|
||||
from tinygrad.mixin.dtype import DTypeMixin
|
||||
from tinygrad.mixin.creation import CreationMixin
|
||||
|
|
@ -416,7 +416,10 @@ class ElementwiseMixin(DTypeMixin, CreationMixin):
|
|||
def where(self, x: Self | ConstType, y: Self | ConstType) -> Self:
|
||||
ref: Self = x if isinstance(x, type(self)) else y if isinstance(y, type(self)) else \
|
||||
self.cast(least_upper_dtype(dtypes.from_py(x), dtypes.from_py(y)))
|
||||
return self.alu(Ops.WHERE, ref.ufix(x), ref.ufix(y))
|
||||
fx, fy = ref.ufix(x), ref.ufix(y)
|
||||
if getattr(fx, "op", None) is Ops.CONST and getattr(fx, "arg", None) is Invalid and fx.dtype != fy.dtype: fx = fy.ufix(Invalid)
|
||||
if getattr(fy, "op", None) is Ops.CONST and getattr(fy, "arg", None) is Invalid and fy.dtype != fx.dtype: fy = fx.ufix(Invalid)
|
||||
return self.alu(Ops.WHERE, fx, fy)
|
||||
|
||||
def masked_fill(self, mask:Self, value:Self|PyConst) -> Self:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -127,7 +127,7 @@ mop_cleanup = PatternMatcher([
|
|||
(UPat(Ops.RESHAPE, src=(UPat(Ops.RESHAPE, name="x2"), UPat()), name="x"), lambda x,x2: x.replace(src=(x2.src[0], x.src[1]))),
|
||||
])
|
||||
|
||||
pm_gather_params = PatternMatcher([ (UPat(Ops.PARAM, name="p"), lambda ctx, p: ctx.append(p)), ])
|
||||
pm_gather_params = PatternMatcher([ (UPat(Ops.PARAM, name="p"), lambda ctx, p: ctx.append(p) if p.arg.slot >= 0 else None), ])
|
||||
def resolve_function(c:UOp, allow_param_mismatch=True) -> UOp|None:
|
||||
if c.arg.precompile: return None
|
||||
params: list[UOp] = []
|
||||
|
|
@ -533,7 +533,8 @@ to_define_global = PatternMatcher([
|
|||
if v.arg.name is not None and v.arg.vmin_vmax is not None else None),
|
||||
(UPat(Ops.PARAM, name="buf"), lambda ctx, buf:
|
||||
None if isinstance(buf.dtype, PtrDType) or buf.arg.name is not None or buf._shape is None else debuf(ctx, buf)),
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.DEFINE_VAR, name="v"),)), lambda v: v),
|
||||
(UPat(Ops.INDEX, src=(UPat((Ops.DEFINE_VAR, Ops.PARAM), name="v"),)),
|
||||
lambda v: v if v.op is Ops.DEFINE_VAR or v.arg.vmin_vmax is not None else None),
|
||||
|
||||
(UPat(Ops.BIND, name="b"), unbind_kernel),
|
||||
(UPat(Ops.AFTER, name="after"), handle_after),
|
||||
|
|
|
|||
|
|
@ -901,21 +901,22 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
|
|||
# *** uop Variable stuff ***
|
||||
|
||||
@staticmethod
|
||||
def variable(name:str, min_val:ConstType, max_val:ConstType, dtype:DType=dtypes.weakint) -> UOp:
|
||||
assert not isinstance(min_val, UOp) and not isinstance(max_val, UOp), f"can't create Variable {name} with {min_val}/{max_val}"
|
||||
return UOp(Ops.DEFINE_VAR, dtype, arg=(name, min_val, max_val))
|
||||
def variable(name:str, min_val:PyConst, max_val:PyConst, dtype:DType=dtypes.weakint) -> UOp:
|
||||
return UOp(Ops.PARAM, dtype, src=(shape_to_shape_arg((dtype.count,) if dtype.count > 1 else ()),),
|
||||
arg=ParamArg(-1, name=name, vmin_vmax=(min_val, max_val), addrspace=None))
|
||||
@property
|
||||
def expr(self) -> str:
|
||||
if self.op is Ops.PARAM: return unwrap(self.arg.name)
|
||||
assert self.op is Ops.DEFINE_VAR, f"op is {self.op}, need DEFINE_VAR"
|
||||
return self.arg[0]
|
||||
def bind(self, val:int|UOp):
|
||||
assert self.op is Ops.DEFINE_VAR, f"op is {self.op}, need DEFINE_VAR"
|
||||
assert self.op is Ops.PARAM and self.addrspace is None, f"op is {self.op}, need DEFINE_VAR"
|
||||
uval = self.const_like(val) if isinstance(val, int) else val
|
||||
assert self.arg[1] <= uval.vmin and uval.vmax <= self.arg[2], f"bind {val} not in range [{self.arg[1]}, {self.arg[2]}]"
|
||||
assert self.arg.vmin_vmax[0] <= uval.vmin and uval.vmax <= self.arg.vmin_vmax[1], f"bind {val} not in range [{self.arg[1]}, {self.arg[2]}]"
|
||||
return UOp(Ops.BIND, self.dtype, (self, uval))
|
||||
def unbind(self) -> tuple[Variable, int]:
|
||||
assert self.op is Ops.BIND and self.src[0].op is Ops.DEFINE_VAR and self.src[1].op is Ops.CONST, f"can't unbind {self}"
|
||||
assert self.op is Ops.BIND and self.src[0].op is Ops.PARAM and self.src[0].arg.vmin_vmax is not None and \
|
||||
self.src[1].op is Ops.CONST, f"can't unbind {self}"
|
||||
return self.src[0], self.src[1].arg
|
||||
def unbind_all(self) -> tuple[UOp, dict[Variable, int]]:
|
||||
ret:dict[Variable, int] = {}
|
||||
|
|
@ -930,6 +931,7 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
|
|||
|
||||
def is_increasing(self:UOp) -> bool:
|
||||
# is f a monotonically increasing function regards its input
|
||||
if self.op is Ops.PARAM and self.arg.vmin_vmax is not None: return True
|
||||
if self.op in GroupOp.Irreducible: return True
|
||||
if self.op is Ops.ADD: return self.src[0].is_increasing() and self.src[1].is_increasing()
|
||||
if self.op in (Ops.MUL, Ops.CDIV, Ops.FLOORDIV) and self.src[1].op is Ops.CONST and self.src[1].arg >= 0: return self.src[0].is_increasing()
|
||||
|
|
@ -1020,7 +1022,7 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
|
|||
if self.op is Ops.DEFINE_VAR and self.arg: return self.arg[1], self.arg[2]
|
||||
if self.op in (Ops.RANGE, Ops.SPECIAL): return 0, (self.src[0]-1).vmax
|
||||
if self.op is Ops.BIND: return self.src[0]._min_max # ignore the bound value
|
||||
if self.op in {Ops.UNROLL, Ops.STACK}: return min(x.vmin for x in self.src), max(x.vmax for x in self.src)
|
||||
if self.op in {Ops.UNROLL, Ops.STACK}: return (0, 0) if len(self.src) == 0 else (min(x.vmin for x in self.src), max(x.vmax for x in self.src))
|
||||
if self.op is Ops.CONST and self.arg is not Invalid: return self.arg, self.arg
|
||||
if self.op is Ops.GEP: return self.src[0]._min_max
|
||||
# TODO: CAST to bool/unsigned is not monotone, still some case can be simplified
|
||||
|
|
@ -1084,7 +1086,7 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
|
|||
def param_like(self, slot:int):
|
||||
addrspace = self.addrspace if isinstance(self.dtype, (PtrDType, ImageDType)) else AddrSpace.GLOBAL
|
||||
if self.op is Ops.BIND:
|
||||
return UOp.param(slot, self.dtype, self._shape, self.device, cast(tuple[int, int], self._min_max), self.src[0].arg[0], addrspace)
|
||||
return UOp.param(slot, self.dtype, self._shape, self.device, cast(tuple[int, int], self._min_max), self.src[0].expr, addrspace)
|
||||
return UOp.param(slot, self.dtype, self.shard_shape if self.axis is not None else self._shape, self.device, addrspace=addrspace, axis=self.axis)
|
||||
|
||||
# opaque bodies stay as Ops.CALL; value-producing bodies become Ops.FUNCTION (wrapped in TUPLE)
|
||||
|
|
@ -1672,7 +1674,8 @@ pm_lower_index_dtype = PatternMatcher([
|
|||
# special can only be int32
|
||||
(UPat(Ops.SPECIAL, src=(UPat.var("var").cast(dtypes.weakint),), name="u"),
|
||||
lambda u,var: u.replace(dtype=dtypes.int, src=(var,)).cast(dtypes.weakint)),
|
||||
(UPat(Ops.DEFINE_VAR, dtype=dtypes.weakint, name="u"), lambda u: u.replace(dtype=dtypes.int).cast(dtypes.weakint)),
|
||||
(UPat(Ops.PARAM, dtype=dtypes.weakint, name="u"),
|
||||
lambda u: u.replace(dtype=dtypes.int).cast(dtypes.weakint) if u.addrspace is None else None),
|
||||
(UPat(Ops.BIND, src=(UPat.var("var").cast(dtypes.weakint), UPat.cvar("val").cast(dtypes.weakint))),
|
||||
lambda var,val: var.bind(val).cast(dtypes.weakint)),
|
||||
# remove hanging casts
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ def validate_index(uidx:UOp, gate:UOp|None=None):
|
|||
# WEBGPU has a BITCAST in the index, PTX casts pointer to long
|
||||
# VECTORIZE/GEP can't be properly modeled in z3 since it doesn't support vectors
|
||||
for x in idx.toposort() | gate.toposort():
|
||||
if x.op in {Ops.BITCAST, Ops.STACK, Ops.GEP} or (x.op is Ops.CAST and isinstance(x.src[0].dtype, PtrDType)): return True
|
||||
if x.op in {Ops.BITCAST, Ops.GEP} or (x.op is Ops.CAST and isinstance(x.src[0].dtype, PtrDType)): return True
|
||||
|
||||
# if all is good and CHECK_OOB=1, validate with z3
|
||||
from tinygrad.uop.validate import validate_index_with_z3
|
||||
|
|
@ -42,6 +42,8 @@ def type_verify(ast:UOp|list[UOp], check_spec:PatternMatcher):
|
|||
if DEBUG >= 3: print_uops(lst)
|
||||
raise RuntimeError(f"UOp verification failed at {i} on {u.op} {u.dtype} {len(u.src)} {[(x.op, x.dtype, x.arg) for x in u.src]} {u.arg}")
|
||||
|
||||
def is_shape_arg(u:UOp) -> bool: return u.dtype.scalar() in (dtypes.weakint, dtypes.int) or (u.op is Ops.STACK and len(u.src) == 0)
|
||||
|
||||
# ***** new specs *****
|
||||
|
||||
# these ops can be used in the tensor graph and programs
|
||||
|
|
@ -54,6 +56,7 @@ spec_shared = PatternMatcher([
|
|||
# CONST/DEFINE_VAR are everywhere
|
||||
(UPat(Ops.CONST, src=(), name="x"), lambda x: type(x.arg) is type(x.dtype.const(x.arg))),
|
||||
(UPat(Ops.DEFINE_VAR, name="x"), lambda x: len(x.arg) == 3 and isinstance(x.arg[0], str)),
|
||||
(UPat(Ops.STACK, src=(), dtype=dtypes.void), lambda: True),
|
||||
|
||||
# STACK is everywhere too
|
||||
(UPat(Ops.STACK, dtype=dtypes.void, src=()), lambda: True),
|
||||
|
|
@ -133,7 +136,7 @@ spec_tensor = PatternMatcher([
|
|||
lambda buf: isinstance(buf.arg, int) and isinstance(buf.dtype, DType)),
|
||||
|
||||
# Tensor variable bindings
|
||||
(UPat(Ops.BIND, (dtypes.int, dtypes.weakint,), (UPat(Ops.DEFINE_VAR), UPat.cvar(dtype=(dtypes.int,dtypes.weakint,))), arg=None), lambda: True),
|
||||
(UPat(Ops.BIND, (dtypes.int, dtypes.weakint,), (UPat(Ops.PARAM), UPat.cvar(dtype=(dtypes.int,dtypes.weakint,))), arg=None), lambda: True),
|
||||
|
||||
# custom function
|
||||
(UPat(Ops.CUSTOM_FUNCTION, name="x"), lambda x: isinstance(x.arg, str)),
|
||||
|
|
@ -195,6 +198,9 @@ spec_program = PatternMatcher([
|
|||
# no more of these in programs
|
||||
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.DEFINE_VAR, Ops.GEP)), lambda: False),
|
||||
|
||||
# scalar shape metadata
|
||||
(UPat(Ops.STACK, src=(), dtype=dtypes.void), lambda: True),
|
||||
|
||||
# weakint is not allowed in programs
|
||||
(UPat(GroupOp.All, dtypes.weakint), lambda: False),
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
# all of symbolic lives here now
|
||||
import math, struct
|
||||
from typing import cast
|
||||
from collections import defaultdict
|
||||
from tinygrad.uop.ops import Ops, PatternMatcher, UPat, UOp, GroupOp, exec_alu
|
||||
from tinygrad.dtype import ConstType, dtypes, PtrDType, can_lossless_cast, Invalid
|
||||
from tinygrad.dtype import ConstType, PyConst, dtypes, PtrDType, can_lossless_cast, Invalid
|
||||
from tinygrad.helpers import partition, all_same, prod, flatten, get_single_element, unwrap, IMAGE, dedup
|
||||
from tinygrad.uop.decompositions import threefry2x32, xpow
|
||||
from tinygrad.uop.divandmod import div_and_mod_symbolic
|
||||
|
|
@ -186,7 +187,7 @@ def canonicalize_simplex(X:UOp) -> UOp|None:
|
|||
if u.op is Ops.MUL and u.src[1].op is Ops.CONST and u.src[1].arg > 0:
|
||||
changed = True
|
||||
u = u.src[0]
|
||||
if not (u.op in GroupOp.Irreducible and u.vmin >= 0): return None
|
||||
if not ((u.op in GroupOp.Irreducible or (u.op is Ops.PARAM and u.arg.vmin_vmax is not None)) and u.vmin >= 0): return None
|
||||
ret.append(u)
|
||||
return UOp.usum(*ret) if changed else None
|
||||
|
||||
|
|
@ -259,8 +260,8 @@ symbolic = symbolic_simple+commutative+PatternMatcher([
|
|||
((UPat.var("y")+UPat.var("c").where(UPat.var("t"), UPat.var("f"))) + UPat.var("c").where(UPat.var("tt"), UPat.var("ff")), \
|
||||
lambda y,c,t,tt,f,ff: y+c.where(t+tt, f+ff) if t.op == tt.op == Ops.CONST or f.op == ff.op == Ops.CONST else None),
|
||||
# ALU/variable min==max -> CONST
|
||||
(UPat({Ops.CMPLT, Ops.CMPNE, Ops.FLOORDIV, Ops.FLOORMOD, Ops.DEFINE_VAR, Ops.BIND, Ops.SPECIAL}, name="x"),
|
||||
lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None),
|
||||
(UPat({Ops.CMPLT, Ops.CMPNE, Ops.FLOORDIV, Ops.FLOORMOD, Ops.DEFINE_VAR, Ops.PARAM, Ops.BIND, Ops.SPECIAL}, name="x"),
|
||||
lambda x: x.const_like(x.vmin) if x.vmin == x.vmax and (x.op is not Ops.PARAM or x.addrspace is None) else None),
|
||||
(UPat(Ops.RANGE, src=(UPat(Ops.CONST,)), name="x"), lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None),
|
||||
# max folding
|
||||
(UPat.maximum(UPat.var("x"), UPat.var("y")), lambda x,y: x if x.vmin >= y.vmax else y if x.vmax <= y.vmin else None),
|
||||
|
|
@ -343,7 +344,7 @@ def uop_given_valid(valid:UOp, uop:UOp, try_simplex=True) -> UOp:
|
|||
for i,(expr,v) in enumerate(bounds.items()):
|
||||
v0, v1 = (expr.vmin if v[0] is None else v[0], expr.vmax if v[1] is None else v[1])
|
||||
# try checking the whole clause
|
||||
all_candidates.append((expr, UOp.variable(f"fake{i}", v0, v1, expr.dtype)))
|
||||
all_candidates.append((expr, UOp.variable(f"fake{i}", cast(PyConst, v0), cast(PyConst, v1), expr.dtype)))
|
||||
|
||||
if try_simplex:
|
||||
# every candidate is a set of constrained UOp based on valid, and if every item in a set simplifies the uop into a same output, we rewrite uop
|
||||
|
|
|
|||
|
|
@ -28,6 +28,8 @@ z3_renderer = PatternMatcher([
|
|||
(UPat.var("cond").where(UPat.var("x"), UPat.const(dtypes.weakint, Invalid)), lambda x,cond,ctx: (ctx[1][x], ctx[1][cond])),
|
||||
# variables
|
||||
(UPat(Ops.SPECIAL, name="x"), lambda x,ctx: create_bounded(x.arg, 0, ctx[1][x.src[0]]-1, ctx[0])),
|
||||
(UPat(Ops.PARAM, name="x"), lambda x,ctx:
|
||||
create_bounded(x.expr, x.arg.vmin_vmax[0], x.arg.vmin_vmax[1], ctx[0]) if x.arg.vmin_vmax is not None else None),
|
||||
(UPat(Ops.DEFINE_VAR, name="x"), lambda x,ctx: create_bounded(x.arg[0], x.arg[1], x.arg[2], ctx[0])),
|
||||
(UPat(Ops.RANGE, name="x"), lambda x,ctx: create_bounded(x.render(simplify=False), 0, ctx[1][x.src[0]]-1, ctx[0])),
|
||||
# loads are variables bounded by the min/max of the dtype. non-pointer INDEX is also a LOAD
|
||||
|
|
@ -52,8 +54,10 @@ z3_renderer = PatternMatcher([
|
|||
|
||||
def uops_to_z3(solver:z3.Solver, *uops: UOp) -> list[z3.ExprRef]:
|
||||
# gate on upstream AFTER/BUFFER as a replacement for PtrDType, but keep INDEX as an unknown LOAD
|
||||
lst = list(UOp.sink(*uops).toposort(gate=lambda x: x.op not in {Ops.AFTER, Ops.BUFFER, Ops.PARAM} and \
|
||||
(x.dtype.scalar() in dtypes.ints+(dtypes.bool, dtypes.weakint) or x.op is Ops.SINK)))[:-1]
|
||||
raw = list(UOp.sink(*uops).toposort(gate=lambda x: x.op not in {Ops.AFTER, Ops.BUFFER} and \
|
||||
not (x.op is Ops.PARAM and x.arg.vmin_vmax is None) and (x.dtype.scalar() in dtypes.ints+(dtypes.bool, dtypes.weakint) or x.op is Ops.SINK)))[:-1]
|
||||
param_shape_args = {p.src[0] for p in raw if p.op is Ops.PARAM and p.arg.vmin_vmax is not None and len(p.src) == 1}
|
||||
lst = [u for u in raw if u not in param_shape_args]
|
||||
z3map: dict[UOp, z3.ExprRef] = {}
|
||||
for u in lst:
|
||||
z3_rewritten = z3_renderer.rewrite(u, ctx=(solver.ctx, z3map))
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue