mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Remove Ops.VCONST (#16267)
* start removing vconst * remove a lot of vconst * const folding + strict ordering * update tests * spec from minigen * move that
This commit is contained in:
parent
7cdd9cbdeb
commit
55515747b7
15 changed files with 111 additions and 86 deletions
|
|
@ -127,10 +127,9 @@ class TestBitcastConstFolding(unittest.TestCase):
|
|||
|
||||
def test_vec_bitcast(self):
|
||||
with Context(SPEC=0):
|
||||
r = full_rewrite(UOp.const(dtypes.int32.vec(3), (-1, -2**31, 75)).bitcast(dtypes.uint32.vec(3)).sink()).src[0]
|
||||
self.assertEqual(r.op, Ops.STACK)
|
||||
self.assertEqual(r.dtype, dtypes.uint32.vec(3))
|
||||
self.assertEqual(tuple(x.arg for x in r.src), (2**32-1, 2**31, 75))
|
||||
srcs = full_rewrite(UOp.const(dtypes.int32.vec(3), (-1, -2**31, 75)).bitcast(dtypes.uint32.vec(3)).sink()).src
|
||||
self.assertTrue(all(r.op is Ops.CONST and r.dtype == dtypes.uint32 for r in srcs))
|
||||
self.assertEqual(tuple(x.arg for x in srcs), (2**32-1, 2**31, 75))
|
||||
|
||||
# folds advance indexing into basic indexing
|
||||
class TestIndexingConstFolding(unittest.TestCase):
|
||||
|
|
|
|||
|
|
@ -10,6 +10,14 @@ from hypothesis import given, strategies as strat
|
|||
def apply_rewrite(expr):
|
||||
return full_rewrite(expr.sink()).src[0]
|
||||
|
||||
@Context(SPEC=0)
|
||||
def apply_rewrite_values(expr):
|
||||
srcs = full_rewrite(expr.sink()).src
|
||||
if len(srcs) == 1:
|
||||
if srcs[0].op is Ops.CONST: return (srcs[0].arg,)*srcs[0].dtype.count
|
||||
if srcs[0].op is Ops.STACK: return tuple(s.arg for s in srcs[0].src)
|
||||
return tuple(s.arg for s in srcs)
|
||||
|
||||
def evaluate_uop(uop, variables):
|
||||
if uop.op == Ops.CONST:
|
||||
return uop.arg
|
||||
|
|
@ -121,7 +129,7 @@ class TestModuloAndDivisionFolding(unittest.TestCase):
|
|||
def test_graph_rewrite_div_folding_bug(self):
|
||||
lhs = UOp(Ops.ADD, dtypes.int.vec(4), src=(
|
||||
UOp(Ops.STACK, dtypes.int.vec(4), arg=None, src=(UOp(Ops.SPECIAL, dtypes.int, arg='lidx0', src=(UOp.const(dtypes.int, 32),)),)*4),
|
||||
UOp(Ops.VCONST, dtypes.int.vec(4), arg=(0, 256, 512, 768), src=())))
|
||||
UOp.const(dtypes.int.vec(4), (0, 256, 512, 768))))
|
||||
rhs = UOp.const(dtypes.int.vec(4), 2)
|
||||
unopt = lhs<rhs
|
||||
opt = apply_rewrite(unopt)
|
||||
|
|
@ -180,19 +188,17 @@ class TestGEPAndVectorizeRewrite(unittest.TestCase):
|
|||
def test_gep_tuple_extraction(self):
|
||||
# GEP on a vector dtype to extract multiple elements as a vector
|
||||
base_vector = UOp.const(dtypes.float32.vec(4), (1.0, 2.0, 3.0, 4.0))
|
||||
optimized_uop = apply_rewrite(base_vector.gep((2, 3)))
|
||||
self.assertEqual([sub_uop.arg for sub_uop in optimized_uop.src], [3.0, 4.0])
|
||||
self.assertEqual(list(apply_rewrite_values(base_vector.gep((2, 3)))), [3.0, 4.0])
|
||||
|
||||
def test_gep_on_vconst(self):
|
||||
# GEP on a VCONST to extract a single element
|
||||
vconst = UOp(Ops.VCONST, dtypes.float32.vec(4), arg=(1.0, 2.0, 3.0, 4.0))
|
||||
self.assertEqual(apply_rewrite(vconst.gep(2)).arg, 3.0)
|
||||
def test_gep_on_const_stack(self):
|
||||
# GEP on a const STACK to extract a single element
|
||||
const_stack = UOp.const(dtypes.float32.vec(4), (1.0, 2.0, 3.0, 4.0))
|
||||
self.assertEqual(apply_rewrite(const_stack.gep(2)).arg, 3.0)
|
||||
|
||||
def test_gep_tuple_on_vconst(self):
|
||||
# GEP on a VCONST using a tuple to extract multiple elements
|
||||
vconst = UOp(Ops.VCONST, dtypes.float32.vec(4), arg=(7.0, 8.0, 9.0, 10.0))
|
||||
optimized_uop = apply_rewrite(vconst.gep((1, 3)))
|
||||
self.assertEqual([sub_uop.arg for sub_uop in optimized_uop.src], [8.0, 10.0])
|
||||
def test_gep_tuple_on_const_stack(self):
|
||||
# GEP on a const STACK using a tuple to extract multiple elements
|
||||
const_stack = UOp.const(dtypes.float32.vec(4), (7.0, 8.0, 9.0, 10.0))
|
||||
self.assertEqual(list(apply_rewrite_values(const_stack.gep((1, 3)))), [8.0, 10.0])
|
||||
|
||||
def test_gep_gep_simplification(self):
|
||||
# Nested GEP simplification on a vector dtype
|
||||
|
|
@ -204,8 +210,7 @@ class TestGEPAndVectorizeRewrite(unittest.TestCase):
|
|||
# Vectorizing multiple elements using GEP
|
||||
base_vector = UOp.const(dtypes.float32.vec(4), (5.0, 10.0, 15.0, 20.0))
|
||||
vectorized_uop = UOp(Ops.STACK, dtypes.float32.vec(4), src=(base_vector.gep(0), base_vector.gep(1), base_vector.gep(2), base_vector.gep(3)))
|
||||
optimized_uop = apply_rewrite(vectorized_uop)
|
||||
self.assertEqual([sub_uop.arg for sub_uop in optimized_uop.src], [5.0, 10.0, 15.0, 20.0])
|
||||
self.assertEqual(list(apply_rewrite_values(vectorized_uop)), [5.0, 10.0, 15.0, 20.0])
|
||||
|
||||
|
||||
import inspect
|
||||
|
|
|
|||
|
|
@ -14,6 +14,11 @@ simple_pm = PatternMatcher([
|
|||
((UPat.var('x') + UPat.cvar('c1')) + UPat.cvar('c2'), lambda x,c1,c2: x + (c1.arg+c2.arg)),
|
||||
])
|
||||
|
||||
def const_values(u:UOp):
|
||||
if u.op is Ops.CONST: return (u.arg,)*u.dtype.count
|
||||
if u.op is Ops.STACK: return tuple(x.arg for x in u.src)
|
||||
raise AssertionError(f"expected const-like UOp, got {u.op}")
|
||||
|
||||
class TestGraphRewriteConst(unittest.TestCase):
|
||||
def test_gep_const(self):
|
||||
v1 = UOp.const(dtypes.int.vec(3), (0,1,2))
|
||||
|
|
@ -33,9 +38,9 @@ class TestGraphRewriteConst(unittest.TestCase):
|
|||
v1 = UOp.const(dtypes.int.vec(3), (0,1,2))
|
||||
v2 = UOp.const(dtypes.int.vec(3), (5,6,7))
|
||||
ret = graph_rewrite(v1+v2, sym)
|
||||
self.assertEqual(ret.op, Ops.VCONST)
|
||||
self.assertEqual(ret.op, Ops.STACK)
|
||||
self.assertEqual(ret.dtype, dtypes.int.vec(3))
|
||||
self.assertEqual(ret.arg, (5,7,9))
|
||||
self.assertEqual(const_values(ret), (5,7,9))
|
||||
|
||||
def test_add_const_lose_v(self):
|
||||
v1 = UOp.const(dtypes.int.vec(3), (0,1,2))
|
||||
|
|
@ -587,58 +592,61 @@ class TestExpander(unittest.TestCase):
|
|||
def test_expand_add_broadcast(self):
|
||||
e1 = UOp(Ops.UNROLL, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(x for x in range(4))),), ((1,4),))
|
||||
sink = expander_rewrite(e1+3)
|
||||
assert sink.op is Ops.UNROLL and len(sink.src[0].arg) == 4
|
||||
self.assertTupleEqual(sink.src[0].arg, (3,4,5,6))
|
||||
assert sink.op is Ops.UNROLL and len(const_values(sink.src[0])) == 4
|
||||
self.assertTupleEqual(const_values(sink.src[0]), (3,4,5,6))
|
||||
|
||||
def test_contract_simple(self):
|
||||
e1 = UOp(Ops.UNROLL, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(x for x in range(4))),), ((1,4),))
|
||||
con = UOp(Ops.CONTRACT, dtypes.int.vec(4), (e1,), ((1,4),))
|
||||
sink = expander_rewrite(con)
|
||||
self.assertEqual(sink.op, Ops.VCONST)
|
||||
self.assertTupleEqual(sink.arg, (0,1,2,3))
|
||||
self.assertEqual(sink.op, Ops.STACK)
|
||||
self.assertTupleEqual(const_values(sink), (0,1,2,3))
|
||||
|
||||
def test_contract_axis_1(self):
|
||||
e1 = UOp(Ops.UNROLL, dtypes.int, (UOp.const(dtypes.int.vec(16), tuple(x for x in range(16))),), ((1,4),(2,4)))
|
||||
con = UOp(Ops.CONTRACT, dtypes.int.vec(4), (e1,), ((1,4),))
|
||||
sink = expander_rewrite(con)
|
||||
assert sink.op is Ops.UNROLL and len(sink.src[0].arg) == 16 and sink.arg == ((2,4),)
|
||||
assert sink.src[0].op is Ops.VCONST
|
||||
self.assertTupleEqual(sink.src[0].arg[0:4], (0,4,8,12))
|
||||
self.assertTupleEqual(sink.src[0].arg[12:], (3,7,11,15))
|
||||
vals = const_values(sink.src[0])
|
||||
assert sink.op is Ops.UNROLL and len(vals) == 16 and sink.arg == ((2,4),)
|
||||
assert sink.src[0].op is Ops.STACK
|
||||
self.assertTupleEqual(vals[0:4], (0,4,8,12))
|
||||
self.assertTupleEqual(vals[12:], (3,7,11,15))
|
||||
|
||||
def test_contract_axis_2(self):
|
||||
e1 = UOp(Ops.UNROLL, dtypes.int, (UOp.const(dtypes.int.vec(16), tuple(x for x in range(16))),), ((1,4),(2,4)))
|
||||
con = UOp(Ops.CONTRACT, dtypes.int.vec(4), (e1,), ((2,4),))
|
||||
sink = expander_rewrite(con)
|
||||
assert sink.op is Ops.UNROLL and len(sink.src[0].arg) == 16 and sink.arg == ((1,4),)
|
||||
assert sink.src[0].op is Ops.VCONST
|
||||
self.assertTupleEqual(sink.src[0].arg[0:4], (0,1,2,3))
|
||||
self.assertTupleEqual(sink.src[0].arg[12:], (12,13,14,15))
|
||||
vals = const_values(sink.src[0])
|
||||
assert sink.op is Ops.UNROLL and len(vals) == 16 and sink.arg == ((1,4),)
|
||||
assert sink.src[0].op is Ops.STACK
|
||||
self.assertTupleEqual(vals[0:4], (0,1,2,3))
|
||||
self.assertTupleEqual(vals[12:], (12,13,14,15))
|
||||
|
||||
def test_contract_axis_2_big(self):
|
||||
e1 = UOp(Ops.UNROLL, dtypes.int, (UOp.const(dtypes.int.vec(16), tuple(x for x in range(16))),), ((1,2),(2,2),(3,2),(4,2)))
|
||||
con = UOp(Ops.CONTRACT, dtypes.int.vec(2), (e1,), ((2,2),))
|
||||
sink = expander_rewrite(con)
|
||||
assert sink.op is Ops.UNROLL and sink.arg == ((1, 2), (3, 2), (4, 2))
|
||||
self.assertTupleEqual(sink.src[0].arg[0:2], (0,4))
|
||||
self.assertTupleEqual(sink.src[0].arg[12:14], (10,14))
|
||||
vals = const_values(sink.src[0])
|
||||
self.assertTupleEqual(vals[0:2], (0,4))
|
||||
self.assertTupleEqual(vals[12:14], (10,14))
|
||||
|
||||
def test_contract_multi_axis(self):
|
||||
e1 = UOp(Ops.UNROLL, dtypes.int, (UOp.const(dtypes.int.vec(16), tuple(x for x in range(16))),), ((1,2),(2,2),(3,2),(4,2)))
|
||||
sink = expander_rewrite(UOp(Ops.CONTRACT, dtypes.int.vec(4), (e1,), ((3, 2), (2, 2))))
|
||||
assert sink.op is Ops.UNROLL and sink.arg == ((1, 2), (4, 2))
|
||||
self.assertTupleEqual(sink.src[0].arg[0:4], (0, 4, 2, 6))
|
||||
self.assertTupleEqual(const_values(sink.src[0])[0:4], (0, 4, 2, 6))
|
||||
sink = expander_rewrite(UOp(Ops.CONTRACT, dtypes.int.vec(4), (e1,), ((2, 2), (3, 2))))
|
||||
assert sink.op is Ops.UNROLL and sink.arg == ((1, 2), (4, 2))
|
||||
self.assertTupleEqual(sink.src[0].arg[0:4], (0, 2, 4, 6))
|
||||
self.assertTupleEqual(const_values(sink.src[0])[0:4], (0, 2, 4, 6))
|
||||
|
||||
def test_contract_mid(self):
|
||||
e1 = UOp(Ops.UNROLL, dtypes.int, (UOp.const(dtypes.int.vec(8), tuple(x for x in range(8))),), ((1,2),(2,2),(3,2)))
|
||||
con = UOp(Ops.CONTRACT, dtypes.int.vec(2), (e1,), ((2,2),))
|
||||
sink = expander_rewrite(con)
|
||||
assert sink.op is Ops.UNROLL and sink.arg == ((1,2),(3,2))
|
||||
assert sink.src[0].op is Ops.VCONST and len(sink.src[0].arg) == 8
|
||||
self.assertTupleEqual(sink.src[0].arg, (0,2,1,3,4,6,5,7))
|
||||
assert sink.src[0].op is Ops.STACK and len(const_values(sink.src[0])) == 8
|
||||
self.assertTupleEqual(const_values(sink.src[0]), (0,2,1,3,4,6,5,7))
|
||||
|
||||
def test_contract_no_expand(self):
|
||||
e1 = UOp.variable("i", 0, 10, dtype=dtypes.int)
|
||||
|
|
@ -651,26 +659,28 @@ class TestExpander(unittest.TestCase):
|
|||
e1 = UOp(Ops.UNROLL, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(x for x in range(4))),), ((1,4),))
|
||||
con = UOp(Ops.CONTRACT, dtypes.int.vec(8), (e1,), ((1,4), (2,2)))
|
||||
sink = expander_rewrite(con)
|
||||
assert sink.op is Ops.VCONST and len(sink.arg) == 8
|
||||
assert sink.arg[0] == sink.arg[1]
|
||||
assert sink.arg[0] != sink.arg[2]
|
||||
assert sink.arg[6] == sink.arg[7]
|
||||
vals = const_values(sink)
|
||||
assert sink.op is Ops.STACK and len(vals) == 8
|
||||
assert vals[0] == vals[1]
|
||||
assert vals[0] != vals[2]
|
||||
assert vals[6] == vals[7]
|
||||
|
||||
def test_expand_same_axis(self):
|
||||
e1 = UOp(Ops.UNROLL, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(x for x in range(4))),), ((1,4),))
|
||||
e2 = UOp(Ops.UNROLL, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(4*x for x in range(4))),), ((1,4),))
|
||||
sink = expander_rewrite(e1+e2)
|
||||
self.assertEqual(sink.op, Ops.UNROLL)
|
||||
self.assertEqual(sink.src[0].op, Ops.VCONST)
|
||||
self.assertTupleEqual(sink.src[0].arg, (0,5,10,15))
|
||||
self.assertEqual(sink.src[0].op, Ops.STACK)
|
||||
self.assertTupleEqual(const_values(sink.src[0]), (0,5,10,15))
|
||||
|
||||
def test_expand_different_axis(self, flip=False):
|
||||
e1 = UOp(Ops.UNROLL, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(4*x for x in range(4))),), ((1,4),))
|
||||
e2 = UOp(Ops.UNROLL, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(x for x in range(4))),), ((2,4),))
|
||||
sink = expander_rewrite((e2+e1) if flip else (e1+e2))
|
||||
assert sink.op is Ops.UNROLL and len(sink.src[0].arg) == 16
|
||||
vals = const_values(sink.src[0])
|
||||
assert sink.op is Ops.UNROLL and len(vals) == 16
|
||||
assert sink.arg == ((1, 4), (2, 4))
|
||||
self.assertTupleEqual(sink.src[0].arg, (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15))
|
||||
self.assertTupleEqual(vals, (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15))
|
||||
|
||||
def test_expand_different_axis_flip(self): self.test_expand_different_axis(True)
|
||||
|
||||
|
|
|
|||
|
|
@ -361,7 +361,7 @@ class TestUOpRender(unittest.TestCase):
|
|||
self.assertEqual(u.render(), "0")
|
||||
def test_render_vectorize_different_simplified(self):
|
||||
u = UOp(Ops.STACK, dtype=dtypes.int.vec(3), src=(UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1), UOp.const(dtypes.int, 2)))
|
||||
self.assertEqual(u.render(), "(0, 1, 2)")
|
||||
self.assertEqual(u.render(), "{0,1,2}")
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@ class TestUPatCompile(unittest.TestCase):
|
|||
do_compile(up)
|
||||
|
||||
def test_const_folding(self):
|
||||
up = UPat(GroupOp.ALU-{Ops.THREEFRY}, name="a", src=UPat((Ops.VCONST, Ops.CONST)))
|
||||
up = UPat(GroupOp.ALU-{Ops.THREEFRY}, name="a", src=UPat((Ops.CONST, Ops.STACK)))
|
||||
do_compile(up)
|
||||
|
||||
@unittest.skip("fix this")
|
||||
|
|
|
|||
|
|
@ -273,7 +273,6 @@ pm_render = PatternMatcher([
|
|||
# for rendering, we use explicit VECTORIZE
|
||||
(UPat(Ops.CONST, name='c'),
|
||||
lambda c: UOp(Ops.STACK, c.dtype, (UOp.const(c.dtype.scalar(), c.arg),)*c.dtype.vcount) if c.dtype.vcount > 1 else None),
|
||||
(UPat(Ops.VCONST, name='c'), lambda c: UOp(Ops.STACK, c.dtype, tuple(UOp.const(c.dtype.scalar(), x) for x in c.arg))),
|
||||
(UPat(Ops.GEP, name='gep'), lambda gep: UOp(Ops.STACK, gep.dtype, tuple(gep.src[0].gep(x) for x in gep.arg)) if len(gep.arg) > 1 else None),
|
||||
(UPat(Ops.GEP, name='gep'), lambda gep: gep.src[0] if gep.src[0].dtype.vcount == 1 and gep.arg == (0,) else None),
|
||||
(UPat(Ops.STACK, src=(UPat(name='x'),)), lambda x: x),
|
||||
|
|
|
|||
|
|
@ -16,6 +16,6 @@ pm_move_gates_from_index = PatternMatcher([
|
|||
l.replace(src=(l.src[0], a.src[0] if a.op is Ops.CAST and a.src[0].dtype == l.dtype else a.cast(l.dtype), l.src[2])).cast(a.dtype)),
|
||||
|
||||
# images use 2D INDEX now (y,x)
|
||||
(UPat(Ops.INDEX, src=(UPat(), UPat((Ops.CONST, Ops.VCONST, Ops.STACK), name="vec")), name="idx"),
|
||||
(UPat(Ops.INDEX, src=(UPat(), UPat((Ops.CONST, Ops.STACK), name="vec")), name="idx"),
|
||||
lambda idx,vec: idx.replace(src=(idx.src[0], vec.gep(1).cast(dtypes.int), vec.gep(0).cast(dtypes.int))) if vec.dtype.count == 2 else None),
|
||||
])
|
||||
|
|
|
|||
|
|
@ -134,7 +134,7 @@ def reduce_collapse(red:UOp, u:UOp, pm:PatternMatcher=pm_reduce_collapse) -> UOp
|
|||
replaces: dict[UOp, UOp] = {}
|
||||
for u in included:
|
||||
for s in u.src:
|
||||
if s in included or s in replaces or s.op in {Ops.CONST, Ops.VCONST, Ops.PARAM, Ops.DEFINE_LOCAL, Ops.DEFINE_VAR}: continue
|
||||
if s in included or s in replaces or s.op in {Ops.CONST, Ops.PARAM, Ops.DEFINE_LOCAL, Ops.DEFINE_VAR}: continue
|
||||
replaces[s] = UOp.variable(f'in{len(replaces)}', s.vmin, s.vmax, s.dtype)
|
||||
collapse_fxn = u.substitute(replaces).reduce(r, arg=Ops.ADD)
|
||||
sink = graph_rewrite(collapse_fxn, pm, name="reduce_collapse")
|
||||
|
|
|
|||
|
|
@ -78,8 +78,8 @@ class Ops(FastEnum):
|
|||
# control flow ops
|
||||
BARRIER = auto(); RANGE = auto(); IF = auto(); END = auto(); ENDIF = auto(); WAIT = auto()
|
||||
|
||||
# consts. VCONST is a vectorized const
|
||||
VCONST = auto(); CONST = auto()
|
||||
# const.
|
||||
CONST = auto()
|
||||
|
||||
# CUSTOM/CUSTOMI are used to output strings into codegen. the I makes the string inline
|
||||
CUSTOM = auto(); CUSTOMI = auto()
|
||||
|
|
|
|||
|
|
@ -65,7 +65,7 @@ def fold_divmod_general(d: UOp) -> UOp|None:
|
|||
# nest_by_factor: x//c -> (x//f)//(c//f), x%c -> (x//f%(c//f))*f + b where b=x%f
|
||||
# FLOORDIV identity holds for any sign of x; FLOORMOD reconstruction needs x.vmin>=0
|
||||
results = []
|
||||
for div in {abs(f) for u, f in zip(uops_no_const, factors) if u.op not in (Ops.CONST, Ops.VCONST) and 1 < abs(f) < c and (c%f)==0}:
|
||||
for div in {abs(f) for u, f in zip(uops_no_const, factors) if u.op is not Ops.CONST and 1 < abs(f) < c and (c%f)==0}:
|
||||
if (newxs := fold_divmod_general(x//div)) is not None:
|
||||
if d.op is Ops.FLOORDIV:
|
||||
results.append((len(newxs.backward_slice), newxs // (c // div)))
|
||||
|
|
|
|||
|
|
@ -251,7 +251,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
|
||||
# TODO: these should have the shape of the dtype.count
|
||||
case Ops.CONST | Ops.DEFINE_VAR: return ()
|
||||
case Ops.GEP | Ops.STACK | Ops.VCONST | Ops.VCAT | Ops.GETADDR: return ()
|
||||
case Ops.GEP | Ops.STACK | Ops.VCAT | Ops.GETADDR: return ()
|
||||
|
||||
# some ops init the shape
|
||||
case Ops.BIND | Ops.RANGE | Ops.SPECIAL | Ops.UNROLL: return ()
|
||||
|
|
@ -395,8 +395,8 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
# *** uop evaluation ***
|
||||
|
||||
def simplify(self, tracked=False):
|
||||
if self.op in {Ops.CONST, Ops.VCONST}: return self
|
||||
if self.op is Ops.SINK and all(s.op in {Ops.CONST, Ops.VCONST} or (s.op is Ops.STACK and len(s.src) == 0) for s in self.src): return self
|
||||
if self.op is Ops.CONST: return self
|
||||
if self.op is Ops.SINK and all(s.op is Ops.CONST or (s.op is Ops.STACK and len(s.src) == 0) for s in self.src): return self
|
||||
# late import!
|
||||
from tinygrad.uop.symbolic import symbolic
|
||||
with Context(TRACK_MATCH_STATS=0 if not tracked else TRACK_MATCH_STATS.value):
|
||||
|
|
@ -482,7 +482,6 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
if isinstance(i, int):
|
||||
# NOTE: these are just shortcuts to not have to create and fold later
|
||||
if self.op is Ops.STACK: return self.src[i]
|
||||
if self.op is Ops.VCONST: return UOp.const(self.dtype.scalar(), self.arg[i])
|
||||
if self.op is Ops.CONST: return UOp.const(self.dtype.scalar(), self.arg)
|
||||
i = (i,)
|
||||
return UOp(Ops.GEP, self.dtype.scalar().vec(len(i)) if len(i) > 1 else self.dtype.scalar(), (self,), i)
|
||||
|
|
@ -513,10 +512,12 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
if isinstance(b, UOp): return b.unbind()[0] if b.op is Ops.BIND else b
|
||||
if isinstance(b, tuple) and all_same(b):
|
||||
assert len(b) > 0, "can't create const from empty tuple"
|
||||
b = b[0] # doesn't have to be a VCONST if they are all the same
|
||||
ret = UOp(Ops.VCONST if isinstance(b, tuple) else Ops.CONST, dtype,
|
||||
arg=dtype.const(b),
|
||||
src=(UOp(Ops.DEVICE, arg=device),) if device is not None else ())
|
||||
b = b[0] # doesn't have to be a STACK if they are all the same
|
||||
if isinstance(b, tuple):
|
||||
stk = [UOp(Ops.CONST, dtype.scalar(), arg=dtype.const(c), src=(UOp(Ops.DEVICE, arg=device),) if device is not None else ()) for c in b]
|
||||
ret = UOp.vectorize(*stk)
|
||||
else:
|
||||
ret = UOp(Ops.CONST, dtype, arg=dtype.const(b), src=(UOp(Ops.DEVICE, arg=device),) if device is not None else ())
|
||||
return ret.reshape((1,)*len(shape)).expand(shape) if shape is not None and ret.shape != shape else ret
|
||||
@staticmethod
|
||||
def unique_const(fill_value:ConstType, dtype:DTypeLike|None=None, device:str|tuple[str, ...]|None=None, # type: ignore[override]
|
||||
|
|
@ -656,7 +657,6 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
def sgep(self, i:int) -> sint:
|
||||
match self.op:
|
||||
case Ops.CONST: return self.arg
|
||||
case Ops.VCONST: return self.arg[i]
|
||||
case Ops.STACK: return self.src[i].sintify()
|
||||
case _: raise RuntimeError(f"no sgep on {self.op}")
|
||||
|
||||
|
|
@ -850,14 +850,16 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
"""largest known int that divides self"""
|
||||
# TODO: for negatives it's not the largest
|
||||
if self.op is Ops.CONST: return self.arg
|
||||
if self.op is Ops.VCONST: return math.gcd(*self.arg)
|
||||
if self.op is Ops.STACK: return math.gcd(*[x.const_factor() for x in self.src])
|
||||
if self.op is Ops.ADD: return math.gcd(self.src[0].const_factor(), self.src[1].const_factor())
|
||||
if self.op is Ops.MUL: return self.src[0].arg if self.src[0].op is Ops.CONST else self.src[1].arg if self.src[1].op is Ops.CONST else 1
|
||||
return 1
|
||||
def divides(self, v:int) -> UOp|None:
|
||||
if v==1: return self
|
||||
if self.op is Ops.CONST: return self.const_like(self.arg//v) if self.arg%v == 0 else None
|
||||
if self.op is Ops.VCONST: return self.const_like(tuple(x//v for x in self.arg)) if all(x%v == 0 for x in self.arg) else None
|
||||
if self.op is Ops.STACK:
|
||||
srcs = tuple(s.divides(v) for s in self.src)
|
||||
return None if any(s is None for s in srcs) else UOp(Ops.STACK, self.dtype, cast(tuple[UOp, ...], srcs))
|
||||
if self.op is Ops.ADD: return d0+d1 if (d0:=self.src[0].divides(v)) is not None and (d1:=self.src[1].divides(v)) is not None else None
|
||||
if self.op is Ops.MUL:
|
||||
if (d0:=self.src[0].divides(v)) is not None: return d0 * self.src[1]
|
||||
|
|
@ -931,7 +933,6 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
if self.op is Ops.BIND: return self.src[0]._min_max # ignore the bound value
|
||||
if self.op in {Ops.UNROLL, Ops.STACK}: return min(x.vmin for x in self.src), max(x.vmax for x in self.src)
|
||||
if self.op is Ops.CONST and self.arg is not Invalid: return self.arg, self.arg
|
||||
if self.op is Ops.VCONST and Invalid not in self.arg: return (min(self.arg), max(self.arg))
|
||||
if self.op is Ops.GEP: return self.src[0]._min_max
|
||||
# TODO: CAST to bool/unsigned is not monotone, still some case can be simplified
|
||||
if self.op is Ops.CAST and self.dtype in dtypes.floats+dtypes.sints+(dtypes.weakint,):
|
||||
|
|
@ -1185,7 +1186,7 @@ class UPat(OpMixin):
|
|||
@staticmethod
|
||||
@functools.cache
|
||||
def cvar(name:str|None=None, dtype:DType|tuple[DType, ...]|None=None, vec=True, arg=None):
|
||||
return UPat((Ops.CONST,Ops.VCONST) if vec else Ops.CONST, dtype, name=name, arg=arg)
|
||||
return UPat(Ops.CONST, dtype, name=name, arg=arg)
|
||||
@staticmethod
|
||||
def const(dtype:DType|tuple[DType, ...]|None, b:ConstType): return UPat(Ops.CONST, dtype=dtype, arg=b)
|
||||
|
||||
|
|
@ -1572,8 +1573,7 @@ pm_lower_index_dtype = PatternMatcher([
|
|||
# There are no Unary ops at this point in symbolic, those are introduced later
|
||||
(UPat(GroupOp.Binary, name="u", src=(UPat.var("x").cast(dtypes.weakint), UPat.var("y").cast(dtypes.weakint))), lambda u,x,y:
|
||||
x.cast(dt:=least_upper_dtype(select_dtype(u), x.dtype, y.dtype)).alu(u.op, y.cast(dt)).cast(u.dtype)),
|
||||
(UPat((Ops.CONST, Ops.VCONST), dtype=dtypes.weakint, name="u"),
|
||||
lambda u: u.replace(dtype=select_dtype(u)).cast(u.dtype) if u.arg!=Invalid else None),
|
||||
(UPat(Ops.CONST, dtype=dtypes.weakint, name="u"), lambda u: u.replace(dtype=select_dtype(u)).cast(u.dtype) if u.arg!=Invalid else None),
|
||||
(UPat(Ops.WHERE, dtypes.weakint, src=(UPat.var("cond"), UPat.var("x").cast(dtypes.weakint), UPat.var("y").cast(dtypes.weakint))), lambda cond,x,y:
|
||||
cond.where(x.cast(dt:=least_upper_dtype(x.dtype, y.dtype)), y.cast(dt)).cast(dtypes.weakint)),
|
||||
(UPat(Ops.RANGE, src=(UPat.var("end").cast(dtypes.weakint)), name="r"), lambda r,end: r.replace(dtype=end.dtype, src=(end,)).cast(dtypes.weakint)),
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ renderer = PatternMatcher([
|
|||
(UPat(Ops.PARAM, src=(UPat(), UPat(), UPat(), UPat(), UPat(Ops.NOOP, name="x"))), lambda x: x.arg),
|
||||
(UPat((Ops.SPECIAL), name="x"), lambda x: x.arg),
|
||||
(UPat(Ops.RANGE, name="x"), lambda x: f"r{range_str(x)}"),
|
||||
(UPat((Ops.CONST, Ops.VCONST), name="x"), lambda x: str(x.arg)),
|
||||
(UPat(Ops.CONST, name="x"), lambda x: str(x.arg)),
|
||||
(UPat(Ops.UNROLL, name="x"), lambda ctx,x,u: f"UNROLL({ctx[x.src[0]]}, {u.arg})"),
|
||||
(UPat(Ops.CAST, name="x"), lambda ctx,x: f"({str(x.dtype)[7:]})({ctx[x.src[0]]})"),
|
||||
(UPat(Ops.BIND, name="x"), lambda ctx,x: ctx[x.src[0]]),
|
||||
|
|
@ -146,7 +146,7 @@ def pyrender(ast:UOp) -> str:
|
|||
lst = list(ast.toposort())
|
||||
|
||||
cmap = consumer_map_from_toposort(lst)
|
||||
not_rendered = {Ops.CONST, Ops.VCONST, Ops.DEVICE}
|
||||
not_rendered = {Ops.CONST, Ops.DEVICE}
|
||||
always_rendered = {Ops.PARAM, Ops.LOAD, Ops.SPECIAL, Ops.RANGE, Ops.CONTIGUOUS, Ops.STACK,
|
||||
Ops.BUFFER, Ops.COPY, Ops.CALL, Ops.FUNCTION, Ops.WHERE, Ops.END}
|
||||
|
||||
|
|
|
|||
|
|
@ -153,7 +153,7 @@ spec_tensor = PatternMatcher([
|
|||
(UPat(Ops.PARAM, src=(UPat(), UPat(Ops.DEVICE), UPat(Ops.MULTI)), name="x"), lambda x: True),
|
||||
|
||||
# inputs to movement ops
|
||||
(UPat((Ops.STACK, Ops.VCONST)), lambda: True),
|
||||
(UPat(Ops.STACK), lambda: True),
|
||||
(UPat({Ops.ADD, Ops.MUL, Ops.CDIV, Ops.FLOORDIV}, dtype=dtypes.weakint), lambda: True),
|
||||
|
||||
# movement ops
|
||||
|
|
@ -198,6 +198,12 @@ spec_tensor = PatternMatcher([
|
|||
|
||||
# these ops can exist in programs but not the tensor spec. example: LOAD
|
||||
spec_program = PatternMatcher([
|
||||
# weakint is not allowed in programs
|
||||
(UPat(GroupOp.All, dtypes.weakint), lambda: False),
|
||||
|
||||
# Invalid is not allowed in program
|
||||
(UPat(Ops.CONST, arg=Invalid), lambda: False),
|
||||
|
||||
# STACK/GEP in program. TODO: this should match Tensor
|
||||
(UPat(Ops.STACK, name="x"), lambda x: len(x.src)>1 and len(x.src) == x.dtype.vcount and all(x.dtype == y.dtype.vec(len(x.src)) for y in x.src)),
|
||||
(UPat(Ops.GEP, src=(UPat.var("src"),), name="gep"), lambda gep,src: gep.dtype == src.dtype.scalar()),
|
||||
|
|
|
|||
|
|
@ -22,6 +22,15 @@ def fold_bitcast(root:UOp, c:UOp) -> UOp|None:
|
|||
def convert(v:ConstType) -> ConstType: return struct.unpack(to_fmt, struct.pack(from_fmt, v))[0]
|
||||
return root.const_like(convert(c.arg) if root.dtype.count == 1 else tuple(map(convert, c.arg)))
|
||||
|
||||
def const_arg(u:UOp) -> ConstType|tuple[ConstType, ...]|None:
|
||||
if u.op is Ops.CONST: return u.arg
|
||||
if u.op is Ops.STACK and all(s.op is Ops.CONST for s in u.src): return tuple(s.arg for s in u.src)
|
||||
return None
|
||||
|
||||
def fold_const_alu(a:UOp) -> UOp|None:
|
||||
vals = [const_arg(s) for s in a.src]
|
||||
return None if any(v is None for v in vals) else a.const_like(exec_alu(a.op, a.dtype, vals, False))
|
||||
|
||||
invalid_pat = UPat(Ops.CONST, arg=Invalid, name="i")
|
||||
invalid_gate = UPat.var("cond").where(UPat.var("x"), invalid_pat)
|
||||
|
||||
|
|
@ -71,6 +80,10 @@ propagate_invalid = PatternMatcher([
|
|||
(UPat.var("a").where(UPat.var("b"), invalid_gate), lambda cond,i,x,a,b: (a|cond).where(a.where(b, x), i) if b.arg != Invalid else None),
|
||||
(UPat(Ops.BITCAST, src=(invalid_pat,), name="bc"), lambda bc,i: i.cast(bc.dtype)),
|
||||
(UPat(Ops.BITCAST, src=(invalid_gate,), name="bc"), lambda bc,cond,x,i: cond.where(x.bitcast(bc.dtype), i.bitcast(bc.dtype))),
|
||||
# fold gated LOAD/STORE
|
||||
(UPat(Ops.STORE, src=(UPat().index(invalid_pat).or_casted(), UPat())), lambda i: UOp(Ops.NOOP)),
|
||||
(UPat(Ops.LOAD, src=(UPat().index(invalid_pat).or_casted(),), allow_any_len=True, name="x"),
|
||||
lambda x,i: x.src[1] if len(x.src) > 1 else x.const_like(0)), # invalid load produces 0, or the alt value if we have one
|
||||
])
|
||||
|
||||
symbolic_simple = propagate_invalid + PatternMatcher([
|
||||
|
|
@ -107,13 +120,11 @@ symbolic_simple = propagate_invalid + PatternMatcher([
|
|||
(UPat.var("x", dtype=dtypes.ints+(dtypes.bool, dtypes.weakint)) != UPat.var("x"),
|
||||
lambda x: x.const_like(False).cast(dtypes.bool.vec(x.dtype.count))), # x != x -> False (only ints)
|
||||
# ** constant folding **
|
||||
(UPat(GroupOp.Unary, src=(UPat((Ops.VCONST, Ops.CONST)),), name="a"), lambda a: a.const_like(exec_alu(a.op, a.dtype, [a.src[0].arg], False))),
|
||||
(UPat(GroupOp.Binary-{Ops.THREEFRY}, src=(UPat((Ops.VCONST, Ops.CONST)),)*2, name="a"),
|
||||
lambda a: a.const_like(exec_alu(a.op, a.dtype, [a.src[0].arg, a.src[1].arg], False))),
|
||||
(UPat(GroupOp.Unary, src=(UPat((Ops.CONST, Ops.STACK)),), name="a"), fold_const_alu),
|
||||
(UPat(GroupOp.Binary-{Ops.THREEFRY}, src=(UPat((Ops.CONST, Ops.STACK)),)*2, name="a"), fold_const_alu),
|
||||
(UPat(Ops.THREEFRY, src=(UPat.cvar("x"), UPat.cvar("key")), name="a"),
|
||||
lambda a, x, key: a.const_like(threefry2x32(x, key).simplify().arg)),
|
||||
(UPat(GroupOp.Ternary, src=(UPat((Ops.VCONST, Ops.CONST)),)*3, name="a"),
|
||||
lambda a: a.const_like(exec_alu(a.op, a.dtype, [a.src[0].arg, a.src[1].arg, a.src[2].arg], False))),
|
||||
(UPat(GroupOp.Ternary, src=(UPat((Ops.CONST, Ops.STACK)),)*3, name="a"), fold_const_alu),
|
||||
# bool MUL is AND, ADD/MAX is OR. prevents other rules to rewrite bool ADD/MUL incorrectly
|
||||
(UPat.var('x', dtype=dtypes.bool) * UPat.var('y', dtype=dtypes.bool), lambda x,y: x&y),
|
||||
(UPat.var('x', dtype=dtypes.bool) + UPat.var('y', dtype=dtypes.bool), lambda x,y: x|y),
|
||||
|
|
@ -189,13 +200,12 @@ def gep_through_wmma(gep:UOp, wmma:UOp) -> UOp|None:
|
|||
return UOp(Ops.WMMA, gep.dtype, tuple(tsrcs), wmma.arg)
|
||||
|
||||
gep_pushing = PatternMatcher([
|
||||
# GEP/VECTORIZE, GEP/GEP, GEP/CONST, GEP/VCONST
|
||||
# GEP/VECTORIZE, GEP/GEP, GEP/CONST
|
||||
(UPat(Ops.GEP, name='g2').f(Ops.GEP, name='g1'),
|
||||
lambda g1, g2: g2.src[0].gep(tuple(g2.arg[g1.arg[i]] for i in range(len(g1.arg))))),
|
||||
(UPat(Ops.STACK, name='vec').f(Ops.GEP, name='gep'),
|
||||
lambda gep, vec: UOp(Ops.STACK, gep.dtype, tuple(vec.src[i] for i in gep.arg)) if len(gep.arg) > 1 else vec.src[gep.arg[0]]),
|
||||
(UPat.cvar("c", vec=False).f(Ops.GEP, name="gep"), lambda gep, c: gep.const_like(c.arg)),
|
||||
(UPat(Ops.VCONST, name="c").f(Ops.GEP, name="gep"), lambda gep, c: gep.const_like(tuple(c.arg[x] for x in gep.arg))),
|
||||
# GEP on void is skipped
|
||||
(UPat(Ops.GEP, src=(UPat(dtype=dtypes.void, name="x"),)), lambda x: x),
|
||||
# GEP in order is removed
|
||||
|
|
@ -216,7 +226,8 @@ gep_pushing = PatternMatcher([
|
|||
commutative = PatternMatcher([
|
||||
# ** COMMUTATIVE flipping (only for index) **
|
||||
# NOTE: this can break merging vector math by only flipping some of them
|
||||
(UPat(GroupOp.Commutative, dtype=dtypes.weakint, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None),
|
||||
(UPat(GroupOp.Commutative, dtype=dtypes.weakint, name='x'), lambda x:
|
||||
x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize and not x.src[0].tuplize < x.src[1].tuplize else None),
|
||||
])
|
||||
|
||||
symbolic = symbolic_simple+commutative+PatternMatcher([
|
||||
|
|
@ -311,7 +322,6 @@ def parse_valid(v:UOp) -> tuple[UOp, bool, int]|None:
|
|||
if v.op is Ops.CMPLT and dtypes.is_int(v.src[0].dtype):
|
||||
# X < c -> X <= c-1
|
||||
return v.src[0], True, int((v.src[1]).vmax)-1
|
||||
# NOTE: v.src[1].op can be Ops.VCONST
|
||||
return None
|
||||
|
||||
def uop_given_valid(valid:UOp, uop:UOp, try_simplex=True) -> UOp:
|
||||
|
|
@ -446,9 +456,6 @@ sym = symbolic+pm_simplify_valid+PatternMatcher([
|
|||
UPat.load(UPat(Ops.INDEX, name="index")))),
|
||||
lambda index, gate, alt: UOp.store(index.src[0].index(gate.where(index.src[1], UOp.invalid())), alt)),
|
||||
# fold gated LOAD/STORE
|
||||
(UPat(Ops.STORE, src=(UPat().index(UPat.const(dtypes.weakint, Invalid)).or_casted(), UPat())), lambda: UOp(Ops.NOOP)),
|
||||
(UPat(Ops.LOAD, src=(UPat().index(UPat.const(dtypes.weakint, Invalid)).or_casted(),), allow_any_len=True, name="x"),
|
||||
lambda x: x.src[1] if len(x.src) > 1 else x.const_like(0)), # invalid load produces 0, or the alt value if we have one
|
||||
(UPat(Ops.STORE, src=(UPat(), invalid_pat)), lambda i: UOp(Ops.NOOP)),
|
||||
# store of where with invalid -> gated store
|
||||
(UPat(Ops.STORE, src=(UPat(Ops.INDEX, name="index"), UPat.var("cond").where(UPat.var("val"), invalid_pat))),
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ from tinygrad.uop.render import print_uops, pyrender
|
|||
from tinygrad.device import ProfileDeviceEvent, ProfileGraphEvent, ProfileGraphEntry, ProfileProgramEvent
|
||||
from tinygrad.dtype import dtypes
|
||||
|
||||
uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0", Ops.REDUCE: "#FF5B5B",
|
||||
uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.REDUCE: "#FF5B5B",
|
||||
**{x:"#f2cb91" for x in {Ops.DEFINE_LOCAL, Ops.DEFINE_REG}}, Ops.SHAPED_WMMA: "#FF5B5B",
|
||||
Ops.RANGE: "#c8a0e0", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff",
|
||||
Ops.INDEX: "#cef263", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.INS: "#eec4ff",
|
||||
|
|
@ -118,7 +118,6 @@ def uop_to_json(data:VizData, x:UOp) -> dict[int, dict]:
|
|||
# always exclude DEVICE/CONST/UNIQUE
|
||||
if u.op in {Ops.DEVICE, Ops.CONST, Ops.UNIQUE, Ops.LUNIQUE} and u is not x: excluded.add(u)
|
||||
if u.op is Ops.CONST and len(u.src) and u.src[0].op in {Ops.UNIQUE, Ops.LUNIQUE}: excluded.remove(u)
|
||||
if u.op is Ops.VCONST and u.dtype.scalar() == dtypes.weakint and u is not x: excluded.add(u)
|
||||
if u.op is Ops.STACK and len(u.src) == 0: excluded.add(u)
|
||||
# exclude RESHAPE/EXPAND that only serve to broadcast a CONST
|
||||
if u.op in {Ops.RESHAPE, Ops.EXPAND} and len(u.src) >= 1 and u.src[0] in excluded and u is not x: excluded.add(u)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue