more tests

This commit is contained in:
George Hotz 2026-02-20 17:55:56 +08:00
commit e2782cdf6e
2 changed files with 11 additions and 1 deletions

View file

@ -658,6 +658,16 @@ class TestSymbolic(unittest.TestCase):
def test_div_mod_recombine_3level(self):
gidx = Variable("gidx", 0, 150527)
self.helper_test_variable(gidx//3%224*3 + gidx%3 + gidx//672*672, 0, 150527, "gidx")
# different shapes
x = Variable("x", 0, 5*7*11-1)
self.helper_test_variable(x//11%7*11 + x%11 + x//77*77, 0, 5*7*11-1, "x")
# result is x//a*c2 not just x
x2 = Variable("x2", 0, 5*6*7-1)
self.helper_test_variable(x2//7%6*14 + x2//42*84, 0, (5*6*7-1)//7*14, "(x2//7*14)")
# should NOT simplify: a*c1 != b (3*224 != 600)
self.helper_test_variable(gidx//3%224*3 + gidx//600*600, 0, 150669, "(gidx//600*600+gidx//3%224*3)")
# should NOT simplify: c1*c2 != c3 (224*3 != 700)
self.helper_test_variable(gidx//3%224*3 + gidx//672*700, 0, 156769, "(gidx//672*700+gidx//3%224*3)")
def test_div_mod_recombine_with_gcd(self):
b = Variable("b", 0, 100)

View file

@ -51,7 +51,7 @@ symbolic_simple = propagate_invalid + PatternMatcher([
((UPat.var("x")//UPat.cvar("a"))%UPat.cvar("c")+(UPat.var("x")//UPat.cvar("b"))*UPat.cvar("c"),
lambda x,a,b,c: x//a if a.arg*c.arg==b.arg else None), # ((x//a)%c)+(x//a*c)*c = x//a. Note if a = 1 it degenerates to the one above
((UPat.var("x")//UPat.cvar("a"))%UPat.cvar("c1")*UPat.cvar("c2")+(UPat.var("x")//UPat.cvar("b"))*UPat.cvar("c3"),
lambda x,a,b,c1,c2,c3: x//a*c2 if c1.arg>0 and a.arg*c1.arg==b.arg and c1.arg*c2.arg==c3.arg else None),
lambda x,a,b,c1,c2,c3: x//a*c2 if a.arg*c1.arg==b.arg and c1.arg*c2.arg==c3.arg else None),
((UPat.var("x")//UPat.cvar("c1"))*UPat.cvar("c3")+UPat.var("x")%UPat.cvar("c1")*UPat.cvar("c2"),
lambda x,c1,c2,c3: x*c2 if c1.arg*c2.arg==c3.arg else None), # (x%c1)*c2+(x//c1)*c3 = x*c2 if c1*c2==c3
((UPat.var("y")+(UPat.var("x")//UPat.cvar("c"))*UPat.cvar("c"))+UPat.var("x")%UPat.cvar("c"), lambda y,x,c: y+x),