fold BIND to CONST when min==max (#15568)

This commit is contained in:
b1tg 2026-04-01 23:19:04 +08:00 committed by GitHub
commit 20497f2840
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 7 additions and 1 deletions

View file

@ -1024,6 +1024,12 @@ class TestSymbolicVariables(unittest.TestCase):
assert (a * a).variables() == [a]
assert (a//4 + a//6).variables() == [a]
def test_variable_min_eq_max_bind_folds(self):
b = Variable("x", 1, 1).bind(1)
s = b.simplify()
self.assertEqual(s.op, Ops.CONST)
self.assertEqual(s.arg, 1)
class TestSymInfer(unittest.TestCase):
def test_sym_infer(self):
a = Variable("a", 0, 10)

View file

@ -236,7 +236,7 @@ symbolic = symbolic_simple+commutative+PatternMatcher([
((UPat.var("y")+UPat.var("c").where(UPat.var("t"), UPat.var("f"))) + UPat.var("c").where(UPat.var("tt"), UPat.var("ff")), \
lambda y,c,t,tt,f,ff: y+c.where(t+tt, f+ff) if t.op == tt.op == Ops.CONST or f.op == ff.op == Ops.CONST else None),
# ALU/variable min==max -> CONST
(UPat({Ops.CMPLT, Ops.CMPNE, Ops.IDIV, Ops.MOD, Ops.DEFINE_VAR, Ops.SPECIAL}, name="x"),
(UPat({Ops.CMPLT, Ops.CMPNE, Ops.IDIV, Ops.MOD, Ops.DEFINE_VAR, Ops.BIND, Ops.SPECIAL}, name="x"),
lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None),
(UPat(Ops.RANGE, src=(UPat(Ops.CONST,)), name="x"), lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None),
# max folding