Remove Ops.VCONST (#16267)

* start removing vconst

* remove a lot of vconst

* const folding + strict ordering

* update tests

* spec from minigen

* move that
This commit is contained in:
George Hotz 2026-05-19 16:35:24 -07:00 committed by GitHub
commit 55515747b7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 111 additions and 86 deletions

View file

@ -127,10 +127,9 @@ class TestBitcastConstFolding(unittest.TestCase):
def test_vec_bitcast(self):
with Context(SPEC=0):
r = full_rewrite(UOp.const(dtypes.int32.vec(3), (-1, -2**31, 75)).bitcast(dtypes.uint32.vec(3)).sink()).src[0]
self.assertEqual(r.op, Ops.STACK)
self.assertEqual(r.dtype, dtypes.uint32.vec(3))
self.assertEqual(tuple(x.arg for x in r.src), (2**32-1, 2**31, 75))
srcs = full_rewrite(UOp.const(dtypes.int32.vec(3), (-1, -2**31, 75)).bitcast(dtypes.uint32.vec(3)).sink()).src
self.assertTrue(all(r.op is Ops.CONST and r.dtype == dtypes.uint32 for r in srcs))
self.assertEqual(tuple(x.arg for x in srcs), (2**32-1, 2**31, 75))
# folds advance indexing into basic indexing
class TestIndexingConstFolding(unittest.TestCase):

View file

@ -10,6 +10,14 @@ from hypothesis import given, strategies as strat
def apply_rewrite(expr):
return full_rewrite(expr.sink()).src[0]
@Context(SPEC=0)
def apply_rewrite_values(expr):
srcs = full_rewrite(expr.sink()).src
if len(srcs) == 1:
if srcs[0].op is Ops.CONST: return (srcs[0].arg,)*srcs[0].dtype.count
if srcs[0].op is Ops.STACK: return tuple(s.arg for s in srcs[0].src)
return tuple(s.arg for s in srcs)
def evaluate_uop(uop, variables):
if uop.op == Ops.CONST:
return uop.arg
@ -121,7 +129,7 @@ class TestModuloAndDivisionFolding(unittest.TestCase):
def test_graph_rewrite_div_folding_bug(self):
lhs = UOp(Ops.ADD, dtypes.int.vec(4), src=(
UOp(Ops.STACK, dtypes.int.vec(4), arg=None, src=(UOp(Ops.SPECIAL, dtypes.int, arg='lidx0', src=(UOp.const(dtypes.int, 32),)),)*4),
UOp(Ops.VCONST, dtypes.int.vec(4), arg=(0, 256, 512, 768), src=())))
UOp.const(dtypes.int.vec(4), (0, 256, 512, 768))))
rhs = UOp.const(dtypes.int.vec(4), 2)
unopt = lhs<rhs
opt = apply_rewrite(unopt)
@ -180,19 +188,17 @@ class TestGEPAndVectorizeRewrite(unittest.TestCase):
def test_gep_tuple_extraction(self):
# GEP on a vector dtype to extract multiple elements as a vector
base_vector = UOp.const(dtypes.float32.vec(4), (1.0, 2.0, 3.0, 4.0))
optimized_uop = apply_rewrite(base_vector.gep((2, 3)))
self.assertEqual([sub_uop.arg for sub_uop in optimized_uop.src], [3.0, 4.0])
self.assertEqual(list(apply_rewrite_values(base_vector.gep((2, 3)))), [3.0, 4.0])
def test_gep_on_vconst(self):
# GEP on a VCONST to extract a single element
vconst = UOp(Ops.VCONST, dtypes.float32.vec(4), arg=(1.0, 2.0, 3.0, 4.0))
self.assertEqual(apply_rewrite(vconst.gep(2)).arg, 3.0)
def test_gep_on_const_stack(self):
# GEP on a const STACK to extract a single element
const_stack = UOp.const(dtypes.float32.vec(4), (1.0, 2.0, 3.0, 4.0))
self.assertEqual(apply_rewrite(const_stack.gep(2)).arg, 3.0)
def test_gep_tuple_on_vconst(self):
# GEP on a VCONST using a tuple to extract multiple elements
vconst = UOp(Ops.VCONST, dtypes.float32.vec(4), arg=(7.0, 8.0, 9.0, 10.0))
optimized_uop = apply_rewrite(vconst.gep((1, 3)))
self.assertEqual([sub_uop.arg for sub_uop in optimized_uop.src], [8.0, 10.0])
def test_gep_tuple_on_const_stack(self):
# GEP on a const STACK using a tuple to extract multiple elements
const_stack = UOp.const(dtypes.float32.vec(4), (7.0, 8.0, 9.0, 10.0))
self.assertEqual(list(apply_rewrite_values(const_stack.gep((1, 3)))), [8.0, 10.0])
def test_gep_gep_simplification(self):
# Nested GEP simplification on a vector dtype
@ -204,8 +210,7 @@ class TestGEPAndVectorizeRewrite(unittest.TestCase):
# Vectorizing multiple elements using GEP
base_vector = UOp.const(dtypes.float32.vec(4), (5.0, 10.0, 15.0, 20.0))
vectorized_uop = UOp(Ops.STACK, dtypes.float32.vec(4), src=(base_vector.gep(0), base_vector.gep(1), base_vector.gep(2), base_vector.gep(3)))
optimized_uop = apply_rewrite(vectorized_uop)
self.assertEqual([sub_uop.arg for sub_uop in optimized_uop.src], [5.0, 10.0, 15.0, 20.0])
self.assertEqual(list(apply_rewrite_values(vectorized_uop)), [5.0, 10.0, 15.0, 20.0])
import inspect

View file

@ -14,6 +14,11 @@ simple_pm = PatternMatcher([
((UPat.var('x') + UPat.cvar('c1')) + UPat.cvar('c2'), lambda x,c1,c2: x + (c1.arg+c2.arg)),
])
def const_values(u:UOp):
if u.op is Ops.CONST: return (u.arg,)*u.dtype.count
if u.op is Ops.STACK: return tuple(x.arg for x in u.src)
raise AssertionError(f"expected const-like UOp, got {u.op}")
class TestGraphRewriteConst(unittest.TestCase):
def test_gep_const(self):
v1 = UOp.const(dtypes.int.vec(3), (0,1,2))
@ -33,9 +38,9 @@ class TestGraphRewriteConst(unittest.TestCase):
v1 = UOp.const(dtypes.int.vec(3), (0,1,2))
v2 = UOp.const(dtypes.int.vec(3), (5,6,7))
ret = graph_rewrite(v1+v2, sym)
self.assertEqual(ret.op, Ops.VCONST)
self.assertEqual(ret.op, Ops.STACK)
self.assertEqual(ret.dtype, dtypes.int.vec(3))
self.assertEqual(ret.arg, (5,7,9))
self.assertEqual(const_values(ret), (5,7,9))
def test_add_const_lose_v(self):
v1 = UOp.const(dtypes.int.vec(3), (0,1,2))
@ -587,58 +592,61 @@ class TestExpander(unittest.TestCase):
def test_expand_add_broadcast(self):
e1 = UOp(Ops.UNROLL, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(x for x in range(4))),), ((1,4),))
sink = expander_rewrite(e1+3)
assert sink.op is Ops.UNROLL and len(sink.src[0].arg) == 4
self.assertTupleEqual(sink.src[0].arg, (3,4,5,6))
assert sink.op is Ops.UNROLL and len(const_values(sink.src[0])) == 4
self.assertTupleEqual(const_values(sink.src[0]), (3,4,5,6))
def test_contract_simple(self):
e1 = UOp(Ops.UNROLL, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(x for x in range(4))),), ((1,4),))
con = UOp(Ops.CONTRACT, dtypes.int.vec(4), (e1,), ((1,4),))
sink = expander_rewrite(con)
self.assertEqual(sink.op, Ops.VCONST)
self.assertTupleEqual(sink.arg, (0,1,2,3))
self.assertEqual(sink.op, Ops.STACK)
self.assertTupleEqual(const_values(sink), (0,1,2,3))
def test_contract_axis_1(self):
e1 = UOp(Ops.UNROLL, dtypes.int, (UOp.const(dtypes.int.vec(16), tuple(x for x in range(16))),), ((1,4),(2,4)))
con = UOp(Ops.CONTRACT, dtypes.int.vec(4), (e1,), ((1,4),))
sink = expander_rewrite(con)
assert sink.op is Ops.UNROLL and len(sink.src[0].arg) == 16 and sink.arg == ((2,4),)
assert sink.src[0].op is Ops.VCONST
self.assertTupleEqual(sink.src[0].arg[0:4], (0,4,8,12))
self.assertTupleEqual(sink.src[0].arg[12:], (3,7,11,15))
vals = const_values(sink.src[0])
assert sink.op is Ops.UNROLL and len(vals) == 16 and sink.arg == ((2,4),)
assert sink.src[0].op is Ops.STACK
self.assertTupleEqual(vals[0:4], (0,4,8,12))
self.assertTupleEqual(vals[12:], (3,7,11,15))
def test_contract_axis_2(self):
e1 = UOp(Ops.UNROLL, dtypes.int, (UOp.const(dtypes.int.vec(16), tuple(x for x in range(16))),), ((1,4),(2,4)))
con = UOp(Ops.CONTRACT, dtypes.int.vec(4), (e1,), ((2,4),))
sink = expander_rewrite(con)
assert sink.op is Ops.UNROLL and len(sink.src[0].arg) == 16 and sink.arg == ((1,4),)
assert sink.src[0].op is Ops.VCONST
self.assertTupleEqual(sink.src[0].arg[0:4], (0,1,2,3))
self.assertTupleEqual(sink.src[0].arg[12:], (12,13,14,15))
vals = const_values(sink.src[0])
assert sink.op is Ops.UNROLL and len(vals) == 16 and sink.arg == ((1,4),)
assert sink.src[0].op is Ops.STACK
self.assertTupleEqual(vals[0:4], (0,1,2,3))
self.assertTupleEqual(vals[12:], (12,13,14,15))
def test_contract_axis_2_big(self):
e1 = UOp(Ops.UNROLL, dtypes.int, (UOp.const(dtypes.int.vec(16), tuple(x for x in range(16))),), ((1,2),(2,2),(3,2),(4,2)))
con = UOp(Ops.CONTRACT, dtypes.int.vec(2), (e1,), ((2,2),))
sink = expander_rewrite(con)
assert sink.op is Ops.UNROLL and sink.arg == ((1, 2), (3, 2), (4, 2))
self.assertTupleEqual(sink.src[0].arg[0:2], (0,4))
self.assertTupleEqual(sink.src[0].arg[12:14], (10,14))
vals = const_values(sink.src[0])
self.assertTupleEqual(vals[0:2], (0,4))
self.assertTupleEqual(vals[12:14], (10,14))
def test_contract_multi_axis(self):
e1 = UOp(Ops.UNROLL, dtypes.int, (UOp.const(dtypes.int.vec(16), tuple(x for x in range(16))),), ((1,2),(2,2),(3,2),(4,2)))
sink = expander_rewrite(UOp(Ops.CONTRACT, dtypes.int.vec(4), (e1,), ((3, 2), (2, 2))))
assert sink.op is Ops.UNROLL and sink.arg == ((1, 2), (4, 2))
self.assertTupleEqual(sink.src[0].arg[0:4], (0, 4, 2, 6))
self.assertTupleEqual(const_values(sink.src[0])[0:4], (0, 4, 2, 6))
sink = expander_rewrite(UOp(Ops.CONTRACT, dtypes.int.vec(4), (e1,), ((2, 2), (3, 2))))
assert sink.op is Ops.UNROLL and sink.arg == ((1, 2), (4, 2))
self.assertTupleEqual(sink.src[0].arg[0:4], (0, 2, 4, 6))
self.assertTupleEqual(const_values(sink.src[0])[0:4], (0, 2, 4, 6))
def test_contract_mid(self):
e1 = UOp(Ops.UNROLL, dtypes.int, (UOp.const(dtypes.int.vec(8), tuple(x for x in range(8))),), ((1,2),(2,2),(3,2)))
con = UOp(Ops.CONTRACT, dtypes.int.vec(2), (e1,), ((2,2),))
sink = expander_rewrite(con)
assert sink.op is Ops.UNROLL and sink.arg == ((1,2),(3,2))
assert sink.src[0].op is Ops.VCONST and len(sink.src[0].arg) == 8
self.assertTupleEqual(sink.src[0].arg, (0,2,1,3,4,6,5,7))
assert sink.src[0].op is Ops.STACK and len(const_values(sink.src[0])) == 8
self.assertTupleEqual(const_values(sink.src[0]), (0,2,1,3,4,6,5,7))
def test_contract_no_expand(self):
e1 = UOp.variable("i", 0, 10, dtype=dtypes.int)
@ -651,26 +659,28 @@ class TestExpander(unittest.TestCase):
e1 = UOp(Ops.UNROLL, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(x for x in range(4))),), ((1,4),))
con = UOp(Ops.CONTRACT, dtypes.int.vec(8), (e1,), ((1,4), (2,2)))
sink = expander_rewrite(con)
assert sink.op is Ops.VCONST and len(sink.arg) == 8
assert sink.arg[0] == sink.arg[1]
assert sink.arg[0] != sink.arg[2]
assert sink.arg[6] == sink.arg[7]
vals = const_values(sink)
assert sink.op is Ops.STACK and len(vals) == 8
assert vals[0] == vals[1]
assert vals[0] != vals[2]
assert vals[6] == vals[7]
def test_expand_same_axis(self):
e1 = UOp(Ops.UNROLL, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(x for x in range(4))),), ((1,4),))
e2 = UOp(Ops.UNROLL, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(4*x for x in range(4))),), ((1,4),))
sink = expander_rewrite(e1+e2)
self.assertEqual(sink.op, Ops.UNROLL)
self.assertEqual(sink.src[0].op, Ops.VCONST)
self.assertTupleEqual(sink.src[0].arg, (0,5,10,15))
self.assertEqual(sink.src[0].op, Ops.STACK)
self.assertTupleEqual(const_values(sink.src[0]), (0,5,10,15))
def test_expand_different_axis(self, flip=False):
e1 = UOp(Ops.UNROLL, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(4*x for x in range(4))),), ((1,4),))
e2 = UOp(Ops.UNROLL, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(x for x in range(4))),), ((2,4),))
sink = expander_rewrite((e2+e1) if flip else (e1+e2))
assert sink.op is Ops.UNROLL and len(sink.src[0].arg) == 16
vals = const_values(sink.src[0])
assert sink.op is Ops.UNROLL and len(vals) == 16
assert sink.arg == ((1, 4), (2, 4))
self.assertTupleEqual(sink.src[0].arg, (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15))
self.assertTupleEqual(vals, (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15))
def test_expand_different_axis_flip(self): self.test_expand_different_axis(True)

View file

@ -361,7 +361,7 @@ class TestUOpRender(unittest.TestCase):
self.assertEqual(u.render(), "0")
def test_render_vectorize_different_simplified(self):
u = UOp(Ops.STACK, dtype=dtypes.int.vec(3), src=(UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1), UOp.const(dtypes.int, 2)))
self.assertEqual(u.render(), "(0, 1, 2)")
self.assertEqual(u.render(), "{0,1,2}")
if __name__ == '__main__':
unittest.main()

View file

@ -41,7 +41,7 @@ class TestUPatCompile(unittest.TestCase):
do_compile(up)
def test_const_folding(self):
up = UPat(GroupOp.ALU-{Ops.THREEFRY}, name="a", src=UPat((Ops.VCONST, Ops.CONST)))
up = UPat(GroupOp.ALU-{Ops.THREEFRY}, name="a", src=UPat((Ops.CONST, Ops.STACK)))
do_compile(up)
@unittest.skip("fix this")

View file

@ -273,7 +273,6 @@ pm_render = PatternMatcher([
# for rendering, we use explicit VECTORIZE
(UPat(Ops.CONST, name='c'),
lambda c: UOp(Ops.STACK, c.dtype, (UOp.const(c.dtype.scalar(), c.arg),)*c.dtype.vcount) if c.dtype.vcount > 1 else None),
(UPat(Ops.VCONST, name='c'), lambda c: UOp(Ops.STACK, c.dtype, tuple(UOp.const(c.dtype.scalar(), x) for x in c.arg))),
(UPat(Ops.GEP, name='gep'), lambda gep: UOp(Ops.STACK, gep.dtype, tuple(gep.src[0].gep(x) for x in gep.arg)) if len(gep.arg) > 1 else None),
(UPat(Ops.GEP, name='gep'), lambda gep: gep.src[0] if gep.src[0].dtype.vcount == 1 and gep.arg == (0,) else None),
(UPat(Ops.STACK, src=(UPat(name='x'),)), lambda x: x),

View file

@ -16,6 +16,6 @@ pm_move_gates_from_index = PatternMatcher([
l.replace(src=(l.src[0], a.src[0] if a.op is Ops.CAST and a.src[0].dtype == l.dtype else a.cast(l.dtype), l.src[2])).cast(a.dtype)),
# images use 2D INDEX now (y,x)
(UPat(Ops.INDEX, src=(UPat(), UPat((Ops.CONST, Ops.VCONST, Ops.STACK), name="vec")), name="idx"),
(UPat(Ops.INDEX, src=(UPat(), UPat((Ops.CONST, Ops.STACK), name="vec")), name="idx"),
lambda idx,vec: idx.replace(src=(idx.src[0], vec.gep(1).cast(dtypes.int), vec.gep(0).cast(dtypes.int))) if vec.dtype.count == 2 else None),
])

View file

@ -134,7 +134,7 @@ def reduce_collapse(red:UOp, u:UOp, pm:PatternMatcher=pm_reduce_collapse) -> UOp
replaces: dict[UOp, UOp] = {}
for u in included:
for s in u.src:
if s in included or s in replaces or s.op in {Ops.CONST, Ops.VCONST, Ops.PARAM, Ops.DEFINE_LOCAL, Ops.DEFINE_VAR}: continue
if s in included or s in replaces or s.op in {Ops.CONST, Ops.PARAM, Ops.DEFINE_LOCAL, Ops.DEFINE_VAR}: continue
replaces[s] = UOp.variable(f'in{len(replaces)}', s.vmin, s.vmax, s.dtype)
collapse_fxn = u.substitute(replaces).reduce(r, arg=Ops.ADD)
sink = graph_rewrite(collapse_fxn, pm, name="reduce_collapse")

View file

@ -78,8 +78,8 @@ class Ops(FastEnum):
# control flow ops
BARRIER = auto(); RANGE = auto(); IF = auto(); END = auto(); ENDIF = auto(); WAIT = auto()
# consts. VCONST is a vectorized const
VCONST = auto(); CONST = auto()
# const.
CONST = auto()
# CUSTOM/CUSTOMI are used to output strings into codegen. the I makes the string inline
CUSTOM = auto(); CUSTOMI = auto()

View file

@ -65,7 +65,7 @@ def fold_divmod_general(d: UOp) -> UOp|None:
# nest_by_factor: x//c -> (x//f)//(c//f), x%c -> (x//f%(c//f))*f + b where b=x%f
# FLOORDIV identity holds for any sign of x; FLOORMOD reconstruction needs x.vmin>=0
results = []
for div in {abs(f) for u, f in zip(uops_no_const, factors) if u.op not in (Ops.CONST, Ops.VCONST) and 1 < abs(f) < c and (c%f)==0}:
for div in {abs(f) for u, f in zip(uops_no_const, factors) if u.op is not Ops.CONST and 1 < abs(f) < c and (c%f)==0}:
if (newxs := fold_divmod_general(x//div)) is not None:
if d.op is Ops.FLOORDIV:
results.append((len(newxs.backward_slice), newxs // (c // div)))

View file

@ -251,7 +251,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
# TODO: these should have the shape of the dtype.count
case Ops.CONST | Ops.DEFINE_VAR: return ()
case Ops.GEP | Ops.STACK | Ops.VCONST | Ops.VCAT | Ops.GETADDR: return ()
case Ops.GEP | Ops.STACK | Ops.VCAT | Ops.GETADDR: return ()
# some ops init the shape
case Ops.BIND | Ops.RANGE | Ops.SPECIAL | Ops.UNROLL: return ()
@ -395,8 +395,8 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
# *** uop evaluation ***
def simplify(self, tracked=False):
if self.op in {Ops.CONST, Ops.VCONST}: return self
if self.op is Ops.SINK and all(s.op in {Ops.CONST, Ops.VCONST} or (s.op is Ops.STACK and len(s.src) == 0) for s in self.src): return self
if self.op is Ops.CONST: return self
if self.op is Ops.SINK and all(s.op is Ops.CONST or (s.op is Ops.STACK and len(s.src) == 0) for s in self.src): return self
# late import!
from tinygrad.uop.symbolic import symbolic
with Context(TRACK_MATCH_STATS=0 if not tracked else TRACK_MATCH_STATS.value):
@ -482,7 +482,6 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
if isinstance(i, int):
# NOTE: these are just shortcuts to not have to create and fold later
if self.op is Ops.STACK: return self.src[i]
if self.op is Ops.VCONST: return UOp.const(self.dtype.scalar(), self.arg[i])
if self.op is Ops.CONST: return UOp.const(self.dtype.scalar(), self.arg)
i = (i,)
return UOp(Ops.GEP, self.dtype.scalar().vec(len(i)) if len(i) > 1 else self.dtype.scalar(), (self,), i)
@ -513,10 +512,12 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
if isinstance(b, UOp): return b.unbind()[0] if b.op is Ops.BIND else b
if isinstance(b, tuple) and all_same(b):
assert len(b) > 0, "can't create const from empty tuple"
b = b[0] # doesn't have to be a VCONST if they are all the same
ret = UOp(Ops.VCONST if isinstance(b, tuple) else Ops.CONST, dtype,
arg=dtype.const(b),
src=(UOp(Ops.DEVICE, arg=device),) if device is not None else ())
b = b[0] # doesn't have to be a STACK if they are all the same
if isinstance(b, tuple):
stk = [UOp(Ops.CONST, dtype.scalar(), arg=dtype.const(c), src=(UOp(Ops.DEVICE, arg=device),) if device is not None else ()) for c in b]
ret = UOp.vectorize(*stk)
else:
ret = UOp(Ops.CONST, dtype, arg=dtype.const(b), src=(UOp(Ops.DEVICE, arg=device),) if device is not None else ())
return ret.reshape((1,)*len(shape)).expand(shape) if shape is not None and ret.shape != shape else ret
@staticmethod
def unique_const(fill_value:ConstType, dtype:DTypeLike|None=None, device:str|tuple[str, ...]|None=None, # type: ignore[override]
@ -656,7 +657,6 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
def sgep(self, i:int) -> sint:
match self.op:
case Ops.CONST: return self.arg
case Ops.VCONST: return self.arg[i]
case Ops.STACK: return self.src[i].sintify()
case _: raise RuntimeError(f"no sgep on {self.op}")
@ -850,14 +850,16 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
"""largest known int that divides self"""
# TODO: for negatives it's not the largest
if self.op is Ops.CONST: return self.arg
if self.op is Ops.VCONST: return math.gcd(*self.arg)
if self.op is Ops.STACK: return math.gcd(*[x.const_factor() for x in self.src])
if self.op is Ops.ADD: return math.gcd(self.src[0].const_factor(), self.src[1].const_factor())
if self.op is Ops.MUL: return self.src[0].arg if self.src[0].op is Ops.CONST else self.src[1].arg if self.src[1].op is Ops.CONST else 1
return 1
def divides(self, v:int) -> UOp|None:
if v==1: return self
if self.op is Ops.CONST: return self.const_like(self.arg//v) if self.arg%v == 0 else None
if self.op is Ops.VCONST: return self.const_like(tuple(x//v for x in self.arg)) if all(x%v == 0 for x in self.arg) else None
if self.op is Ops.STACK:
srcs = tuple(s.divides(v) for s in self.src)
return None if any(s is None for s in srcs) else UOp(Ops.STACK, self.dtype, cast(tuple[UOp, ...], srcs))
if self.op is Ops.ADD: return d0+d1 if (d0:=self.src[0].divides(v)) is not None and (d1:=self.src[1].divides(v)) is not None else None
if self.op is Ops.MUL:
if (d0:=self.src[0].divides(v)) is not None: return d0 * self.src[1]
@ -931,7 +933,6 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
if self.op is Ops.BIND: return self.src[0]._min_max # ignore the bound value
if self.op in {Ops.UNROLL, Ops.STACK}: return min(x.vmin for x in self.src), max(x.vmax for x in self.src)
if self.op is Ops.CONST and self.arg is not Invalid: return self.arg, self.arg
if self.op is Ops.VCONST and Invalid not in self.arg: return (min(self.arg), max(self.arg))
if self.op is Ops.GEP: return self.src[0]._min_max
# TODO: CAST to bool/unsigned is not monotone, still some case can be simplified
if self.op is Ops.CAST and self.dtype in dtypes.floats+dtypes.sints+(dtypes.weakint,):
@ -1185,7 +1186,7 @@ class UPat(OpMixin):
@staticmethod
@functools.cache
def cvar(name:str|None=None, dtype:DType|tuple[DType, ...]|None=None, vec=True, arg=None):
return UPat((Ops.CONST,Ops.VCONST) if vec else Ops.CONST, dtype, name=name, arg=arg)
return UPat(Ops.CONST, dtype, name=name, arg=arg)
@staticmethod
def const(dtype:DType|tuple[DType, ...]|None, b:ConstType): return UPat(Ops.CONST, dtype=dtype, arg=b)
@ -1572,8 +1573,7 @@ pm_lower_index_dtype = PatternMatcher([
# There are no Unary ops at this point in symbolic, those are introduced later
(UPat(GroupOp.Binary, name="u", src=(UPat.var("x").cast(dtypes.weakint), UPat.var("y").cast(dtypes.weakint))), lambda u,x,y:
x.cast(dt:=least_upper_dtype(select_dtype(u), x.dtype, y.dtype)).alu(u.op, y.cast(dt)).cast(u.dtype)),
(UPat((Ops.CONST, Ops.VCONST), dtype=dtypes.weakint, name="u"),
lambda u: u.replace(dtype=select_dtype(u)).cast(u.dtype) if u.arg!=Invalid else None),
(UPat(Ops.CONST, dtype=dtypes.weakint, name="u"), lambda u: u.replace(dtype=select_dtype(u)).cast(u.dtype) if u.arg!=Invalid else None),
(UPat(Ops.WHERE, dtypes.weakint, src=(UPat.var("cond"), UPat.var("x").cast(dtypes.weakint), UPat.var("y").cast(dtypes.weakint))), lambda cond,x,y:
cond.where(x.cast(dt:=least_upper_dtype(x.dtype, y.dtype)), y.cast(dt)).cast(dtypes.weakint)),
(UPat(Ops.RANGE, src=(UPat.var("end").cast(dtypes.weakint)), name="r"), lambda r,end: r.replace(dtype=end.dtype, src=(end,)).cast(dtypes.weakint)),

View file

@ -37,7 +37,7 @@ renderer = PatternMatcher([
(UPat(Ops.PARAM, src=(UPat(), UPat(), UPat(), UPat(), UPat(Ops.NOOP, name="x"))), lambda x: x.arg),
(UPat((Ops.SPECIAL), name="x"), lambda x: x.arg),
(UPat(Ops.RANGE, name="x"), lambda x: f"r{range_str(x)}"),
(UPat((Ops.CONST, Ops.VCONST), name="x"), lambda x: str(x.arg)),
(UPat(Ops.CONST, name="x"), lambda x: str(x.arg)),
(UPat(Ops.UNROLL, name="x"), lambda ctx,x,u: f"UNROLL({ctx[x.src[0]]}, {u.arg})"),
(UPat(Ops.CAST, name="x"), lambda ctx,x: f"({str(x.dtype)[7:]})({ctx[x.src[0]]})"),
(UPat(Ops.BIND, name="x"), lambda ctx,x: ctx[x.src[0]]),
@ -146,7 +146,7 @@ def pyrender(ast:UOp) -> str:
lst = list(ast.toposort())
cmap = consumer_map_from_toposort(lst)
not_rendered = {Ops.CONST, Ops.VCONST, Ops.DEVICE}
not_rendered = {Ops.CONST, Ops.DEVICE}
always_rendered = {Ops.PARAM, Ops.LOAD, Ops.SPECIAL, Ops.RANGE, Ops.CONTIGUOUS, Ops.STACK,
Ops.BUFFER, Ops.COPY, Ops.CALL, Ops.FUNCTION, Ops.WHERE, Ops.END}

View file

@ -153,7 +153,7 @@ spec_tensor = PatternMatcher([
(UPat(Ops.PARAM, src=(UPat(), UPat(Ops.DEVICE), UPat(Ops.MULTI)), name="x"), lambda x: True),
# inputs to movement ops
(UPat((Ops.STACK, Ops.VCONST)), lambda: True),
(UPat(Ops.STACK), lambda: True),
(UPat({Ops.ADD, Ops.MUL, Ops.CDIV, Ops.FLOORDIV}, dtype=dtypes.weakint), lambda: True),
# movement ops
@ -198,6 +198,12 @@ spec_tensor = PatternMatcher([
# these ops can exist in programs but not the tensor spec. example: LOAD
spec_program = PatternMatcher([
# weakint is not allowed in programs
(UPat(GroupOp.All, dtypes.weakint), lambda: False),
# Invalid is not allowed in program
(UPat(Ops.CONST, arg=Invalid), lambda: False),
# STACK/GEP in program. TODO: this should match Tensor
(UPat(Ops.STACK, name="x"), lambda x: len(x.src)>1 and len(x.src) == x.dtype.vcount and all(x.dtype == y.dtype.vec(len(x.src)) for y in x.src)),
(UPat(Ops.GEP, src=(UPat.var("src"),), name="gep"), lambda gep,src: gep.dtype == src.dtype.scalar()),

View file

@ -22,6 +22,15 @@ def fold_bitcast(root:UOp, c:UOp) -> UOp|None:
def convert(v:ConstType) -> ConstType: return struct.unpack(to_fmt, struct.pack(from_fmt, v))[0]
return root.const_like(convert(c.arg) if root.dtype.count == 1 else tuple(map(convert, c.arg)))
def const_arg(u:UOp) -> ConstType|tuple[ConstType, ...]|None:
if u.op is Ops.CONST: return u.arg
if u.op is Ops.STACK and all(s.op is Ops.CONST for s in u.src): return tuple(s.arg for s in u.src)
return None
def fold_const_alu(a:UOp) -> UOp|None:
vals = [const_arg(s) for s in a.src]
return None if any(v is None for v in vals) else a.const_like(exec_alu(a.op, a.dtype, vals, False))
invalid_pat = UPat(Ops.CONST, arg=Invalid, name="i")
invalid_gate = UPat.var("cond").where(UPat.var("x"), invalid_pat)
@ -71,6 +80,10 @@ propagate_invalid = PatternMatcher([
(UPat.var("a").where(UPat.var("b"), invalid_gate), lambda cond,i,x,a,b: (a|cond).where(a.where(b, x), i) if b.arg != Invalid else None),
(UPat(Ops.BITCAST, src=(invalid_pat,), name="bc"), lambda bc,i: i.cast(bc.dtype)),
(UPat(Ops.BITCAST, src=(invalid_gate,), name="bc"), lambda bc,cond,x,i: cond.where(x.bitcast(bc.dtype), i.bitcast(bc.dtype))),
# fold gated LOAD/STORE
(UPat(Ops.STORE, src=(UPat().index(invalid_pat).or_casted(), UPat())), lambda i: UOp(Ops.NOOP)),
(UPat(Ops.LOAD, src=(UPat().index(invalid_pat).or_casted(),), allow_any_len=True, name="x"),
lambda x,i: x.src[1] if len(x.src) > 1 else x.const_like(0)), # invalid load produces 0, or the alt value if we have one
])
symbolic_simple = propagate_invalid + PatternMatcher([
@ -107,13 +120,11 @@ symbolic_simple = propagate_invalid + PatternMatcher([
(UPat.var("x", dtype=dtypes.ints+(dtypes.bool, dtypes.weakint)) != UPat.var("x"),
lambda x: x.const_like(False).cast(dtypes.bool.vec(x.dtype.count))), # x != x -> False (only ints)
# ** constant folding **
(UPat(GroupOp.Unary, src=(UPat((Ops.VCONST, Ops.CONST)),), name="a"), lambda a: a.const_like(exec_alu(a.op, a.dtype, [a.src[0].arg], False))),
(UPat(GroupOp.Binary-{Ops.THREEFRY}, src=(UPat((Ops.VCONST, Ops.CONST)),)*2, name="a"),
lambda a: a.const_like(exec_alu(a.op, a.dtype, [a.src[0].arg, a.src[1].arg], False))),
(UPat(GroupOp.Unary, src=(UPat((Ops.CONST, Ops.STACK)),), name="a"), fold_const_alu),
(UPat(GroupOp.Binary-{Ops.THREEFRY}, src=(UPat((Ops.CONST, Ops.STACK)),)*2, name="a"), fold_const_alu),
(UPat(Ops.THREEFRY, src=(UPat.cvar("x"), UPat.cvar("key")), name="a"),
lambda a, x, key: a.const_like(threefry2x32(x, key).simplify().arg)),
(UPat(GroupOp.Ternary, src=(UPat((Ops.VCONST, Ops.CONST)),)*3, name="a"),
lambda a: a.const_like(exec_alu(a.op, a.dtype, [a.src[0].arg, a.src[1].arg, a.src[2].arg], False))),
(UPat(GroupOp.Ternary, src=(UPat((Ops.CONST, Ops.STACK)),)*3, name="a"), fold_const_alu),
# bool MUL is AND, ADD/MAX is OR. prevents other rules to rewrite bool ADD/MUL incorrectly
(UPat.var('x', dtype=dtypes.bool) * UPat.var('y', dtype=dtypes.bool), lambda x,y: x&y),
(UPat.var('x', dtype=dtypes.bool) + UPat.var('y', dtype=dtypes.bool), lambda x,y: x|y),
@ -189,13 +200,12 @@ def gep_through_wmma(gep:UOp, wmma:UOp) -> UOp|None:
return UOp(Ops.WMMA, gep.dtype, tuple(tsrcs), wmma.arg)
gep_pushing = PatternMatcher([
# GEP/VECTORIZE, GEP/GEP, GEP/CONST, GEP/VCONST
# GEP/VECTORIZE, GEP/GEP, GEP/CONST
(UPat(Ops.GEP, name='g2').f(Ops.GEP, name='g1'),
lambda g1, g2: g2.src[0].gep(tuple(g2.arg[g1.arg[i]] for i in range(len(g1.arg))))),
(UPat(Ops.STACK, name='vec').f(Ops.GEP, name='gep'),
lambda gep, vec: UOp(Ops.STACK, gep.dtype, tuple(vec.src[i] for i in gep.arg)) if len(gep.arg) > 1 else vec.src[gep.arg[0]]),
(UPat.cvar("c", vec=False).f(Ops.GEP, name="gep"), lambda gep, c: gep.const_like(c.arg)),
(UPat(Ops.VCONST, name="c").f(Ops.GEP, name="gep"), lambda gep, c: gep.const_like(tuple(c.arg[x] for x in gep.arg))),
# GEP on void is skipped
(UPat(Ops.GEP, src=(UPat(dtype=dtypes.void, name="x"),)), lambda x: x),
# GEP in order is removed
@ -216,7 +226,8 @@ gep_pushing = PatternMatcher([
commutative = PatternMatcher([
# ** COMMUTATIVE flipping (only for index) **
# NOTE: this can break merging vector math by only flipping some of them
(UPat(GroupOp.Commutative, dtype=dtypes.weakint, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None),
(UPat(GroupOp.Commutative, dtype=dtypes.weakint, name='x'), lambda x:
x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize and not x.src[0].tuplize < x.src[1].tuplize else None),
])
symbolic = symbolic_simple+commutative+PatternMatcher([
@ -311,7 +322,6 @@ def parse_valid(v:UOp) -> tuple[UOp, bool, int]|None:
if v.op is Ops.CMPLT and dtypes.is_int(v.src[0].dtype):
# X < c -> X <= c-1
return v.src[0], True, int((v.src[1]).vmax)-1
# NOTE: v.src[1].op can be Ops.VCONST
return None
def uop_given_valid(valid:UOp, uop:UOp, try_simplex=True) -> UOp:
@ -446,9 +456,6 @@ sym = symbolic+pm_simplify_valid+PatternMatcher([
UPat.load(UPat(Ops.INDEX, name="index")))),
lambda index, gate, alt: UOp.store(index.src[0].index(gate.where(index.src[1], UOp.invalid())), alt)),
# fold gated LOAD/STORE
(UPat(Ops.STORE, src=(UPat().index(UPat.const(dtypes.weakint, Invalid)).or_casted(), UPat())), lambda: UOp(Ops.NOOP)),
(UPat(Ops.LOAD, src=(UPat().index(UPat.const(dtypes.weakint, Invalid)).or_casted(),), allow_any_len=True, name="x"),
lambda x: x.src[1] if len(x.src) > 1 else x.const_like(0)), # invalid load produces 0, or the alt value if we have one
(UPat(Ops.STORE, src=(UPat(), invalid_pat)), lambda i: UOp(Ops.NOOP)),
# store of where with invalid -> gated store
(UPat(Ops.STORE, src=(UPat(Ops.INDEX, name="index"), UPat.var("cond").where(UPat.var("val"), invalid_pat))),

View file

@ -45,7 +45,7 @@ from tinygrad.uop.render import print_uops, pyrender
from tinygrad.device import ProfileDeviceEvent, ProfileGraphEvent, ProfileGraphEntry, ProfileProgramEvent
from tinygrad.dtype import dtypes
uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0", Ops.REDUCE: "#FF5B5B",
uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.REDUCE: "#FF5B5B",
**{x:"#f2cb91" for x in {Ops.DEFINE_LOCAL, Ops.DEFINE_REG}}, Ops.SHAPED_WMMA: "#FF5B5B",
Ops.RANGE: "#c8a0e0", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff",
Ops.INDEX: "#cef263", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.INS: "#eec4ff",
@ -118,7 +118,6 @@ def uop_to_json(data:VizData, x:UOp) -> dict[int, dict]:
# always exclude DEVICE/CONST/UNIQUE
if u.op in {Ops.DEVICE, Ops.CONST, Ops.UNIQUE, Ops.LUNIQUE} and u is not x: excluded.add(u)
if u.op is Ops.CONST and len(u.src) and u.src[0].op in {Ops.UNIQUE, Ops.LUNIQUE}: excluded.remove(u)
if u.op is Ops.VCONST and u.dtype.scalar() == dtypes.weakint and u is not x: excluded.add(u)
if u.op is Ops.STACK and len(u.src) == 0: excluded.add(u)
# exclude RESHAPE/EXPAND that only serve to broadcast a CONST
if u.op in {Ops.RESHAPE, Ops.EXPAND} and len(u.src) >= 1 and u.src[0] in excluded and u is not x: excluded.add(u)