mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
fold BIND to CONST when min==max (#15568)
This commit is contained in:
parent
9275f283e5
commit
20497f2840
2 changed files with 7 additions and 1 deletions
|
|
@ -1024,6 +1024,12 @@ class TestSymbolicVariables(unittest.TestCase):
|
|||
assert (a * a).variables() == [a]
|
||||
assert (a//4 + a//6).variables() == [a]
|
||||
|
||||
def test_variable_min_eq_max_bind_folds(self):
|
||||
b = Variable("x", 1, 1).bind(1)
|
||||
s = b.simplify()
|
||||
self.assertEqual(s.op, Ops.CONST)
|
||||
self.assertEqual(s.arg, 1)
|
||||
|
||||
class TestSymInfer(unittest.TestCase):
|
||||
def test_sym_infer(self):
|
||||
a = Variable("a", 0, 10)
|
||||
|
|
|
|||
|
|
@ -236,7 +236,7 @@ symbolic = symbolic_simple+commutative+PatternMatcher([
|
|||
((UPat.var("y")+UPat.var("c").where(UPat.var("t"), UPat.var("f"))) + UPat.var("c").where(UPat.var("tt"), UPat.var("ff")), \
|
||||
lambda y,c,t,tt,f,ff: y+c.where(t+tt, f+ff) if t.op == tt.op == Ops.CONST or f.op == ff.op == Ops.CONST else None),
|
||||
# ALU/variable min==max -> CONST
|
||||
(UPat({Ops.CMPLT, Ops.CMPNE, Ops.IDIV, Ops.MOD, Ops.DEFINE_VAR, Ops.SPECIAL}, name="x"),
|
||||
(UPat({Ops.CMPLT, Ops.CMPNE, Ops.IDIV, Ops.MOD, Ops.DEFINE_VAR, Ops.BIND, Ops.SPECIAL}, name="x"),
|
||||
lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None),
|
||||
(UPat(Ops.RANGE, src=(UPat(Ops.CONST,)), name="x"), lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None),
|
||||
# max folding
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue