use FLOORDIV and FLOORMOD (#16048)

* use FLOORDIV and FLOORMOD

also removed CORRECT_DIVMOD_FOLDING

* fix

* Revert "fix"

This reverts commit 86af33b88ef31943c61e67189b072eca4896409a.

* fix

* fix
This commit is contained in:
chenyu 2026-05-05 18:32:54 -04:00 committed by GitHub
commit 34fe37d64e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 252 additions and 223 deletions

View file

@ -417,7 +417,7 @@ jobs:
llvm: 'true'
- name: Test openpilot model kernel count and gate usage
run: |
ALLOWED_KERNEL_COUNT=123 ALLOWED_READ_IMAGE=1486 ALLOWED_GATED_READ_IMAGE=17 FLOAT16=1 DEV=CL IMAGE=1 python examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/cf6376aa9a090f0da26c280ef69eabf9bbdd51d1faac9ed392919c3db69be916
ALLOWED_KERNEL_COUNT=123 ALLOWED_READ_IMAGE=1486 ALLOWED_GATED_READ_IMAGE=18 FLOAT16=1 DEV=CL IMAGE=1 python examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/cf6376aa9a090f0da26c280ef69eabf9bbdd51d1faac9ed392919c3db69be916
- name: Test openpilot CL compile fp16
run: FLOAT16=1 DEV=CL IMAGE=1 python examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/cf6376aa9a090f0da26c280ef69eabf9bbdd51d1faac9ed392919c3db69be916
- name: Test openpilot CL compile fp32 (test correctness)

View file

@ -7,7 +7,7 @@ import z3
from tinygrad import Variable, dtypes
from tinygrad.uop.ops import UOp
from tinygrad.uop.validate import uops_to_z3
from tinygrad.helpers import DEBUG, Context
from tinygrad.helpers import DEBUG
seed = int(sys.argv[1]) if len(sys.argv) > 1 else random.randint(0, 100)
print(f"Seed: {seed}", flush=True)
@ -56,8 +56,7 @@ if __name__ == "__main__":
v = [u1,u2,u3]
expr = random_int_expr(6)
with Context(CORRECT_DIVMOD_FOLDING=1):
simplified_expr = expr.simplify()
simplified_expr = expr.simplify()
solver = z3.Solver(ctx=z3.Context())
solver.set(timeout=5000) # some expressions take very long verify, but its very unlikely they actually return sat
@ -74,10 +73,9 @@ if __name__ == "__main__":
m = solver.model()
n1, n2, n3 = m[v1], m[v2], m[v3]
u1_val, u2_val, u3_val = u1.const_like(n1.as_long()), u2.const_like(n2.as_long()), u3.const_like(n3.as_long())
with Context(CORRECT_DIVMOD_FOLDING=1):
num = expr.simplify().substitute({u1:u1_val, u2:u2_val, u3:u3_val}).ssimplify()
rn = expr.substitute({u1:u1_val, u2:u2_val, u3:u3_val}).ssimplify()
if num==rn: print("z3 found a mismatch but the expressions are equal!!")
num = expr.simplify().substitute({u1:u1_val, u2:u2_val, u3:u3_val}).ssimplify()
rn = expr.substitute({u1:u1_val, u2:u2_val, u3:u3_val}).ssimplify()
if num==rn: print("z3 found a mismatch but the expressions are equal!!")
assert False, f"mismatched {expr.render()} at v1={m[v1]}; v2={m[v2]}; v3={m[v3]} = {num} != {rn}\n" +\
"Reproduce with:\n" +\
f"v1=Variable(\"{u1.arg[0]}\", {u1.arg[1]}, {u1.arg[2]})\n" +\

View file

@ -2,7 +2,7 @@ import random, sys
import z3
from tinygrad.uop.ops import UOp, Ops
from tinygrad.uop.validate import uops_to_z3
from tinygrad.helpers import DEBUG, Context, colored
from tinygrad.helpers import DEBUG, colored
seed = int(sys.argv[1]) if len(sys.argv) > 1 else random.randint(0, 100)
print(f"Seed: {seed}", flush=True)
@ -36,8 +36,7 @@ if __name__ == "__main__":
variable_names += [f"r{i}" for i in range(num_ranges)]
expr = get_random_expr(ranges, factors)
with Context(CORRECT_DIVMOD_FOLDING=1):
simplified_expr = expr.simplify()
simplified_expr = expr.simplify()
if DEBUG>=1:
print(expr.render(simplify=False), " --> ", simplified_expr.render(simplify=False))

View file

@ -1,12 +1,18 @@
import unittest, itertools
from tinygrad.codegen.late.devectorizer import load_store_indexing
from tinygrad.dtype import dtypes
from tinygrad.uop.ops import UOp, Ops
from tinygrad.uop.symbolic import simplify_valid
from tinygrad.uop.ops import UOp, Ops, graph_rewrite
from tinygrad.uop.symbolic import simplify_valid, sym, pm_move_where_on_load
from tinygrad.helpers import Context
from test.helpers import full_rewrite
from test.null.test_uop_symbolic import check_uop_against_string
# symbolic-only idx + valid simplification (no late lowering of FLOORDIV/FLOORMOD)
def simplify_valid_idx(sink: UOp) -> UOp: return graph_rewrite(sink, sym+pm_move_where_on_load, name="simplify_valid_idx")
# image-aware idx + valid simplification: adds the codegen-layer matcher that drops provably in-bounds gates
def simplify_image_idx(sink: UOp) -> UOp: return graph_rewrite(sink, sym+pm_move_where_on_load+load_store_indexing, name="simplify_image_idx")
def get_gated_load_uop(valid:UOp, idx:UOp):
return UOp(Ops.LOAD, dtypes.float, (
UOp(Ops.PARAM, dtypes.float.ptr(), arg=0).index(idx.valid(valid), ptr=True),
@ -47,11 +53,10 @@ class TestHelpers(unittest.TestCase):
class TestValidIdxSimplification(unittest.TestCase):
def check(self, load, sidx, svalid, extra=()):
with Context(NOOPT=1, SPEC=0):
load = full_rewrite(UOp.sink(load, *extra)).src[0]
idx, valid = load.src[0].src[1], load.src[0].src[2]
check_uop_against_string(self, idx, sidx)
check_uop_against_string(self, valid, svalid)
load = simplify_valid_idx(UOp.sink(load, *extra)).src[0]
off = load.src[0].src[1]
check_uop_against_string(self, off.get_idx(), sidx)
check_uop_against_string(self, off.get_valid(), svalid)
def test_cumsum(self):
gidx0 = Special("gidx0", 5)
@ -216,18 +221,18 @@ class TestValidIdxSimplification(unittest.TestCase):
class TestImageSimplification(unittest.TestCase):
def check(self, load, svalid, sidx0, sidx1):
with Context(NOOPT=1, SPEC=0):
load = full_rewrite(load.sink()).src[0]
idx = load.src[0].src[1]
load = simplify_image_idx(load.sink()).src[0]
off = load.src[0].src[1]
idx = off.get_idx()
self.assertEqual(idx.op, Ops.STACK)
self.assertEqual(len(idx.src), 2)
idx0, idx1 = idx.src[0], idx.src[1]
check_uop_against_string(self, idx0, sidx0)
check_uop_against_string(self, idx1, sidx1)
if svalid is not None:
check_uop_against_string(self, load.src[0].src[2], svalid)
check_uop_against_string(self, off.get_valid(), svalid)
else:
self.assertEqual(len(load.src[0].src), 2, "svalid is None but load still has a valid")
self.assertEqual(off.get_valid(), UOp.const(dtypes.bool, True), "svalid is None but valid is not True")
def test_idx_gt_c(self):
# (idx1 < c+1).ne(True) ? (..., idx1-1+c) : 0 can drop the valid
@ -447,12 +452,12 @@ class TestImageSimplification(unittest.TestCase):
load = get_load_image_uop((32, 1024, 4), valid, (alu0, alu1))
self.check(load, None, "(lidx1*128+gidx0//2+144)", "(lidx0*2+r0+-3)")
# TODO: this is the same idx as above, but simplifying idx too early makes it hard to drop the valid
# same idx, written without the inline simplification of the inner div/mod
alu0 = ((gidx0*2+lidx1*512+(lidx0*8192+r0*4096)+-11711)//4%1024)
alu1 = (lidx0*2+r0+-3)
valid = ((lidx1<7)&((((lidx0*2+r0)<3)!=1)&((lidx0*2+r0)<35)))
load = get_load_image_uop((32, 1024, 4), valid, (alu0, alu1))
self.check(load, "(lidx1<7)", "((gidx0*2+lidx1*512+(lidx0*8192+r0*4096)+-11711)//4%1024)", "(lidx0*2+r0+-3)")
self.check(load, None, "(lidx1*128+gidx0//2+144)", "(lidx0*2+r0+-3)")
def test_simplify8(self):
# from openpilot compile3, kernel r_4_16_8_16_4_4_3_3n1

View file

@ -1,16 +1,8 @@
import unittest
from tinygrad import Variable
from tinygrad.helpers import Context
class TestFuzzFailure(unittest.TestCase):
def setUp(self):
self.context = Context(CORRECT_DIVMOD_FOLDING=1)
self.context.__enter__()
def tearDown(self):
self.context.__exit__(None, None, None)
def test_fuzz_failure1(self):
v1=Variable('v1', 0, 8)
v2=Variable('v2', 0, 2)

View file

@ -3,7 +3,6 @@ import unittest, pickle, functools, math
import z3
from tinygrad.dtype import dtypes, ConstType, DType, Invalid
from tinygrad.helpers import Context
from test.helpers import get_uops
from tinygrad.uop.ops import UOp, Ops, graph_rewrite, sym_infer
from tinygrad.uop.symbolic import sym, commutative, pm_simplify_valid, pm_move_where_on_load
@ -181,8 +180,8 @@ class TestSymbolic(unittest.TestCase):
self.helper_test_variable(Variable("a", 0, 8)*1, 0, 8, "a")
def test_mul_neg_1(self):
self.helper_test_variable((Variable("a", 0, 2)*-1)//3, 0, 0, "0")
self.helper_test_variable((Variable("a", 2, 7)*-1)//3, -2, 0, "((a//3)*-1)")
self.helper_test_variable((Variable("a", 0, 2)*-1)//3, -1, 0, "((a*-1)//3)")
self.helper_test_variable((Variable("a", 2, 7)*-1)//3, -3, -1, "((a*-1)//3)")
def test_mul_2(self):
self.helper_test_variable(Variable("a", 0, 8)*2, 0, 16, "(a*2)")
@ -203,8 +202,8 @@ class TestSymbolic(unittest.TestCase):
self.helper_test_variable(Variable("a", 0, 7) // 20, 0, 0, "0")
def test_div_neg_min_max(self):
self.helper_test_variable(Variable("a", 1, 7) // -2, -3, 0, "((a//2)*-1)")
self.helper_test_variable(Variable("a", 0, 6) // -2, -3, 0, "((a//2)*-1)")
self.helper_test_variable(Variable("a", 1, 7) // -2, -4, -1, "(a//-2)")
self.helper_test_variable(Variable("a", 0, 6) // -2, -3, 0, "(a//-2)")
def test_div_mod_zero(self):
with self.assertRaises(ZeroDivisionError):
@ -238,14 +237,14 @@ class TestSymbolic(unittest.TestCase):
def test_mod_min_max(self):
self.helper_test_variable(Variable("x", 0, 10)%Variable("y", 1, 10), 0, 9, "(x%y)")
self.helper_test_variable(Variable("x", -10, 0)%Variable("y", 1, 10), -9, 0, "(((x*-1)%y)*-1)")
self.helper_test_variable(Variable("x", 0, 10)%Variable("y", -10, -1), 0, 9, "(x%(y*-1))")
self.helper_test_variable(Variable("x", -10, 0)%Variable("y", -10, -1), -9, 0, "(((x*-1)%(y*-1))*-1)")
self.helper_test_variable(Variable("x", -10, 10)%Variable("y", -10, -1), -9, 9, "(x%(y*-1))")
self.helper_test_variable(Variable("x", -10, 0)%Variable("y", 1, 10), 0, 9, "(x%y)")
self.helper_test_variable(Variable("x", 0, 10)%Variable("y", -10, -1), -9, 0, "(x%y)")
self.helper_test_variable(Variable("x", -10, 0)%Variable("y", -10, -1), -9, 0, "(x%y)")
self.helper_test_variable(Variable("x", -10, 10)%Variable("y", -10, -1), -9, 0, "(x%y)")
# test _min_max directly without the rewrite taking out the sign
# test _min_max directly: floor mod with positive divisor is in [0, c-1]; with negative divisor in [c+1, 0]
self.assertEqual((Variable("x", -10, 0)%Variable("y", -10, -1))._min_max, (-9, 0))
self.assertEqual((Variable("x", -10, 0)%Variable("y", 1, 10))._min_max, (-9, 0))
self.assertEqual((Variable("x", -10, 0)%Variable("y", 1, 10))._min_max, (0, 9))
def test_range_div_its_symbolic_bound(self):
a = Variable("a", 1, 10, dtypes.weakint)
@ -262,12 +261,12 @@ class TestSymbolic(unittest.TestCase):
self.helper_test_variable(Variable("a", 0, 6) // 2, 0, 3, "(a//2)")
self.helper_test_variable(Variable("x", 0, 10)//Variable("y", 1, 10), 0, 10, "(x//y)")
self.helper_test_variable(Variable("x", -10, 0)//Variable("y", 1, 10), -10, 0, "(((x*-1)//y)*-1)")
self.helper_test_variable(Variable("x", 0, 10)//Variable("y", -10, -1), -10, 0, "((x//(y*-1))*-1)")
self.helper_test_variable(Variable("x", -10, 0)//Variable("y", -10, -1), 0, 10, "((x*-1)//(y*-1))")
self.helper_test_variable(Variable("x", -10, 0)//Variable("y", 1, 10), -10, 0, "(x//y)")
self.helper_test_variable(Variable("x", 0, 10)//Variable("y", -10, -1), -10, 0, "(x//y)")
self.helper_test_variable(Variable("x", -10, 0)//Variable("y", -10, -1), 0, 10, "(x//y)")
self.helper_test_variable(Variable("x", -10, 10)//Variable("y", 1, 10), -10, 10, "(x//y)")
self.helper_test_variable(Variable("x", -10, 10)//Variable("y", -10, -1), -10, 10, "((x//(y*-1))*-1)")
self.helper_test_variable(Variable("x", -10, 10)//Variable("y", -10, -1), -10, 10, "(x//y)")
def test_mod_factor(self):
self.helper_test_variable(usum([Variable("a", 0, 7)*100, Variable("b", 0, 3)*50]) % 100, 0, 50, "((b%2)*50)")
@ -334,12 +333,12 @@ class TestSymbolic(unittest.TestCase):
def test_mod_mod_wrong_sign(self):
v1=Variable("v1", 0, 128)
v3=Variable("v3", 0, 7)
self.helper_test_variable((((((v1%2)*2)+((v3+-1)%5))+-2)%5), -3, 4, "(v1%2*2+(v3+-1)%5+-2)")
self.helper_test_variable((((((v1%2)*2)+((v3+-1)%5))+-2)%5), 0, 4, "((v3+v1%2*2+-3)%5)")
def test_mod_mod_wrong_sign2(self):
v2=Variable("v2", 0, 8)
v3=Variable("v3", 0, 4)
self.helper_test_variable((((((v3+3)%7)+(v2+-2))%7)%7), -2, 6, "(((v2+((v3+3)%7))+-2)%7)")
self.helper_test_variable((((((v3+3)%7)+(v2+-2))%7)%7), 0, 6, "((v2+v3+1)%7)")
def test_mul_mul(self):
self.helper_test_variable((Variable("a", 0, 5)*10)*9, 0, 5*10*9, "(a*90)")
@ -357,21 +356,21 @@ class TestSymbolic(unittest.TestCase):
def test_div_const_div(self):
a = Variable("a", 0, 124)
self.helper_test_variable((a//2+1)//2, 0, 31, "((a+2)//4)")
self.helper_test_variable(((-a)//2-1)//2, -31, 0, "(((a+2)//4)*-1)")
self.helper_test_variable(((-a)//2+10)//2, -26, 5, "((((a//2)*-1)+10)//2)")
self.helper_test_variable(((-a)//2-1)//2, -32, -1, "((a*-1+2)//4+-1)")
self.helper_test_variable(((-a)//2+10)//2, -26, 5, "(a*-1//4+5)")
def test_div_const_div_wrong_sign(self):
a = Variable("a", 0, 124)
self.helper_test_variable(((a-10)//2+10)//2, 2, 33, "((((a+-10)//2)+10)//2)")
self.helper_test_variable(((a-10)//2+10)//2, 2, 33, "((a+2)//4+2)")
def test_div_const_div_wrong_sign_divisor(self):
a = Variable("a", 0, 124)
self.helper_test_variable(((a+10)//-2+10)//-4, -1, 14, "(((((a//2)*-1)+5)//4)*-1)")
self.helper_test_variable(((a+10)//-2+10)//-4, -2, 14, "(((a+10)//-2+10)//-4)")
def test_neg_mod(self):
a = Variable("a", 0, 124)
self.helper_test_variable((-a)%4, -3, 0, "((a%4)*-1)")
self.helper_test_variable(a%-4, 0, 3, "(a%4)")
self.helper_test_variable((-a)%4, 0, 3, "(a*-1%4)")
self.helper_test_variable(a%-4, -3, 0, "(a%-4)")
def test_distribute_mul(self):
self.helper_test_variable(usum([Variable("a", 0, 3), Variable("b", 0, 5)])*3, 0, 24, "((a*3)+(b*3))")
@ -387,11 +386,11 @@ class TestSymbolic(unittest.TestCase):
self.helper_test_variable(Variable("a", 0, 6)%100, 0, 6, "a")
def test_big_mod(self):
self.helper_test_variable(Variable("a", -20, 20)%10, -9, 9, "(a%10)")
self.helper_test_variable(Variable("a", -20, 0)%10, -9, 0, "(((a*-1)%10)*-1)")
self.helper_test_variable(Variable("a", -20, 1)%10, -9, 1, "(a%10)")
self.helper_test_variable(Variable("a", -20, 20)%10, 0, 9, "(a%10)")
self.helper_test_variable(Variable("a", -20, 0)%10, 0, 9, "(a%10)")
self.helper_test_variable(Variable("a", -20, 1)%10, 0, 9, "(a%10)")
self.helper_test_variable(Variable("a", 0, 20)%10, 0, 9, "(a%10)")
self.helper_test_variable(Variable("a", -1, 20)%10, -1, 9, "(a%10)")
self.helper_test_variable(Variable("a", -1, 20)%10, 0, 9, "(a%10)")
def test_ge_remove(self):
self.helper_test_variable(Variable("a", 0, 6) >= 25, 0, 0, "False")
@ -439,8 +438,8 @@ class TestSymbolic(unittest.TestCase):
self.helper_test_variable(c & c.logical_not(), False, False, "False")
def test_mod_factor_negative(self):
self.helper_test_variable(usum([uconst(-29), Variable("a", 0, 10), Variable("b", 0, 10)*28]) % 28, -27, 27, "(((a+(b*28))+-29)%28)")
self.helper_test_variable(usum([uconst(-29), Variable("a", 0, 100), Variable("b", 0, 10)*28]) % 28, -27, 27, "(((a+(b*28))+-29)%28)")
self.helper_test_variable(usum([uconst(-29), Variable("a", 0, 10), Variable("b", 0, 10)*28]) % 28, 0, 27, "((a+b*28+-29)%28)")
self.helper_test_variable(usum([uconst(-29), Variable("a", 0, 100), Variable("b", 0, 10)*28]) % 28, 0, 27, "((a+b*28+-29)%28)")
def test_sum_combine_num(self):
self.helper_test_variable(usum([uconst(29), Variable("a", 0, 10), uconst(-23)]), 6, 16, "(a+6)")
@ -448,22 +447,12 @@ class TestSymbolic(unittest.TestCase):
def test_sum_num_hoisted_and_factors_cancel_out(self):
self.helper_test_variable(usum([Variable("a", 0, 1) * -4 + 1, Variable("a", 0, 1) * 4]), 1, 1, "1")
@unittest.expectedFailure # only correct for floordiv, not truncdiv
def test_div_cancel(self):
self.helper_test_variable(usum([uconst(-40), Variable("a", 0, 10)*2, Variable("b", 0, 10)*40])//40, -1, 9, "(b+-1)")
def test_div_cancel_correct(self):
with Context(CORRECT_DIVMOD_FOLDING=1):
self.helper_test_variable(usum([uconst(-40), Variable("a", 0, 10)*2, Variable("b", 0, 10)*40])//40, -1, 9, "(((a+(b*20))+-20)//20)")
@unittest.expectedFailure # only correct for floordiv, not truncdiv
def test_mod_cancel(self):
self.helper_test_variable(usum([uconst(-40), Variable("a", 0, 10)*2, Variable("b", 0, 10)*40]) % 40, 0, 20, "(a*2)")
def test_mod_cancel_correct(self):
with Context(CORRECT_DIVMOD_FOLDING=1):
self.helper_test_variable(usum([uconst(-40), Variable("a", 0, 10)*2, Variable("b", 0, 10)*40]) % 40, -38, 38, "((((a+(b*20))+-20)%20)*2)")
def test_mul_div(self):
self.helper_test_variable((Variable("a", 0, 10)*4)//4, 0, 10, "a")
@ -475,22 +464,22 @@ class TestSymbolic(unittest.TestCase):
lidx1 = UOp.variable("lidx1", 0, 1)
ridx1005 = UOp.variable("ridx1005", 0, 2)
ridx1006 = UOp.variable("ridx1006", 0, 2)
self.helper_test_variable((lidx1+((gidx1*18)+(ridx1005*18)+(lidx0*162))+(gidx0*2)+(ridx1006*2)+-40)//18, -2, 20,
"(((((lidx1+(((gidx1*18)+(ridx1005*18))+(lidx0*162)))+(gidx0*2))+(ridx1006*2))+-40)//18)")
self.helper_test_variable((lidx1+((gidx1*18)+(ridx1005*18)+(lidx0*162))+(gidx0*2)+(ridx1006*2)+-40)//18, -3, 20,
"(gidx1+ridx1005+lidx0*9+(gidx0+ridx1006+7)//9+-3)")
def test_add_div(self):
# careful about the lower bounds and upper bounds
self.helper_test_variable((Variable("a", 0, 5)-2)//4, 0, 0, "0")
self.helper_test_variable((Variable("a", 0, 5)-1)//4, 0, 1, "((a+-1)//4)")
self.helper_test_variable((Variable("a", 0, 5)-2)//4, -1, 0, "((a+2)//4+-1)")
self.helper_test_variable((Variable("a", 0, 5)-1)//4, -1, 1, "((a+3)//4+-1)")
self.helper_test_variable((Variable("a", 0, 5))//4, 0, 1, "(a//4)")
self.helper_test_variable((Variable("a", 0, 5)+1)//4, 0, 1, "((a+1)//4)")
self.helper_test_variable((Variable("a", 0, 5)+2)//4, 0, 1, "((a+2)//4)")
self.helper_test_variable((Variable("a", 0, 5)+3)//4, 0, 2, "((a+3)//4)")
self.helper_test_variable((Variable("a", 0, 5)+4)//4, 1, 2, "((a//4)+1)")
self.helper_test_variable((Variable("a", 0, 5)+5)//4, 1, 2, "(((a+1)//4)+1)")
self.helper_test_variable((Variable("a", 0, 5)+4)//4, 1, 2, "(a//4+1)")
self.helper_test_variable((Variable("a", 0, 5)+5)//4, 1, 2, "((a+1)//4+1)")
def test_div_neg_rem(self):
self.helper_test_variable((-Variable("a", 0, 255)+256)//2, 0, 128, "((((a+1)//2)*-1)+128)")
self.helper_test_variable((-Variable("a", 0, 255)+256)//2, 0, 128, "(a*-1//2+128)")
def test_mul_div_factor_mul(self):
self.helper_test_variable((Variable("a", 0, 10)*8)//4, 0, 20, "(a*2)")
@ -502,7 +491,7 @@ class TestSymbolic(unittest.TestCase):
self.helper_test_variable((Variable("a", 0, 10)*4)//8, 0, 5, "(a//2)")
def test_mul_div_factor_div_neg(self):
self.helper_test_variable((Variable("a", 0, 10)*-4+4)//8, -4, 0, "(((a*-1)+1)//2)")
self.helper_test_variable((Variable("a", 0, 10)*-4+4)//8, -5, 0, "((a*-1+1)//2)")
def test_div_symbolic_const_gcd(self):
a = Variable("a", -10, 10)
@ -520,8 +509,8 @@ class TestSymbolic(unittest.TestCase):
self.helper_test_variable((d1*a*d2*b*d1)//(d1*d2), -1000, 1000, "(a*(b*d1))", test_z3=False)
self.helper_test_variable((d1*a + b*d1)//(d1), -20, 20, "(a+b)", test_z3=False)
self.helper_test_variable((d1*a + b*d1 + c*d1)//(d1), -30, 30, "(c+(a+b))", test_z3=False)
self.helper_test_variable((3*a*d1 + 9*b*d1)//(3*d1*d2), -40, 40, "(((a+(b*3))//(d2*-1))*-1)", test_z3=False)
self.helper_test_variable((3*a*d1 + 9*b*d1+3)//(3*d1*d2), -401, 399, "(((((a*d1)+((b*d1)*3))+1)//((d1*d2)*-1))*-1)", test_z3=False)
self.helper_test_variable((3*a*d1 + 9*b*d1)//(3*d1*d2), -40, 40, "((a+b*3)//d2)", test_z3=False)
self.helper_test_variable((3*a*d1 + 9*b*d1+3)//(3*d1*d2), -401, 399, "((a*d1+b*d1*3+1)//(d1*d2))", test_z3=False)
def test_symbolic_factor_remainder_div(self):
a = Variable("a", 0, 10)
@ -532,7 +521,7 @@ class TestSymbolic(unittest.TestCase):
self.helper_test_variable((d*a*20+b*d*5+10)//(5*d), 0, 52, "((b+(a*4))+(2//d))")
def test_mod_gcd_factor_neg(self):
self.helper_test_variable((Variable("a", 0, 10)*-4+4)%8, -4, 4, "((((a*-1)+1)%2)*4)")
self.helper_test_variable((Variable("a", 0, 10)*-4+4)%8, 0, 4, "((a*-1+1)%2*4)")
def test_mod_gcd_fold_neg(self):
self.helper_test_variable((Variable("a", 0, 10)*-8+20)%4, 0, 0, "0")
@ -540,22 +529,21 @@ class TestSymbolic(unittest.TestCase):
def test_sum_div_partial_remove(self):
self.helper_test_variable(usum([Variable("idx0", 0, 127)*4, Variable("idx2", 0, 3)])//4, 0, 127, "idx0")
def test_cdiv_const_evaluation(self):
self.helper_test_variable((Variable("a", 0, 2)-12)//8, -1, -1, "-1")
self.helper_test_variable((-Variable("a", 0, 2))//7, 0, 0, "0")
def test_floordiv_const_evaluation(self):
self.helper_test_variable((Variable("a", 0, 2)-12)//8, -2, -2, "-2")
self.helper_test_variable((-Variable("a", 0, 2))//7, -1, 0, "(a*-1//7)")
def test_cmod_const_evaluation(self):
self.helper_test_variable((Variable("a", 1, 1)*-3)%8, -3, -3, "-3")
self.helper_test_variable((-Variable("a", 10, 10))%7, -3, -3, "-3")
def test_floormod_const_evaluation(self):
self.helper_test_variable((Variable("a", 1, 1)*-3)%8, 5, 5, "5")
self.helper_test_variable((-Variable("a", 10, 10))%7, 4, 4, "4")
def test_div_numerator_negative(self):
with Context(CORRECT_DIVMOD_FOLDING=1):
self.helper_test_variable((Variable("idx", 0, 9)*-10)//11, -8, 0, "(((idx*10)//11)*-1)")
self.helper_test_variable((Variable("idx", 0, 9)*-10)//11, -9, 0, "(idx*-1)")
def test_nest_div_negative_factor(self):
ridx0=Variable("ridx0", 0, 9)
ridx1=Variable("ridx1", 0, 6)
self.helper_test_variable(((((ridx0*-7)+ridx1)+63)//35), 0, 1, "(((ridx0//5)*-1)+1)")
self.helper_test_variable(((((ridx0*-7)+ridx1)+63)//35), 0, 1, "((ridx1+ridx0*-7+28)//35+1)")
def test_div_into_mod(self):
self.helper_test_variable((Variable("idx", 0, 16)*4)%8//4, 0, 1, "(idx%2)")
@ -568,11 +556,11 @@ class TestSymbolic(unittest.TestCase):
self.helper_test_variable(x%12//4*4 + x%4 + x//12*12, 0, 23, "x")
def test_div_neg_cancel(self):
self.helper_test_variable((-Variable("idx", 0, 100)+199)//-4 + 50, 1, 26, "((idx//4)+1)")
self.helper_test_variable((-Variable("idx", 0, 100)+200)//-4 + 50, 0, 25, "((idx+3)//4)")
self.helper_test_variable((-Variable("idx", 0, 100)+201)//-4 + 50, 0, 25, "((idx+2)//4)")
self.helper_test_variable((-Variable("idx", 0, 100))//2, -50, 0, "((idx//2)*-1)")
self.helper_test_variable(Variable("idx", 0, 100)//-2, -50, 0, "((idx//2)*-1)")
self.helper_test_variable((-Variable("idx", 0, 100)+199)//-4 + 50, 0, 25, "((idx*-1+199)//-4+50)")
self.helper_test_variable((-Variable("idx", 0, 100)+200)//-4 + 50, 0, 25, "((idx*-1+200)//-4+50)")
self.helper_test_variable((-Variable("idx", 0, 100)+201)//-4 + 50, -1, 24, "((idx*-1+201)//-4+50)")
self.helper_test_variable((-Variable("idx", 0, 100))//2, -50, 0, "(idx*-1//2)")
self.helper_test_variable(Variable("idx", 0, 100)//-2, -50, 0, "(idx//-2)")
def test_sum_div_big_const(self):
gidx0 = Variable("gidx0", 0, 24)
@ -647,22 +635,22 @@ class TestSymbolic(unittest.TestCase):
def test_div_neg_all_range(self):
gidx = Variable("gidx", 0, 124)
lidx = Variable("lidx", 0, 7)
self.helper_test_variable((-gidx*8-lidx+999)//-4 + 250, 1, 250, "(((gidx*2)+(lidx//4))+1)")
self.helper_test_variable((-gidx*8-lidx+1000)//-4 + 250, 0, 250, "((gidx*2)+((lidx+3)//4))")
self.helper_test_variable((-gidx*8-lidx+1001)//-4 + 250, 0, 250, "((gidx*2)+((lidx+2)//4))")
self.helper_test_variable((-gidx*8-lidx+1002)//-4 + 250, 0, 250, "((gidx*2)+((lidx+1)//4))")
self.helper_test_variable((-gidx*8-lidx+999)//-4 + 250, 0, 250, "((gidx*-8+lidx*-1+999)//-4+250)")
self.helper_test_variable((-gidx*8-lidx+1000)//-4 + 250, 0, 249, "((gidx*-8+lidx*-1+1000)//-4+250)")
self.helper_test_variable((-gidx*8-lidx+1001)//-4 + 250, -1, 249, "((gidx*-8+lidx*-1+1001)//-4+250)")
self.helper_test_variable((-gidx*8-lidx+1002)//-4 + 250, -1, 249, "((gidx*-8+lidx*-1+1002)//-4+250)")
def test_div_neg_then_neg(self):
# taken from arange opts
lidx0 = Variable("lidx0", 0, 7)
lidx1 = Variable("lidx1", 0, 7)
alu2 = -lidx0-lidx1
self.helper_test_variable((((alu2+14)//(-32))+4), 4, 4, "4")
self.helper_test_variable(-(((alu2+14)//(-32))+4), -4, -4, "-4")
self.helper_test_variable((((alu2+134)//(-32))+4), 0, 1, "(((lidx0+lidx1)+25)//32)")
self.helper_test_variable((((alu2+142)//(-32))+4), 0, 0, "0")
self.helper_test_variable((((alu2+150)//(-32))+4), 0, 0, "0")
self.helper_test_variable((((alu2+158)//(-32))+4), 0, 0, "0")
self.helper_test_variable((((alu2+14)//(-32))+4), 3, 4, "((lidx0*-1+lidx1*-1+14)//-32+4)")
self.helper_test_variable(-(((alu2+14)//(-32))+4), -4, -3, "((lidx0*-1+lidx1*-1+14)//-32*-1+-4)")
self.helper_test_variable((((alu2+134)//(-32))+4), -1, 0, "((lidx0*-1+lidx1*-1+134)//-32+4)")
self.helper_test_variable((((alu2+142)//(-32))+4), -1, 0, "((lidx0*-1+lidx1*-1+142)//-32+4)")
self.helper_test_variable((((alu2+150)//(-32))+4), -1, -1, "-1")
self.helper_test_variable((((alu2+158)//(-32))+4), -1, -1, "-1")
def test_div_mod_recombine(self):
gidx = Variable("gidx", 0, 124)
@ -696,7 +684,7 @@ class TestSymbolic(unittest.TestCase):
# negative variable range
xn = Variable("x", -1000, 1000)
self.helper_test_variable(xn//3%224*3 + xn%3 + xn//672*672, -1000, 1000, "x")
self.helper_test_variable(xn//3%7*3 + xn//21*21, -999, 999, "(x//3*3)")
self.helper_test_variable(xn//3%7*3 + xn//21*21, -1002, 999, "(x//3*3)")
# should NOT simplify: a*c1 != b (3*224 != 600)
self.helper_test_variable(gidx//3%224*3 + gidx//600*600, 0, 150669, "(gidx//600*600+gidx//3%224*3)")
# should NOT simplify: c1*c2 != c3 (224*3 != 700)
@ -709,7 +697,7 @@ class TestSymbolic(unittest.TestCase):
self.helper_test_variable((30 * b + 1) % 18 + ((30 * b + 1) // 18) * 18, 1, 3001, "((b*30)+1)")
def test_div_partial_quotient(self):
# IDIV should extract partial quotients when const_factor > divisor, matching what MOD already does
# FLOORDIV should extract partial quotients when const_factor > divisor, matching what FLOORMOD already does
# (f*x+c)//d -> (f%d*x+c)//d + (f//d)*x when f >= d
b = Variable("b", 0, 100)
self.helper_test_variable((31*b+1)//18, 0, 172, "(((b*13)+1)//18+b)")
@ -730,8 +718,7 @@ class TestSymbolic(unittest.TestCase):
def test_div_by_factor_tie_break(self):
a = Variable("a", 0, 1)
b = Variable("b", 0, 1)
with Context(CORRECT_DIVMOD_FOLDING=1):
self.helper_test_variable((a*2+b*3+2)//6, 0, 1, "((a+b+1)//3)")
self.helper_test_variable((a*2+b*3+2)//6, 0, 1, "((a+b+1)//3)")
def test_div_mod_recombine_large_coeff(self):
# recombine must work even when coeff > divisor: both mod and div reduce the coeff the same way
@ -741,7 +728,7 @@ class TestSymbolic(unittest.TestCase):
self.helper_test_variable((25*a+3)%10 + ((25*a+3)//10)*10, 3, 253, "((a*25)+3)")
def test_mod_nest_by_factor(self):
# (a*f+b) % (f*k) = (a%k)*f + b when 0<=b<f — mirrors nest_div_by_factor for MOD
# (a*f+b) % (f*k) = (a%k)*f + b when 0<=b<f — mirrors nest_div_by_factor for FLOORMOD
gidx0 = Variable("gidx0", 0, 15)
lidx0 = Variable("lidx0", 0, 3)
# f=4, k=2, c=8: (gidx0*4+lidx0)%8 = (gidx0%2)*4 + lidx0
@ -755,7 +742,7 @@ class TestSymbolic(unittest.TestCase):
self.helper_test_variable((a*3+b)%9, 0, 8, "(b+a%3*3)")
def test_mod_nest_by_factor_with_const(self):
# nest_by_factor MOD with non-zero constant offset: (a*f+b+const) % (f*k) = (a%k)*f + b + const when 0<=b+const<f
# nest_by_factor FLOORMOD with non-zero constant offset: (a*f+b+const) % (f*k) = (a%k)*f + b + const when 0<=b+const<f
a = Variable("a", 0, 7)
b = Variable("b", 0, 1)
# f=4, k=2, const=2: (a*4+b+2)%8 = (a%2)*4 + b + 2
@ -767,7 +754,7 @@ class TestSymbolic(unittest.TestCase):
self.helper_test_variable((a*3+b+1)%6, 1, 5, "(b+a%2*3+1)")
def test_div_nest_by_factor_with_const(self):
# nest_by_factor IDIV: (160*a + 5*b + 4*c + K) // 60 should pick div=5 (clean) over div=4 (dirty)
# nest_by_factor FLOORDIV: (160*a + 5*b + 4*c + K) // 60 should pick div=5 (clean) over div=4 (dirty)
a = Variable("a", 0, 2)
b = Variable("b", 0, 31)
c = Variable("c", 0, 1)
@ -827,12 +814,26 @@ class TestSymbolic(unittest.TestCase):
# TODO: simplify the true branch
self.helper_test_variable((idx<4).where(idx//4, idx.const_like(-1)), -1, 6, "(idx<4).where((idx//4), -1)")
def test_idiv_lt(self):
def test_floordiv_lt(self):
# x//d<c <=> x<c*d for d>0
idx = Variable("idx", 0, 24)
self.helper_test_variable((idx//4<3), 0, 1, "(idx<12)")
self.helper_test_variable(((idx-20)//4<-3), 0, 1, "(idx<5)")
self.helper_test_variable(((idx-10)//4<0), 0, 1, "(idx<7)")
self.helper_test_variable((idx//-4<-3), 0, 1, "(((idx//4)*-1)<-3)")
self.helper_test_variable(((idx-20)//4<-3), 0, 1, "(idx<8)")
self.helper_test_variable(((idx-10)//4<0), 0, 1, "(idx<10)")
self.helper_test_variable((idx//-4<-3), 0, 1, "((idx//-4)<-3)")
def test_nested_div_mod_negative_inner_divisor(self):
# (x % (k*c)) // c -> (x // c) % k requires k>0; (x % (k*c)) % c -> x % c is unconditional for c>0
a = Variable("a", 0, 100)
self.helper_test_variable((a % -8) // 2, -4, 0, "(a%-8//2)")
self.helper_test_variable((a % -8) % 2, 0, 1, "(a%2)")
def test_floordiv_lt_negative_c(self):
# x//d<c with negative c also reduces to x<c*d for d>0
idx = Variable("idx", -20, 20)
self.helper_test_variable((idx//4 < 0), 0, 1, "(idx<0)")
self.helper_test_variable((idx//4 < -1), 0, 1, "(idx<-4)")
self.helper_test_variable((idx//4 < -2), 0, 1, "(idx<-8)")
def test_simplex_lt(self):
a = Variable("a", 0, 3)
@ -981,10 +982,10 @@ class TestSymbolic(unittest.TestCase):
self.assertIn((a.cast(dtypes.long)*b.cast(dtypes.long)).render(), "(long)((a*b))")
def test_nested_mod_negative_range(self):
# (x%(k*c))%c = x%c holds for cmod regardless of signs since sign(x%(k*c)) = sign(x)
# (x%(k*c))%c = x%c for positive c
x = Variable("x", 0, 1575)
self.helper_test_variable(((x + (-1064)) % 512) % 4, -3, 3, "((x+-1064)%4)")
self.helper_test_variable(((x + (-1064)) % 512) % 128, -127, 127, "((x+-1064)%128)")
self.helper_test_variable(((x + (-1064)) % 512) % 4, 0, 3, "((x+-1064)%4)")
self.helper_test_variable(((x + (-1064)) % 512) % 128, 0, 127, "((x+-1064)%128)")
class TestSymbolicNumeric(unittest.TestCase):
def helper_test_numeric(self, f):
@ -1062,12 +1063,13 @@ class TestSymInfer(unittest.TestCase):
assert sym_infer(a+b+c, var_vals) == 9
assert sym_infer(a*b, var_vals) == 6
assert sym_infer(a*b+c, var_vals) == 10
def test_sym_infer_cdiv_cmod(self):
def test_sym_infer_floordiv_floormod(self):
a = Variable("a", -1000, 1)
b = Variable("b", -1000, 1)
var_vals = {a.expr: 1, b.expr: -1000}
assert sym_infer(a%b, var_vals) == 1
assert sym_infer(a//b, var_vals) == 0
# floor: 1 % -1000 = -999, 1 // -1000 = -1
assert sym_infer(a%b, var_vals) == -999
assert sym_infer(a//b, var_vals) == -1
def test_sym_infer_with_bitcast(self):
a = Variable("a", 1, 10, dtypes.int)
expr = ((a.bitcast(dtypes.uint) << UOp.const(dtypes.uint, 1)).bitcast(dtypes.int) + 2)
@ -1286,7 +1288,8 @@ class TestGatedUopGivenValid(unittest.TestCase):
idx:UOp = (r0 < 3).where((r0 + uconst(-1)) // uconst(3), UOp.invalid())
idx = graph_rewrite(idx, pm_simplify_valid)
self.assertEqual(idx, (r0 < 3).where(uconst(0), UOp.invalid()))
# (r0-1)//3 = (r0+2)//3 - 1 (constant offset split)
self.assertEqual(idx, (r0 < 3).where((r0 + uconst(2)) // uconst(3) + uconst(-1), UOp.invalid()))
def test_invalid_gate_simplifies_vectorize(self):
r0 = Variable("r0", 0, 2)
@ -1295,8 +1298,8 @@ class TestGatedUopGivenValid(unittest.TestCase):
idx1 = r0 % uconst(3)
idx:UOp = (r0 < 3).where(UOp(Ops.STACK, dtypes.weakint.vec(2), (idx0, idx1)), UOp.invalid())
idx = graph_rewrite(idx, pm_simplify_valid)
# NOTE: independent simplification: (r0-1)//3 -> 0, r0%3 -> r0 when r0 in [0,2]
expected_vec = UOp(Ops.STACK, dtypes.weakint.vec(2), (uconst(0), r0))
# independent simplification: (r0-1)//3 -> (r0+2)//3 - 1, and r0%3 -> r0 when r0 in [0,2]
expected_vec = UOp(Ops.STACK, dtypes.weakint.vec(2), ((r0 + uconst(2)) // uconst(3) + uconst(-1), r0))
self.assertEqual(idx, (r0 < 3).where(expected_vec, UOp.invalid()))
class TestRangeSplitting(unittest.TestCase):
@ -1335,8 +1338,8 @@ class TestBounds(unittest.TestCase):
alu0 = gidx0 * -1
assert alu0.vmin == -2559 and alu0.vmax == 0
assert (alu0+2559).vmin == 0 and (alu0+2559).vmax == 2559
assert ((alu0+2559)//-4).vmin == -639 and ((alu0+2559)//-4).vmax == 0
assert (((alu0+2559)//-4)*(-1)).vmin == 0 and (((alu0+2559)//-4)*(-1)).vmax == 639
assert ((alu0+2559)//-4).vmin == -640 and ((alu0+2559)//-4).vmax == 0
assert (((alu0+2559)//-4)*(-1)).vmin == 0 and (((alu0+2559)//-4)*(-1)).vmax == 640
class TestFuzzFailure(unittest.TestCase):
def test_fuzz_failure1(self):

View file

@ -173,17 +173,15 @@ class TestVminVmaxDivMod(unittest.TestCase):
self.assertEqual(uop.vmax, 10)
def test_vmin_vmax_division_negative(self):
# vmin and vmax for division of a variable by a negative constant
# always positive
# floor division of a variable by a negative constant
x = UOp.variable('x', 10, 20)
uop = x // -2
self.assertEqual(uop.vmin, -10)
self.assertEqual(uop.vmax, -5)
uop = x // -3
self.assertEqual(uop.vmin, -6)
self.assertEqual(uop.vmax, -3)
self.assertEqual(uop.vmin, -7)
self.assertEqual(uop.vmax, -4)
# always negative
x = UOp.variable('x', -20, -10)
uop = x // -2
self.assertEqual(uop.vmin, 5)
@ -193,7 +191,6 @@ class TestVminVmaxDivMod(unittest.TestCase):
self.assertEqual(uop.vmax, 6)
def test_vmin_vmax_floordiv_floormod(self):
# FLOORDIV/FLOORMOD ranges differ from IDIV/MOD when the dividend can be negative
x = UOp.variable('x', -7, 7)
floordiv = x.alu(Ops.FLOORDIV, x.const_like(3))
self.assertEqual(floordiv.vmin, -3)
@ -212,32 +209,42 @@ class TestVminVmaxDivMod(unittest.TestCase):
self.assertEqual(uop.vmin, -5)
self.assertEqual(uop.vmax, 5)
uop = x // -3
self.assertEqual(uop.vmin, -3)
self.assertEqual(uop.vmin, -4)
self.assertEqual(uop.vmax, 3)
def test_vmin_vmax_floordiv_floormod_empty_range(self):
# empty numerator range (vmin > vmax, e.g. RANGE with end=0) short-circuits to (0, 0)
rng = UOp.range(0, 0)
self.assertEqual(rng.vmin, 0)
self.assertEqual(rng.vmax, -1)
self.assertEqual((rng // 4).vmin, 0)
self.assertEqual((rng // 4).vmax, 0)
self.assertEqual((rng % 4).vmin, 0)
self.assertEqual((rng % 4).vmax, 0)
def test_vmin_vmax_div_symbolic(self):
x = UOp.variable('x', 1, 10)
y = UOp.variable('y', 3, 5)
self.assertEqual((x//y).vmin, 0)
self.assertEqual((x//y).vmax, 3)
self.assertEqual(((-x)//y).vmin, -3)
self.assertEqual(((-x)//y).vmax, 0)
self.assertEqual((x//(-y)).vmin, -3)
self.assertEqual((x//(-y)).vmax, 0)
self.assertEqual(((-x)//y).vmin, -4)
self.assertEqual(((-x)//y).vmax, -1)
self.assertEqual((x//(-y)).vmin, -4)
self.assertEqual((x//(-y)).vmax, -1)
self.assertEqual(((-x)//(-y)).vmin, 0)
self.assertEqual(((-x)//(-y)).vmax, 3)
self.assertEqual((100//y).vmin, 20)
self.assertEqual((100//y).vmax, 33)
self.assertEqual(((-100)//y).vmin, -33)
self.assertEqual(((-100)//y).vmin, -34)
self.assertEqual(((-100)//y).vmax, -20)
self.assertEqual((100//(-y)).vmin, -33)
self.assertEqual((100//(-y)).vmin, -34)
self.assertEqual((100//(-y)).vmax, -20)
self.assertEqual(((-100)//(-y)).vmin, 20)
self.assertEqual(((-100)//(-y)).vmax, 33)
def test_vmin_vmax_mod_positive(self):
# vmin and vmax for modulo of a variable by a positive constant
# floor mod with positive divisor: result in [0, c-1] regardless of dividend sign
positive = UOp.variable('positive', 10, 20)
uop = positive % 3
self.assertEqual(uop.vmin, 0)
@ -245,20 +252,20 @@ class TestVminVmaxDivMod(unittest.TestCase):
negative = UOp.variable('negative', -20, -10)
uop = negative % 3
self.assertEqual(uop.vmin, -2)
self.assertEqual(uop.vmax, 0)
self.assertEqual(uop.vmin, 0)
self.assertEqual(uop.vmax, 2)
mixed = UOp.variable('mixed', -20, 20)
uop = mixed % 3
self.assertEqual(uop.vmin, -2)
self.assertEqual(uop.vmin, 0)
self.assertEqual(uop.vmax, 2)
def test_vmin_vmax_mod_negative(self):
# vmin and vmax for modulo of a variable by a negative constant
# floor mod with negative divisor: result in [c+1, 0] regardless of dividend sign
positive = UOp.variable('positive', 10, 20)
uop = positive % -3
self.assertEqual(uop.vmin, 0)
self.assertEqual(uop.vmax, 2)
self.assertEqual(uop.vmin, -2)
self.assertEqual(uop.vmax, 0)
negative = UOp.variable('negative', -20, -10)
uop = negative % -3
@ -268,7 +275,7 @@ class TestVminVmaxDivMod(unittest.TestCase):
mixed = UOp.variable('mixed', -20, 20)
uop = mixed % -3
self.assertEqual(uop.vmin, -2)
self.assertEqual(uop.vmax, 2)
self.assertEqual(uop.vmax, 0)
class TestVminVmaxVConst(unittest.TestCase):
def test_vmin_vmax_vconst_single_element(self):

View file

@ -177,6 +177,18 @@ class TestFastIdiv(unittest.TestCase):
self.assertIn(Ops.SHR, ops, f"For dtype={dt} divison by power of two did not simplify to shift")
self.assertNotIn(Ops.IDIV, ops, f"For dtype={dt} divison by power of two did not simplify to shift")
def test_floordiv_power_of_two_uint(self):
# uint FLOORDIV by a power of two lowers to a shift, leaving no IDIV/FLOORDIV in the kernel
for dt in (dtypes.uint32, dtypes.uint64):
g = UOp(Ops.PARAM, dt.ptr(), (), 0)
c = UOp.const(dt, 2)
a = UOp(Ops.FLOORDIV, dt, (g.index(c), c))
uops = to_uops_list([a], ren=Device[Device.DEFAULT].renderer)
ops = [x.op for x in uops]
self.assertIn(Ops.SHR, ops, f"For dtype={dt} FLOORDIV by power of two did not simplify to shift")
self.assertNotIn(Ops.IDIV, ops, f"For dtype={dt} FLOORDIV by power of two did not simplify to shift")
self.assertNotIn(Ops.FLOORDIV, ops, f"For dtype={dt} FLOORDIV survived past late rewrite")
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "WEBGPU doesn't support long")
def test_fast_idiv_and_mod(self):
g = UOp(Ops.PARAM, dtypes.uint32.ptr(), (), 0)

View file

@ -17,7 +17,8 @@ pm_flatten_range = PatternMatcher([
(UPat((Ops.REDUCE, Ops.END), name="r"), flatten_range),
])
def count_divmod(x:UOp) -> int: return sum(u.op in {Ops.IDIV, Ops.MOD} for u in x.backward_slice)
# index/range arithmetic uses FLOORDIV/FLOORMOD prior to late rewrite
def count_divmod(x:UOp) -> int: return sum(u.op in {Ops.FLOORDIV, Ops.FLOORMOD} for u in x.backward_slice)
def simplify_merge_adjacent(u:UOp) -> UOp|None:
reduce_ranges = [x.ranges for x in u.backward_slice_with_self if x.op is Ops.REDUCE]
# on END we only want to merge adjacent ranges, on REDUCE we want to try all combinations

View file

@ -241,7 +241,7 @@ SPLIT_REDUCEOP, NO_MEMORY_PLANNER, LRU = ContextVar("SPLIT_REDUCEOP", 1), Contex
RING, ALL2ALL, ALLREDUCE_CAST = ContextVar("RING", 1), ContextVar("ALL2ALL", 0), ContextVar("ALLREDUCE_CAST", 1)
CACHELEVEL, IGNORE_BEAM_CACHE, DEVECTORIZE = ContextVar("CACHELEVEL", 2), ContextVar("IGNORE_BEAM_CACHE", 0), ContextVar("DEVECTORIZE", 1)
VALIDATE_WITH_CPU, DISABLE_FAST_IDIV = ContextVar("VALIDATE_WITH_CPU", 0), ContextVar("DISABLE_FAST_IDIV", 0)
CORRECT_DIVMOD_FOLDING, FUSE_OPTIM = ContextVar("CORRECT_DIVMOD_FOLDING", 0), ContextVar("FUSE_OPTIM", 0)
FUSE_OPTIM = ContextVar("FUSE_OPTIM", 0)
ALLOW_DEVICE_USAGE, MAX_BUFFER_SIZE = ContextVar("ALLOW_DEVICE_USAGE", 1), ContextVar("MAX_BUFFER_SIZE", 0)
MAX_KERNEL_BUFFERS = ContextVar("MAX_KERNEL_BUFFERS", 0)
EMULATED_DTYPES = ContextVar("EMULATED_DTYPES", "")

View file

@ -181,7 +181,7 @@ class ElementwiseMixin(DTypeMixin, CreationMixin):
return self._binop(Ops.IDIV, x, reverse)
def mod(self, x: Self | ConstType, reverse: bool = False) -> Self:
return self._binop(Ops.MOD, x, reverse)
return self._binop(Ops.FLOORMOD, x, reverse)
def div(self, x: Self | ConstType, reverse: bool = False) -> Self:
lhs, rhs = self._broadcasted(x, reverse)
@ -206,7 +206,7 @@ class ElementwiseMixin(DTypeMixin, CreationMixin):
return self.div(x)
def __floordiv__(self, x: Self | ConstType) -> Self:
return self.idiv(x) # TODO: idiv is trunc div, not floordiv
return self._binop(Ops.FLOORDIV, x, False)
def __mod__(self, x: Self | ConstType) -> Self:
return self.mod(x)
@ -233,7 +233,7 @@ class ElementwiseMixin(DTypeMixin, CreationMixin):
return self.div(x, True)
def __rfloordiv__(self, x: Self | ConstType) -> Self:
return self.idiv(x, True)
return self._binop(Ops.FLOORDIV, x, True)
def __rand__(self, x: Self | ConstType) -> Self:
return self.bitwise_and(x, True)

View file

@ -290,8 +290,10 @@ def fast_idiv(target: Target, x: UOp, d: int, dont_cast=False) -> UOp|None:
if m*vmin >= x.dtype.min and m*vmax <= x.dtype.max:
return ((x*m) >> s) if is_unsigned else ((x*m) >> s) + (x<0).where(x.ufix(1), 0)
# before we try casting to a larger dtype (slow), we see if there are powers of two in d we can shift to make x smaller
# use explicit Ops.IDIV (trunc) since the recursion assumes trunc semantics throughout
if (largest_factor_of_two_in_d := (d & -d)) > 1:
if (ret:=fast_idiv(target, x//largest_factor_of_two_in_d, d//largest_factor_of_two_in_d, dont_cast=True)) is not None: return ret
if (ret:=fast_idiv(target, x.alu(Ops.IDIV, x.const_like(largest_factor_of_two_in_d)),
d//largest_factor_of_two_in_d, dont_cast=True)) is not None: return ret
if dont_cast: return None
# promo_lattice needs to return an unsigned type if the type is unsigned
if dtypes.is_int(next_dtype := promo_lattice[x.dtype.scalar()][-1]) and is_dtype_supported(next_dtype, target):
@ -459,22 +461,30 @@ def get_late_rewrite_patterns(ops:tuple[Ops, ...], disable_fast_idiv:bool) -> Pa
if Ops.THREEFRY not in ops: pat.append((UPat(Ops.THREEFRY, dtype=dtypes.uint64, src=(UPat.var("x"), UPat.var("key"))), threefry2x32))
# MAX can be rewritten as CMPLT + WHERE (max function is annoying on many cstyle backends)
if Ops.MAX not in ops and Ops.CMPLT in ops: pat.append((UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])))
# rewrite MOD to AND (which should always be supported, but not for generic in tests): x % (2**y) -> x & (2**y-1)
# TODO: drop the x.vmin>=0 guard once UOp `%` lowers to FLOORMOD instead of MOD
# rewrite FLOORMOD to AND on power-of-2 const: x % (2**y) -> x & (2**y-1) (correct floor mod for any sign in two's complement)
if Ops.AND in ops: pat += [(UPat.var("x", dtypes.ints)%UPat.cvar("c"),
lambda x,c: x & (c.arg-1) if c.arg in powers_of_two and x.vmin >= 0 else None)]
lambda x,c: x & (c.arg-1) if c.arg in powers_of_two else None)]
if Ops.OR in ops: pat += [(UPat.var("x", dtypes.bool).logical_not()&UPat.var("y", dtypes.bool).logical_not(),
lambda x,y: (x | y).logical_not())]
# rewrite MUL/IDIV to SHL+SHR: x*(2**y) -> shl(x,y) and x//(2**y) -> shr(x,y)
if Ops.SHL in ops: pat += [(UPat.var("x", dtypes.ints)*UPat.cvar("c"), lambda c,x: x << v if (v:=powers_of_two.get(c.arg, 0)) else None)]
if Ops.SHR in ops:
# no reason to check x<0 for uints
pat += [(UPat.var("x", dtypes.uints)//UPat.cvar("c"), lambda x,c: x >> v if (v:=powers_of_two.get(c.arg, 0)) else None)]
pat += [(UPat.var("x", dtypes.ints)//UPat.cvar("c"), lambda x,c: (x+(l.const_like(l.vmin) if (l:=(x<0)).vmin==l.vmax else l).where(
c-1, 0)) >> v if (v:=powers_of_two.get(c.arg, 0)) else None)] # (x+(x<0).where(c-1, 0)) >> v
# uint floor==trunc, so safe for both ops
pat += [(UPat((Ops.IDIV, Ops.FLOORDIV), src=(UPat.var("x", dtypes.uints), UPat.cvar("c"))),
lambda x,c: x >> v if (v:=powers_of_two.get(c.arg, 0)) else None)]
# signed FLOORDIV by 2**v -> (x + (x<0 ? c-1 : 0)) >> v
# signed IDIV (trunc) by 2**v -> (x + (x<0 ? c-1 : 0)) >> v; only correct for trunc, so match raw Ops.IDIV
pat += [(UPat(Ops.IDIV, src=(UPat.var("x", dtypes.ints), UPat.cvar("c"))),
lambda x,c: (x+(l.const_like(l.vmin) if (l:=(x<0)).vmin==l.vmax else l).where(c-1, 0)) >> v
if (v:=powers_of_two.get(c.arg, 0)) else None)]
if not disable_fast_idiv:
pat += [(UPat.var("x", dtypes.ints)//UPat.cvar("d", vec=False), lambda ctx, x, d: fast_idiv(ctx, x, d.arg))]
pat += [(UPat.var("x", dtypes.ints)%UPat.var("d"), lambda x, d: x-d*(x//d))]
# fast_idiv handles non-pow2: only fire on non-negative inputs (signed magic-mul is unreliable for x<0)
pat += [(UPat(Ops.IDIV, src=(UPat.var("x", dtypes.ints), UPat.cvar("d", vec=False))),
lambda ctx, x, d: fast_idiv(ctx, x, d.arg) if x.vmin >= 0 or x.dtype in dtypes.uints else None)]
# rewrite raw MOD -> x - d*IDIV(x,d) so fast_idiv can pick up the IDIV. only on non-negative inputs;
# avoids disturbing floormod_to_mod's general-path output (which uses a trunc Ops.MOD as an implementation detail)
pat += [(UPat(Ops.MOD, src=(UPat.var("x", dtypes.ints), UPat.var("d"))),
lambda x, d: x - d * x.alu(Ops.IDIV, d) if x.vmin >= 0 or x.dtype in dtypes.uints else None)]
if Ops.NEG in ops:
pat += [(UPat.var('x')*-1, lambda ctx,x: x.alu(Ops.NEG))]
if Ops.SUB in ops: pat += [(UPat.var('x')+UPat.var('y').alu(Ops.NEG), lambda ctx,x,y: x.alu(Ops.SUB, y))]

View file

@ -1,19 +1,19 @@
import functools, itertools, math
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp
from tinygrad.dtype import dtypes
from tinygrad.helpers import cdiv, cmod, CORRECT_DIVMOD_FOLDING, unwrap
from tinygrad.helpers import floordiv, floormod, unwrap
# NOTE: this cache is only on index UOps
@functools.cache
def fold_divmod_general(d: UOp, correct_divmod_folding: bool) -> UOp|None:
def fold_divmod_general(d: UOp) -> UOp|None:
x, y = d.src
# cancel_divmod: simple cancel div/mod case when the range of the numerator lies within a single denominator interval
x_min, x_max, y_min, y_max = x.vmin, x.vmax, y.vmin, y.vmax
assert isinstance(x_min, int) and isinstance(x_max, int) and isinstance(y_min, int) and isinstance(y_max, int)
if y_min==y_max==0: raise ZeroDivisionError(f"{'Division' if d.op is Ops.IDIV else 'Mod'} by zero trying to rewrite {x.alu(d.op, y)}")
if y_min*y_max > 0 and (qv:=cdiv(x_min,y_min)) == cdiv(x_min,y_max) == cdiv(x_max,y_min) == cdiv(x_max,y_max):
return x - qv*y if d.op is Ops.MOD else d.const_like(qv)
if y_min==y_max==0: raise ZeroDivisionError(f"{'Division' if d.op is Ops.FLOORDIV else 'Mod'} by zero trying to rewrite {x.alu(d.op, y)}")
if y_min*y_max > 0 and (qv:=floordiv(x_min,y_min)) == floordiv(x_min,y_max) == floordiv(x_max,y_min) == floordiv(x_max,y_max):
return x - qv*y if d.op is Ops.FLOORMOD else d.const_like(qv)
# split uops for the rest of the processing
x_peeled, const = x.pop_const()
@ -22,19 +22,20 @@ def fold_divmod_general(d: UOp, correct_divmod_folding: bool) -> UOp|None:
# ** Constant Denominator Rules **
# these rules strictly require y to be a scalar constant > 0
if y.op is Ops.CONST and (c := y.arg) > 0:
# nested_div_mod: (x%(k*c))//c -> (x//c)%k, and (x%(k*c))%c -> x%c
if x.op is Ops.MOD and (k := x.src[1].divides(c)) is not None:
return x.src[0] // y % k if d.op is Ops.IDIV else x.src[0] % y
# nested_div_mod: (x%(k*c))//c -> (x//c)%k (requires k>0), and (x%(k*c))%c -> x%c
if x.op is Ops.FLOORMOD and (k := x.src[1].divides(c)) is not None:
if d.op is Ops.FLOORMOD: return x.src[0] % y
if k > 0: return x.src[0] // y % k
# remove_nested_mod in sum: (a%4 + b)%2 -> (a+b)%2, requires non-negative sums
if d.op is Ops.MOD and x.vmin >= 0:
# remove_nested_mod in sum: (a%4 + b)%2 -> (a+b)%2
if d.op is Ops.FLOORMOD:
new_xs, changed = [], False
for u in uops_no_const:
if u.op is Ops.MOD and u.src[1].divides(c) is not None:
if u.op is Ops.FLOORMOD and u.src[1].divides(c) is not None:
u = u.src[0]
changed = True
new_xs.append(u)
if changed and (new_x:=(UOp.usum(*new_xs) + const)).vmin >= 0: return new_x % y
if changed: return (UOp.usum(*new_xs) + const) % y
# Shared decomposition for folding rules
decomp = [(u.divides(f:=u.const_factor()),f) for u in uops_no_const]
@ -42,32 +43,31 @@ def fold_divmod_general(d: UOp, correct_divmod_folding: bool) -> UOp|None:
# fold_binary_numerator: fold if expression has one non-constant term that takes on two values
if len(terms)==1 and (v:=terms[0]).vmax-v.vmin == 1:
y1 = (cmod if d.op is Ops.MOD else cdiv)(factors[0]*v.vmin+const, c)
y2 = (cmod if d.op is Ops.MOD else cdiv)(factors[0]*v.vmax+const, c)
y1 = (floormod if d.op is Ops.FLOORMOD else floordiv)(factors[0]*v.vmin+const, c)
y2 = (floormod if d.op is Ops.FLOORMOD else floordiv)(factors[0]*v.vmax+const, c)
return (y2-y1)*(v-v.vmin) + y1
# fold_divmod_congruence: fold if a is congruent to an expression whose range is between 0 and c
if not (x.vmin<0 and correct_divmod_folding):
# when f%c == c//2, abs(r) == abs(r-c) is a tie, try both signs since either may fit in one period
rem_choices = [(r, r-c) if (r:=f%c)*2 == c else (min(r, r-c, key=abs),) for f in factors]
for rems in itertools.product(*rem_choices):
if (rem:=sum(r*v for r,v in zip(rems,terms))+const%c).vmin//c==rem.vmax//c:
if d.op is Ops.MOD: return rem - rem.vmin//c*c
return sum((f-r)//c * v for f,r,v in zip(factors,rems,terms)) + const//c + rem.vmin//c
# when f%c == c//2, abs(r) == abs(r-c) is a tie, try both signs since either may fit in one period
rem_choices = [(r, r-c) if (r:=f%c)*2 == c else (min(r, r-c, key=abs),) for f in factors]
for rems in itertools.product(*rem_choices):
if (rem:=sum(r*v for r,v in zip(rems,terms))+const%c).vmin//c==rem.vmax//c:
if d.op is Ops.FLOORMOD: return rem - rem.vmin//c*c
return sum((f-r)//c * v for f,r,v in zip(factors,rems,terms)) + const//c + rem.vmin//c
# gcd_with_remainder: factor out common gcd from numerator
if x.vmin >= 0 and (g:=math.gcd(*factors, c)) > 1:
new_x = unwrap(x_peeled.divides(g)).simplify() + (const//g)%(c//g)
if new_x.vmin >= 0:
if d.op is Ops.MOD: return new_x % (c//g) * g + const%g
if d.op is Ops.FLOORMOD: return new_x % (c//g) * g + const%g
return new_x // (c//g) + const//c
# nest_by_factor: x//c -> (x//f)//(c//f), x%c -> (x//f%(c//f))*f + b where b=x%f
if x.vmin >= 0:
results = []
for div in {abs(f) for u, f in zip(uops_no_const, factors) if u.op not in (Ops.CONST, Ops.VCONST) and 1 < abs(f) < c and (c%f)==0}:
if (newxs := fold_divmod_general(x//div, correct_divmod_folding)) is not None and newxs.vmin >= 0:
if d.op is Ops.IDIV:
if (newxs := fold_divmod_general(x//div)) is not None and newxs.vmin >= 0:
if d.op is Ops.FLOORDIV:
results.append((len(newxs.backward_slice), newxs // (c // div)))
else:
b_parts = [f%div*t for f, t in zip(factors, terms) if f%div]
@ -86,7 +86,7 @@ def fold_divmod_general(d: UOp, correct_divmod_folding: bool) -> UOp|None:
gcd = UOp.gcd(*all_uops, y).simplify()
if not (gcd.op is Ops.CONST and gcd.arg==1):
ret = unwrap(x.divide_exact(gcd)).alu(d.op, unwrap(y.divide_exact(gcd)))
return ret*gcd if d.op is Ops.MOD else ret
return ret*gcd if d.op is Ops.FLOORMOD else ret
# factor_remainder: (d*x+y)//d -> x+y//d
if y.vmin<0 or x.vmin<0: return None
@ -95,29 +95,22 @@ def fold_divmod_general(d: UOp, correct_divmod_folding: bool) -> UOp|None:
if (q:=u.divide_exact(y)) is not None: quo.append(q)
elif y.op is Ops.CONST and (c:=u.const_factor())%y.arg!=c:
rem.append(u.divides(c)*(c%y.arg))
quo.append(u.divides(c)*(c//y.arg) if d.op is Ops.IDIV else u.const_like(0))
quo.append(u.divides(c)*(c//y.arg) if d.op is Ops.FLOORDIV else u.const_like(0))
else: rem.append(u)
if not quo: return None
new_x = sum(rem)+x.const_like(0)
if new_x.vmin<0: return None
return new_x%y if d.op is Ops.MOD else new_x//y+sum(quo)
return new_x%y if d.op is Ops.FLOORMOD else new_x//y+sum(quo)
div_and_mod_symbolic = PatternMatcher([
# ** 1. Fast Inline Rules **
((UPat.var("x")//UPat.cvar("c") + UPat.cvar("a"))//UPat.cvar("d"), lambda x,c,a,d: (x+a*c)//(c*d)
if c.vmin>0 and d.vmin>0 and x.vmin>=0 and a.vmin>=0 else None), # (x//c+a)//d -> (x+a*c)//(c*d)
(UPat.var("x", dtypes.weakint) // UPat.var("d"), lambda x,d: -(x//(-d)) if d.vmax < 0 else None),
(UPat.var("x", dtypes.weakint) // UPat.var("d"), lambda x,d: -((-x)//d) if x.vmax <= 0 else None),
((UPat.var("x", dtypes.weakint)+UPat.cvar("c", vec=False)).named("n")//UPat.cvar("d", vec=False),
lambda x,c,n,d: ((x+c.arg%d.arg)//d + c.arg//d.arg) if c.arg%d.arg!=c.arg and x.vmin>=0 and n.vmin>=0 and d.arg>0 else None),
((UPat.var("x", dtypes.weakint)+UPat.cvar("c", vec=False)).named("n")//UPat.cvar("d", vec=False),
lambda x,c,n,d: (-(-(c.arg%d.arg + x - (d.arg-1))//d) + c.arg//d.arg) if x.vmax<=0 and n.vmin>=0 and d.arg>0 else None),
# (x//c+a)//d -> (x+a*c)//(c*d) for c>0, d>0
((UPat.var("x")//UPat.cvar("c") + UPat.cvar("a"))//UPat.cvar("d"), lambda x,c,a,d: (x+a*c)//(c*d) if c.vmin>0 and d.vmin>0 else None),
# (x+c)//d -> (x+c%d)//d + c//d for d>0 (split out the multiple of d in the constant)
((UPat.var("x", dtypes.weakint)+UPat.cvar("c", vec=False))//UPat.cvar("d", vec=False),
lambda x,c,d: (x+c.arg%d.arg)//d + c.arg//d.arg if c.arg%d.arg!=c.arg and d.arg>0 else None),
# ** 2. Slow Rules **
(UPat((Ops.IDIV, Ops.MOD), dtypes.weakint, name="d"), lambda d: fold_divmod_general(d, bool(CORRECT_DIVMOD_FOLDING))),
# NOTE: these have to go at the bottom or TestSymbolicOps.test_var loops
(UPat.var("x", dtypes.weakint) % UPat.var("d"), lambda x,d: -((-x)%d) if x.vmax <= 0 else None),
(UPat.var("x", dtypes.weakint) % UPat.var("d"), lambda x,d: (x%(-d)) if d.vmax < 0 else None),
])
(UPat((Ops.FLOORDIV, Ops.FLOORMOD), dtypes.weakint, name="d"), lambda d: fold_divmod_general(d)),
])

View file

@ -869,9 +869,11 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
return min(vals:=(cdiv(s0_vmin, s1_vmin), cdiv(s0_vmin, s1_vmax), cdiv(s0_vmax, s1_vmin), cdiv(s0_vmax, s1_vmax))), max(vals)
if self.op is Ops.FLOORDIV:
assert isinstance(s0_vmin, int) and isinstance(s0_vmax, int) and isinstance(s1_vmin, int) and isinstance(s1_vmax, int)
if s0_vmin > s0_vmax: return 0, 0 # numerator range is empty (e.g. RANGE with end=0)
if s1_vmin*s1_vmax>0: return min(vals:=(s0_vmin//s1_vmin, s0_vmin//s1_vmax, s0_vmax//s1_vmin, s0_vmax//s1_vmax)), max(vals)
if self.op is Ops.FLOORMOD:
assert isinstance(s0_vmin, int) and isinstance(s0_vmax, int) and isinstance(s1_vmin, int) and isinstance(s1_vmax, int)
if s0_vmin > s0_vmax: return 0, 0 # numerator range is empty (e.g. RANGE with end=0)
if (c:=s1_vmin) == s1_vmax > 0: return (s0_vmin%c, s0_vmax%c) if s0_vmin//c == s0_vmax//c else (0, c-1)
if (c:=s1_vmin) == s1_vmax < 0: return (s0_vmin%c, s0_vmax%c) if s0_vmin//c == s0_vmax//c else (c+1, 0)
if s1_vmin > 0: return (0, s1_vmax-1)
@ -906,7 +908,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
# TODO: sanitize varnames, or don't use naked eval while staying fast
ret = _render_with_splits(list(sself.toposort()), renderer_infer, {sself})
lines = [f" {k}={v}" for k,v in ret.items() if k != "ast"] + [f" return {ret['ast']}"]
ns: dict[str, Any] = {"max": max, "cdiv": cdiv, "cmod": cmod, "bitcast": bitcast, "dtypes": dtypes}
ns: dict[str, Any] = {"max": max, "cdiv": cdiv, "cmod": cmod, "floordiv": floordiv, "floormod": floormod, "bitcast": bitcast, "dtypes": dtypes}
exec(f"def _f({','.join(varnames)}):\n"+'\n'.join(lines), ns) # pylint: disable=exec-used
return ns["_f"], varnames

View file

@ -23,10 +23,10 @@ def print_uops(uops:list[UOp]):
print(f"{i:4d} {str(u.op):20s}: {multirange_str(u.ranges, color=True, pad=10)} {str(u.dtype):40s} " f"{str(formatted_srcs):32s} {u.arg}")
# for debug
syms = { Ops.ADD: "+", Ops.SUB: "-", Ops.IDIV: "//", Ops.MOD: "%", Ops.SHL: "<<", Ops.SHR: ">>",
syms = { Ops.ADD: "+", Ops.SUB: "-", Ops.FLOORDIV: "//", Ops.FLOORMOD: "%", Ops.SHL: "<<", Ops.SHR: ">>",
Ops.MUL: "*", Ops.CMPLT: "<", Ops.CMPNE: "!=", Ops.AND: "&", Ops.OR: "|", Ops.XOR: "^"}
# comparison operators are not in here because they are chained in python, not left-associative
precedence = {Ops.MUL:1, Ops.IDIV:1, Ops.MOD:1, Ops.ADD:2, Ops.SUB:2, Ops.SHL:3, Ops.SHR:3, Ops.AND:4, Ops.XOR:5, Ops.OR:6}
precedence = {Ops.MUL:1, Ops.FLOORDIV:1, Ops.FLOORMOD:1, Ops.ADD:2, Ops.SUB:2, Ops.SHL:3, Ops.SHR:3, Ops.AND:4, Ops.XOR:5, Ops.OR:6}
def strip_binary_parens(x:UOp, left:str, right:str, code_for_op) -> str:
if x.op not in precedence: return code_for_op(left, right)
return code_for_op(strip_parens(left) if precedence.get(x.src[0].op,99)<=precedence[x.op] else left, strip_parens(right) if
@ -46,6 +46,8 @@ renderer = PatternMatcher([
(UPat(Ops.MAX, name="x"), lambda ctx,x: f"max({ctx[x.src[0]]}, {ctx[x.src[1]]})"),
(UPat(Ops.MULACC, name="x"), lambda ctx,x: f"({ctx[x.src[0]]}*{ctx[x.src[1]]}+{ctx[x.src[2]]})"),
(UPat(Ops.WHERE, name="x"), lambda ctx,x: f"({ctx[x.src[1]]} if {ctx[x.src[0]]} else {ctx[x.src[2]]})"),
(UPat(Ops.IDIV, name="x"), lambda ctx,x: f"cdiv({ctx[x.src[0]]}, {ctx[x.src[1]]})"),
(UPat(Ops.MOD, name="x"), lambda ctx,x: f"cmod({ctx[x.src[0]]}, {ctx[x.src[1]]})"),
(UPat(set(syms.keys()), name="x"), lambda ctx,x: strip_binary_parens(x, ctx[x.src[0]], ctx[x.src[1]], lambda a,b: f"({a}{syms[x.op]}{b})")),
(UPat((Ops.INDEX, Ops.BUFFERIZE), name="x"), lambda x, ctx: ''.join([f"[{strip_parens(ctx[y])}]" for y in x.src[1:]])),
(UPat(Ops.STACK, name="x"),
@ -56,6 +58,8 @@ renderer = PatternMatcher([
renderer_infer = PatternMatcher([
(UPat(Ops.MOD, name="x"), lambda ctx,x: f"cmod({ctx[x.src[0]]}, {ctx[x.src[1]]})"),
(UPat(Ops.IDIV, name="x"), lambda ctx,x: f"cdiv({ctx[x.src[0]]}, {ctx[x.src[1]]})"),
(UPat(Ops.FLOORMOD, name="x"), lambda ctx,x: f"floormod({ctx[x.src[0]]}, {ctx[x.src[1]]})"),
(UPat(Ops.FLOORDIV, name="x"), lambda ctx,x: f"floordiv({ctx[x.src[0]]}, {ctx[x.src[1]]})"),
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"bitcast({ctx[x.src[0]]}, {x.src[0].dtype!r}, {x.dtype!r})"),
]) + renderer
@ -99,13 +103,16 @@ pm_pyrender_extra = PatternMatcher([
# TODO: movement ops simplify stuff, this can break SPEC=2
#(UPat(GroupOp.Movement, name="x"), lambda ctx,x: f"{ctx[x.src[0]]}.{x.op.name.lower()}({render_marg(ctx,x)})"),
# NOTE: CMPNE doesn't work cause there's no __rne__
# explicit trunc ops: `//` and `%` parse as FLOORDIV/FLOORMOD, so render IDIV/MOD via their named methods
(UPat(Ops.IDIV, name="x"), lambda ctx,x: f"{ctx[x.src[0]]}.idiv({ctx[x.src[1]]})"),
(UPat(Ops.MOD, name="x"), lambda ctx,x: f"{ctx[x.src[0]]}.alu(Ops.MOD, {ctx[x.src[1]]})"),
# NOTE: only match CONSTs without UNIQUE (len(src)==1), unique_const needs explicit rendering
(UPat(set(syms.keys())-{Ops.SUB, Ops.CMPNE}, src=(UPat(Ops.CONST, src=(UPat(Ops.DEVICE),), name="y"), UPat(name="z")), name="x"),
(UPat(set(syms.keys())-{Ops.SUB, Ops.CMPNE, Ops.IDIV, Ops.MOD}, src=(UPat(Ops.CONST, src=(UPat(Ops.DEVICE),), name="y"), UPat(name="z")), name="x"),
lambda ctx,x,y,z: strip_binary_parens(x, str(y.arg), ctx[z], lambda a,b: f"({a}{syms[x.op]}{b})")),
# NOTE: sub doesn't work cause it's written as add/mul
(UPat(set(syms.keys())-{Ops.SUB}, src=(UPat(name="y"), UPat(Ops.CONST, src=(UPat(Ops.DEVICE),), name="z")), name="x"), lambda ctx,x,y,z:
strip_binary_parens(x, ctx[y], str(z.arg), lambda a,b: f"({a}{syms[x.op]}{b})")),
(UPat(set(syms.keys())-{Ops.SUB}, name="x"), lambda ctx,x:
(UPat(set(syms.keys())-{Ops.SUB, Ops.IDIV, Ops.MOD}, src=(UPat(name="y"), UPat(Ops.CONST, src=(UPat(Ops.DEVICE),), name="z")), name="x"),
lambda ctx,x,y,z: strip_binary_parens(x, ctx[y], str(z.arg), lambda a,b: f"({a}{syms[x.op]}{b})")),
(UPat(set(syms.keys())-{Ops.SUB, Ops.IDIV, Ops.MOD}, name="x"), lambda ctx,x:
strip_binary_parens(x, ctx[x.src[0]], ctx[x.src[1]], lambda a,b: f"({a}{syms[x.op]}{b})")),
(UPat(sugar, src=(), name="x"), lambda x: f"UOp.{x.op.name.lower()}("+', '.join(([f'arg={repr(x.arg)}'] if x.arg is not None else []))+")"),
(UPat(sugar, name="x"), lambda ctx,x: f"{ctx[x.src[0]]}.{x.op.name.lower()}("+', '.join([ctx[y] for y in x.src[1:]] + \

View file

@ -28,8 +28,8 @@ invalid_gate = UPat.var("cond").where(UPat.var("x"), invalid_pat)
def fold_add_divmod_recombine(x:UOp) -> UOp|None:
terms = list(x.split_uop(Ops.ADD))
for i,u in enumerate(terms):
if u.op is Ops.MOD and u.src[1].op is Ops.CONST: base, div, mul = u.src[0], u.src[1].arg, 1
elif u.op is Ops.MUL and u.src[1].op is Ops.CONST and (m:=u.src[0]).op is Ops.MOD and m.src[1].op is Ops.CONST:
if u.op is Ops.FLOORMOD and u.src[1].op is Ops.CONST: base, div, mul = u.src[0], u.src[1].arg, 1
elif u.op is Ops.MUL and u.src[1].op is Ops.CONST and (m:=u.src[0]).op is Ops.FLOORMOD and m.src[1].op is Ops.CONST:
base, div, mul = m.src[0], m.src[1].arg, u.src[1].arg
else: continue
for j,v in enumerate(terms):
@ -37,13 +37,13 @@ def fold_add_divmod_recombine(x:UOp) -> UOp|None:
if v.op is not Ops.MUL or v.src[1].op is not Ops.CONST or v.src[1].arg != div*mul: continue
q, exact = v.src[0], False
# (base%div)*mul + (base//div)*(div*mul) -> base*mul
if q.op is Ops.IDIV and q.src[1].op is Ops.CONST and q.src[1].arg == div: exact = q.src[0] is base
if q.op is Ops.FLOORDIV and q.src[1].op is Ops.CONST and q.src[1].arg == div: exact = q.src[0] is base
# ((base//d)%div)*mul + (base//(d*div))*(div*mul) -> (base//d)*mul
if not exact and base.op is Ops.IDIV and base.src[1].op is Ops.CONST:
exact = q.op is Ops.IDIV and q.src[1].op is Ops.CONST and q.src[0] is base.src[0] and q.src[1].arg == base.src[1].arg*div
if not exact and base.op is Ops.FLOORDIV and base.src[1].op is Ops.CONST:
exact = q.op is Ops.FLOORDIV and q.src[1].op is Ops.CONST and q.src[0] is base.src[0] and q.src[1].arg == base.src[1].arg*div
if exact: return (base*mul).usum(*[t for k,t in enumerate(terms) if k not in (i,j)])
# ((base//div)%d)*div + base%div -> base%(div*d)
if mul == 1 and div > 0 and q.op is Ops.MOD and q.src[1].op is Ops.CONST and (d:=q.src[1].arg) > 0 and q.src[0].op is Ops.IDIV:
if mul == 1 and div > 0 and q.op is Ops.FLOORMOD and q.src[1].op is Ops.CONST and (d:=q.src[1].arg) > 0 and q.src[0].op is Ops.FLOORDIV:
if q.src[0].src[0] is base and q.src[0].src[1].op is Ops.CONST and q.src[0].src[1].arg == div:
return (base % (div*d)).usum(*[t for k,t in enumerate(terms) if k not in (i,j)])
return None
@ -244,7 +244,7 @@ symbolic = symbolic_simple+commutative+PatternMatcher([
((UPat.var("y")+UPat.var("c").where(UPat.var("t"), UPat.var("f"))) + UPat.var("c").where(UPat.var("tt"), UPat.var("ff")), \
lambda y,c,t,tt,f,ff: y+c.where(t+tt, f+ff) if t.op == tt.op == Ops.CONST or f.op == ff.op == Ops.CONST else None),
# ALU/variable min==max -> CONST
(UPat({Ops.CMPLT, Ops.CMPNE, Ops.IDIV, Ops.MOD, Ops.DEFINE_VAR, Ops.BIND, Ops.SPECIAL}, name="x"),
(UPat({Ops.CMPLT, Ops.CMPNE, Ops.FLOORDIV, Ops.FLOORMOD, Ops.DEFINE_VAR, Ops.BIND, Ops.SPECIAL}, name="x"),
lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None),
(UPat(Ops.RANGE, src=(UPat(Ops.CONST,)), name="x"), lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None),
# max folding
@ -263,9 +263,9 @@ symbolic = symbolic_simple+commutative+PatternMatcher([
# c0*x<c1 for negative int c0 and non-positive c1
((UPat.cvar("c0", vec=False)*UPat.var("x", dtype=dtypes.weakint))<UPat.cvar("c1", vec=False),
lambda x,c0,c1: (-x)<(-(math.floor(-c1.arg/-c0.arg))) if c0.arg < 0 and c0.arg != -1 and c1.arg <= 0 else None),
# x//d<c
# x//d<c -> x<c*d for d>0
((UPat.var("x", dtype=dtypes.weakint)//UPat.cvar("d", vec=False))<UPat.cvar("c", vec=False),
lambda x,d,c: (x<(c.arg*d.arg) if c.arg > 0 else x<(c.arg*d.arg-(d.arg-1))) if d.arg > 0 else None),
lambda x,d,c: x<(c.arg*d.arg) if d.arg > 0 else None),
# ** move add/mul consts to end (NOTE: this is still happening before constant folding) **
((UPat.var("x") + UPat.cvar("c1")) + UPat.var("y"), lambda x,c1,y: (x+y)+c1),
((UPat.var("x") * UPat.cvar("c1")) * UPat.var("y"), lambda x,c1,y: (x*y)*c1),
@ -408,7 +408,7 @@ pm_move_where_on_load = PatternMatcher([
def gated_given_valid(cond:UOp, x:UOp, i:UOp) -> UOp|None:
if x.dtype.scalar() is not dtypes.weakint: return None
# Skip if x contains DIV/MOD AND IMAGE mode is enabled -> image index e.g. openpilot
if IMAGE.value > 0 and x.op_in_backward_slice_with_self(Ops.IDIV, Ops.MOD): return None
if IMAGE.value > 0 and x.op_in_backward_slice_with_self(Ops.IDIV, Ops.MOD, Ops.FLOORDIV, Ops.FLOORMOD): return None
return cond.where(uop_given_valid(cond, x, try_simplex=False), i)
# TODO: this is O(number of WHERE * number of node)