relax fold_divmod_general (#16058)

This commit is contained in:
chenyu 2026-05-05 21:37:56 -04:00 committed by GitHub
commit aaabe42373
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 26 additions and 15 deletions

View file

@ -550,7 +550,18 @@ class TestSymbolic(unittest.TestCase):
def test_nest_div_negative_factor(self):
ridx0=Variable("ridx0", 0, 9)
ridx1=Variable("ridx1", 0, 6)
self.helper_test_variable(((((ridx0*-7)+ridx1)+63)//35), 0, 1, "((ridx1+ridx0*-7+28)//35+1)")
self.helper_test_variable(((((ridx0*-7)+ridx1)+63)//35), 0, 1, "((ridx0*-1+4)//5+1)")
def test_floordiv_factor_nest_negative_numerator(self):
# x//c = (x//f)//(c//f) for f|c, any sign of x
a = Variable("a", -10, 10)
b = Variable("b", 0, 3)
self.helper_test_variable((a*4 + b)//12, -4, 3, "(a//3)")
def test_floordiv_gcd_with_remainder_negative_numerator(self):
# factor gcd from numerator, even when x crosses zero, as long as the shifted numerator stays nonneg
a = Variable("a", -1, 5)
self.helper_test_variable((a*2 + 7)//8, 0, 2, "((a+3)//4)")
def test_div_into_mod(self):
self.helper_test_variable((Variable("idx", 0, 16)*4)%8//4, 0, 1, "(idx%2)")

View file

@ -56,26 +56,26 @@ def fold_divmod_general(d: UOp) -> UOp|None:
return sum((f-r)//c * v for f,r,v in zip(factors,rems,terms)) + const//c + rem.vmin//c
# gcd_with_remainder: factor out common gcd from numerator
if x.vmin >= 0 and (g:=math.gcd(*factors, c)) > 1:
if (g:=math.gcd(*factors, c)) > 1:
new_x = unwrap(x_peeled.divides(g)).simplify() + (const//g)%(c//g)
if new_x.vmin >= 0:
if d.op is Ops.FLOORMOD: return new_x % (c//g) * g + const%g
return new_x // (c//g) + const//c
# nest_by_factor: x//c -> (x//f)//(c//f), x%c -> (x//f%(c//f))*f + b where b=x%f
if x.vmin >= 0:
results = []
for div in {abs(f) for u, f in zip(uops_no_const, factors) if u.op not in (Ops.CONST, Ops.VCONST) and 1 < abs(f) < c and (c%f)==0}:
if (newxs := fold_divmod_general(x//div)) is not None and newxs.vmin >= 0:
if d.op is Ops.FLOORDIV:
results.append((len(newxs.backward_slice), newxs // (c // div)))
else:
b_parts = [f%div*t for f, t in zip(factors, terms) if f%div]
if const % div: b_parts.append(x.const_like(const % div))
b = UOp.usum(*b_parts) if b_parts else x.const_like(0)
if 0 <= b.vmin and b.vmax < div:
results.append((len((r:=(newxs % x.ufix(c//div))*div + b).backward_slice), r))
if results: return min(results, key=lambda r: r[0])[1]
# FLOORDIV identity holds for any sign of x; FLOORMOD reconstruction needs x.vmin>=0
results = []
for div in {abs(f) for u, f in zip(uops_no_const, factors) if u.op not in (Ops.CONST, Ops.VCONST) and 1 < abs(f) < c and (c%f)==0}:
if (newxs := fold_divmod_general(x//div)) is not None:
if d.op is Ops.FLOORDIV:
results.append((len(newxs.backward_slice), newxs // (c // div)))
elif x.vmin >= 0 and newxs.vmin >= 0:
b_parts = [f%div*t for f, t in zip(factors, terms) if f%div]
if const % div: b_parts.append(x.const_like(const % div))
b = UOp.usum(*b_parts) if b_parts else x.const_like(0)
if 0 <= b.vmin and b.vmax < div:
results.append((len((r:=(newxs % x.ufix(c//div))*div + b).backward_slice), r))
if results: return min(results, key=lambda r: r[0])[1]
# ** Variable Denominator / Fallback Rules **
# These rules apply to variables OR constants that failed the checks above.