symbolic fold double where (#9436)

* symbolic fold double where

a.where(b.where(c, d), d) -> (a & b).where(c, d). a pattern in optimizer

* test case
This commit is contained in:
chenyu 2025-04-05 05:12:17 -04:00 committed by GitHub
commit 407ca54382
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 19 additions and 0 deletions

View file

@ -555,6 +555,23 @@ class TestSymbolic(unittest.TestCase):
self.assertEqual(rewritten_uop, cond.where(a.cast(dtypes.half), b.cast(dtypes.half)))
def test_where_merge_branches(self):
cond1 = Variable("s", 0, 10) < 6
cond2 = Variable("s", 0, 10) > 2
a = Variable("a", 0, 3)
b = Variable("b", 0, 3)
expr = cond1.where(cond2.where(a, b), b)
self.helper_test_variable(expr, 0, 3, "(a if ((s<6)&(2<s)) else b)")
def test_where_merge_branches2(self):
cond1 = Variable("s", 0, 10) < 5
cond2 = Variable("s", 0, 10) < 6
a = Variable("a", 0, 3)
b = Variable("b", 0, 3)
expr = cond1.where(cond2.where(a, b), b)
# (a if ((s<5)&(s<6)) else b) -> (a if (s<5) else b)
self.helper_test_variable(expr, 0, 3, "(a if (s<5) else b)")
def test_symbolic_div(self):
# from symbolic arange
a = Variable("a", 1, 10)

View file

@ -461,6 +461,8 @@ sym = symbolic_flat+PatternMatcher([
# ** where **
# push cast to branches
(UPat.var("s").where(UPat.var("a"), UPat.var("b")).cast().named("cast"), lambda s,a,b,cast: s.where(a.cast(cast.dtype), b.cast(cast.dtype))),
# a.where(b.where(c, d), d) -> (a & b).where(c, d)
(UPat.var("a").where(UPat.var("b").where(UPat.var("c"), UPat.var("d")), UPat.var("d")), lambda a,b,c,d: (a&b).where(c,d)),
# ** pow **
((UPat(Ops.POW, name="p"), lambda p: xpow(*p.src))),
# ** load/store folding **