mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
relax fold_divmod_general (#16058)
This commit is contained in:
parent
1de14cf33a
commit
aaabe42373
2 changed files with 26 additions and 15 deletions
|
|
@ -550,7 +550,18 @@ class TestSymbolic(unittest.TestCase):
|
|||
def test_nest_div_negative_factor(self):
|
||||
ridx0=Variable("ridx0", 0, 9)
|
||||
ridx1=Variable("ridx1", 0, 6)
|
||||
self.helper_test_variable(((((ridx0*-7)+ridx1)+63)//35), 0, 1, "((ridx1+ridx0*-7+28)//35+1)")
|
||||
self.helper_test_variable(((((ridx0*-7)+ridx1)+63)//35), 0, 1, "((ridx0*-1+4)//5+1)")
|
||||
|
||||
def test_floordiv_factor_nest_negative_numerator(self):
|
||||
# x//c = (x//f)//(c//f) for f|c, any sign of x
|
||||
a = Variable("a", -10, 10)
|
||||
b = Variable("b", 0, 3)
|
||||
self.helper_test_variable((a*4 + b)//12, -4, 3, "(a//3)")
|
||||
|
||||
def test_floordiv_gcd_with_remainder_negative_numerator(self):
|
||||
# factor gcd from numerator, even when x crosses zero, as long as the shifted numerator stays nonneg
|
||||
a = Variable("a", -1, 5)
|
||||
self.helper_test_variable((a*2 + 7)//8, 0, 2, "((a+3)//4)")
|
||||
|
||||
def test_div_into_mod(self):
|
||||
self.helper_test_variable((Variable("idx", 0, 16)*4)%8//4, 0, 1, "(idx%2)")
|
||||
|
|
|
|||
|
|
@ -56,26 +56,26 @@ def fold_divmod_general(d: UOp) -> UOp|None:
|
|||
return sum((f-r)//c * v for f,r,v in zip(factors,rems,terms)) + const//c + rem.vmin//c
|
||||
|
||||
# gcd_with_remainder: factor out common gcd from numerator
|
||||
if x.vmin >= 0 and (g:=math.gcd(*factors, c)) > 1:
|
||||
if (g:=math.gcd(*factors, c)) > 1:
|
||||
new_x = unwrap(x_peeled.divides(g)).simplify() + (const//g)%(c//g)
|
||||
if new_x.vmin >= 0:
|
||||
if d.op is Ops.FLOORMOD: return new_x % (c//g) * g + const%g
|
||||
return new_x // (c//g) + const//c
|
||||
|
||||
# nest_by_factor: x//c -> (x//f)//(c//f), x%c -> (x//f%(c//f))*f + b where b=x%f
|
||||
if x.vmin >= 0:
|
||||
results = []
|
||||
for div in {abs(f) for u, f in zip(uops_no_const, factors) if u.op not in (Ops.CONST, Ops.VCONST) and 1 < abs(f) < c and (c%f)==0}:
|
||||
if (newxs := fold_divmod_general(x//div)) is not None and newxs.vmin >= 0:
|
||||
if d.op is Ops.FLOORDIV:
|
||||
results.append((len(newxs.backward_slice), newxs // (c // div)))
|
||||
else:
|
||||
b_parts = [f%div*t for f, t in zip(factors, terms) if f%div]
|
||||
if const % div: b_parts.append(x.const_like(const % div))
|
||||
b = UOp.usum(*b_parts) if b_parts else x.const_like(0)
|
||||
if 0 <= b.vmin and b.vmax < div:
|
||||
results.append((len((r:=(newxs % x.ufix(c//div))*div + b).backward_slice), r))
|
||||
if results: return min(results, key=lambda r: r[0])[1]
|
||||
# FLOORDIV identity holds for any sign of x; FLOORMOD reconstruction needs x.vmin>=0
|
||||
results = []
|
||||
for div in {abs(f) for u, f in zip(uops_no_const, factors) if u.op not in (Ops.CONST, Ops.VCONST) and 1 < abs(f) < c and (c%f)==0}:
|
||||
if (newxs := fold_divmod_general(x//div)) is not None:
|
||||
if d.op is Ops.FLOORDIV:
|
||||
results.append((len(newxs.backward_slice), newxs // (c // div)))
|
||||
elif x.vmin >= 0 and newxs.vmin >= 0:
|
||||
b_parts = [f%div*t for f, t in zip(factors, terms) if f%div]
|
||||
if const % div: b_parts.append(x.const_like(const % div))
|
||||
b = UOp.usum(*b_parts) if b_parts else x.const_like(0)
|
||||
if 0 <= b.vmin and b.vmax < div:
|
||||
results.append((len((r:=(newxs % x.ufix(c//div))*div + b).backward_slice), r))
|
||||
if results: return min(results, key=lambda r: r[0])[1]
|
||||
|
||||
# ** Variable Denominator / Fallback Rules **
|
||||
# These rules apply to variables OR constants that failed the checks above.
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue