mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
symbolic fold double where (#9436)
* symbolic fold double where a.where(b.where(c, d), d) -> (a & b).where(c, d). a pattern in optimizer * test case
This commit is contained in:
parent
9c2fc695b5
commit
407ca54382
2 changed files with 19 additions and 0 deletions
|
|
@ -555,6 +555,23 @@ class TestSymbolic(unittest.TestCase):
|
|||
|
||||
self.assertEqual(rewritten_uop, cond.where(a.cast(dtypes.half), b.cast(dtypes.half)))
|
||||
|
||||
def test_where_merge_branches(self):
|
||||
cond1 = Variable("s", 0, 10) < 6
|
||||
cond2 = Variable("s", 0, 10) > 2
|
||||
a = Variable("a", 0, 3)
|
||||
b = Variable("b", 0, 3)
|
||||
expr = cond1.where(cond2.where(a, b), b)
|
||||
self.helper_test_variable(expr, 0, 3, "(a if ((s<6)&(2<s)) else b)")
|
||||
|
||||
def test_where_merge_branches2(self):
|
||||
cond1 = Variable("s", 0, 10) < 5
|
||||
cond2 = Variable("s", 0, 10) < 6
|
||||
a = Variable("a", 0, 3)
|
||||
b = Variable("b", 0, 3)
|
||||
expr = cond1.where(cond2.where(a, b), b)
|
||||
# (a if ((s<5)&(s<6)) else b) -> (a if (s<5) else b)
|
||||
self.helper_test_variable(expr, 0, 3, "(a if (s<5) else b)")
|
||||
|
||||
def test_symbolic_div(self):
|
||||
# from symbolic arange
|
||||
a = Variable("a", 1, 10)
|
||||
|
|
|
|||
|
|
@ -461,6 +461,8 @@ sym = symbolic_flat+PatternMatcher([
|
|||
# ** where **
|
||||
# push cast to branches
|
||||
(UPat.var("s").where(UPat.var("a"), UPat.var("b")).cast().named("cast"), lambda s,a,b,cast: s.where(a.cast(cast.dtype), b.cast(cast.dtype))),
|
||||
# a.where(b.where(c, d), d) -> (a & b).where(c, d)
|
||||
(UPat.var("a").where(UPat.var("b").where(UPat.var("c"), UPat.var("d")), UPat.var("d")), lambda a,b,c,d: (a&b).where(c,d)),
|
||||
# ** pow **
|
||||
((UPat(Ops.POW, name="p"), lambda p: xpow(*p.src))),
|
||||
# ** load/store folding **
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue