mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
3 commits
master
...
more_foldi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
20c45eb705 | ||
|
|
e2782cdf6e | ||
|
|
7e0a928004 |
2 changed files with 24 additions and 0 deletions
|
|
@ -655,6 +655,24 @@ class TestSymbolic(unittest.TestCase):
|
|||
with self.assertRaises(AssertionError):
|
||||
self.helper_test_variable((31 * b + 1) % 18 + ((31 * b + 1) // 18) * 18, 1, 3101, "((b*31)+1)")
|
||||
|
||||
def test_div_mod_recombine_3level(self):
|
||||
gidx = Variable("gidx", 0, 150527)
|
||||
self.helper_test_variable(gidx//3%224*3 + gidx%3 + gidx//672*672, 0, 150527, "gidx")
|
||||
# different shapes
|
||||
x = Variable("x", 0, 5*7*11-1)
|
||||
self.helper_test_variable(x//11%7*11 + x%11 + x//77*77, 0, 5*7*11-1, "x")
|
||||
# result is x//a*c2 not just x
|
||||
x2 = Variable("x2", 0, 5*6*7-1)
|
||||
self.helper_test_variable(x2//7%6*14 + x2//42*84, 0, (5*6*7-1)//7*14, "(x2//7*14)")
|
||||
# negative variable range
|
||||
xn = Variable("x", -1000, 1000)
|
||||
self.helper_test_variable(xn//3%224*3 + xn%3 + xn//672*672, -1000, 1000, "x")
|
||||
self.helper_test_variable(xn//3%7*3 + xn//21*21, -999, 999, "(x//3*3)")
|
||||
# should NOT simplify: a*c1 != b (3*224 != 600)
|
||||
self.helper_test_variable(gidx//3%224*3 + gidx//600*600, 0, 150669, "(gidx//600*600+gidx//3%224*3)")
|
||||
# should NOT simplify: c1*c2 != c3 (224*3 != 700)
|
||||
self.helper_test_variable(gidx//3%224*3 + gidx//672*700, 0, 156769, "(gidx//672*700+gidx//3%224*3)")
|
||||
|
||||
def test_div_mod_recombine_with_gcd(self):
|
||||
b = Variable("b", 0, 100)
|
||||
exp = (16 * b + 2) % 18 + ((16 * b + 2) // 18) * 18
|
||||
|
|
|
|||
|
|
@ -50,6 +50,8 @@ symbolic_simple = propagate_invalid + PatternMatcher([
|
|||
(UPat.var("x")%UPat.cvar("c")+(UPat.var("x")//UPat.cvar("c"))*UPat.cvar("c"), lambda x,c: x), # (x%c)+(x//c)*c = x
|
||||
((UPat.var("x")//UPat.cvar("a"))%UPat.cvar("c")+(UPat.var("x")//UPat.cvar("b"))*UPat.cvar("c"),
|
||||
lambda x,a,b,c: x//a if a.arg*c.arg==b.arg else None), # ((x//a)%c)+(x//a*c)*c = x//a. Note if a = 1 it degenerates to the one above
|
||||
((UPat.var("x")//UPat.cvar("a"))%UPat.cvar("c1")*UPat.cvar("c2")+(UPat.var("x")//UPat.cvar("b"))*UPat.cvar("c3"),
|
||||
lambda x,a,b,c1,c2,c3: x//a*c2 if c1.arg>0 and a.arg*c1.arg==b.arg and c1.arg*c2.arg==c3.arg else None),
|
||||
((UPat.var("x")//UPat.cvar("c1"))*UPat.cvar("c3")+UPat.var("x")%UPat.cvar("c1")*UPat.cvar("c2"),
|
||||
lambda x,c1,c2,c3: x*c2 if c1.arg*c2.arg==c3.arg else None), # (x%c1)*c2+(x//c1)*c3 = x*c2 if c1*c2==c3
|
||||
((UPat.var("y")+(UPat.var("x")//UPat.cvar("c"))*UPat.cvar("c"))+UPat.var("x")%UPat.cvar("c"), lambda y,x,c: y+x),
|
||||
|
|
@ -58,6 +60,10 @@ symbolic_simple = propagate_invalid + PatternMatcher([
|
|||
lambda y,x,c1,c2,c3: y+x*c2 if c1.arg*c2.arg==c3.arg else None),
|
||||
((UPat.var("y")+UPat.var("x")%UPat.cvar("c1")*UPat.cvar("c2"))+(UPat.var("x")//UPat.cvar("c1"))*UPat.cvar("c3"),
|
||||
lambda y,x,c1,c2,c3: y+x*c2 if c1.arg*c2.arg==c3.arg else None),
|
||||
((UPat.var("y")+(UPat.var("x")//UPat.cvar("a"))%UPat.cvar("c1")*UPat.cvar("c2"))+(UPat.var("x")//UPat.cvar("b"))*UPat.cvar("c3"),
|
||||
lambda y,x,a,b,c1,c2,c3: y+x//a*c2 if c1.arg>0 and a.arg*c1.arg==b.arg and c1.arg*c2.arg==c3.arg else None),
|
||||
((UPat.var("y")+(UPat.var("x")//UPat.cvar("b"))*UPat.cvar("c3"))+(UPat.var("x")//UPat.cvar("a"))%UPat.cvar("c1")*UPat.cvar("c2"),
|
||||
lambda y,x,a,b,c1,c2,c3: y+x//a*c2 if c1.arg>0 and a.arg*c1.arg==b.arg and c1.arg*c2.arg==c3.arg else None),
|
||||
(UPat.var("x", dtype=dtypes.bool) & UPat.cvar("c", vec=False), lambda x,c: x if c.arg else c),
|
||||
(UPat.var("x", dtype=dtypes.bool) | UPat.cvar("c", vec=False), lambda x,c: c if c.arg else x),
|
||||
(UPat(GroupOp.Idempotent, src=(UPat.var("x"), UPat.var("x"))), lambda x: x),
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue