Merge branch 'master' into no_merge_views

This commit is contained in:
George Hotz 2025-08-14 08:07:52 -07:00 committed by GitHub
commit 4fd4e13fcf
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 115 additions and 57 deletions

View file

@ -143,13 +143,12 @@ class TestIndexingConstFolding(unittest.TestCase):
_check_ast_count(1, t[:,:,Tensor(1)+2,:])
_check_ast_count(1, t[:,:,Tensor(1),Tensor(0)])
@unittest.expectedFailure
def test_const_tensor_index(self):
# TODO: implement const tensor folded indexing
# TODO: these can be 0, implement const tensor folded indexing
t = Tensor.arange(16).float().reshape(1,1,4,4).realize()
_check_ast_count(0, t[:,:,Tensor.ones(2,1),:])
_check_ast_count(0, t[:,:,Tensor.ones(1,2)+2,:])
_check_ast_count(0, t[:,:,Tensor.ones(1,1),Tensor.zeros(2,1,2)])
_check_ast_count(1, t[:,:,Tensor.ones(2,1,dtype=dtypes.int),:])
_check_ast_count(1, t[:,:,Tensor.ones(1,2,dtype=dtypes.int)+2,:])
_check_ast_count(1, t[:,:,Tensor.ones(1,1,dtype=dtypes.int),Tensor.zeros(2,1,2,dtype=dtypes.int)])
class TestMovedConstFolding(unittest.TestCase):
def test_add_shrunk_zero(self):

View file

@ -30,16 +30,16 @@ class TestSymbolicPickle(unittest.TestCase):
class TestSymbolic(unittest.TestCase):
def helper_test_variable(self, v, n, m, s, test_z3:bool=True):
rendered, nmin, nmax = render(v)
if isinstance(s, tuple): self.assertIn(rendered, s)
else: self.assertEqual(rendered, s)
self.assertEqual(nmin, n)
self.assertEqual(nmax, m)
if test_z3:
solver = z3.Solver()
z3_sink = graph_rewrite(v.sink(v.simplify()), z3_renderer, ctx=(solver, {}))
expr, epxr_simplified = z3_sink.src[0].arg, z3_sink.src[1].arg
self.assertEqual(solver.check(expr != epxr_simplified), z3.unsat, "simplified expression not equal to original")
rendered, nmin, nmax = render(v)
if isinstance(s, tuple): self.assertIn(rendered, s)
else: self.assertEqual(rendered, s)
self.assertEqual(nmin, n)
self.assertEqual(nmax, m)
def test_cmp_simple(self):
self.helper_test_variable(Variable("a", 3, 8) < 4, 0, 1, "(a<4)")
@ -266,6 +266,16 @@ class TestSymbolic(unittest.TestCase):
self.helper_test_variable(((5*Variable("a", 0, 31)) % 12) % 5, 0, 4, "(((a*5)%12)%5)")
self.helper_test_variable((Variable("a", 0, 31) % 4) % 12, 0, 3, "(a%4)")
def test_mod_mod_wrong_sign(self):
v1=Variable("v1", 0, 128)
v3=Variable("v3", 0, 7)
self.helper_test_variable((((((v1%2)*2)+((v3+-1)%5))+-2)%5), -4, 4, "(((((v1%2)*2)+((v3+-1)%5))+-2)%5)")
def test_mod_mod_wrong_sign2(self):
v2=Variable("v2", 0, 8)
v3=Variable("v3", 0, 4)
self.helper_test_variable((((((v3+3)%7)+(v2+-2))%7)%7), -6, 6, "(((v2+((v3+3)%7))+-2)%7)")
def test_mul_mul(self):
self.helper_test_variable((Variable("a", 0, 5)*10)*9, 0, 5*10*9, "(a*90)")
@ -375,6 +385,17 @@ class TestSymbolic(unittest.TestCase):
def test_mul_div(self):
self.helper_test_variable((Variable("a", 0, 10)*4)//4, 0, 10, "a")
def test_div_drop_small_terms(self):
# from openpilot, shouldnt simplify
gidx0 = UOp.variable("gidx0", 0, 10)
gidx1 = UOp.variable("gidx1", 0, 10)
lidx0 = UOp.variable("lidx0", 0, 1)
lidx1 = UOp.variable("lidx1", 0, 1)
ridx1005 = UOp.variable("ridx1005", 0, 2)
ridx1006 = UOp.variable("ridx1006", 0, 2)
self.helper_test_variable((lidx1+((gidx1*18)+(ridx1005*18)+(lidx0*162))+(gidx0*2)+(ridx1006*2)+-40)//18, -2, 20,
"(((((lidx1+(((gidx1*18)+(ridx1005*18))+(lidx0*162)))+(gidx0*2))+(ridx1006*2))+-40)//18)")
def test_add_div(self):
# careful about the lower bounds and upper bounds
self.helper_test_variable((Variable("a", 0, 5)-2)//4, 0, 0, "0")
@ -421,6 +442,11 @@ class TestSymbolic(unittest.TestCase):
def test_div_numerator_negative(self):
self.helper_test_variable((Variable("idx", 0, 9)*-10)//11, -8, 0, "(((idx*10)//11)*-1)")
def test_nest_div_negative_factor(self):
ridx0=UOp.variable("ridx0", 0, 9)
ridx1=UOp.variable("ridx1", 0, 6)
self.helper_test_variable(((((ridx0*-7)+ridx1)+63)//35), 0, 1, "(((ridx0//5)*-1)+1)")
def test_div_into_mod(self):
self.helper_test_variable((Variable("idx", 0, 16)*4)%8//4, 0, 1, "(idx%2)")

View file

@ -487,6 +487,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
if (d0:=self.src[0].divides(v)) is not None: return d0 * self.src[1]
if (d1:=self.src[1].divides(v)) is not None: return self.src[0] * d1
return None # generic None if we aren't sure
def pop_const(self) -> tuple[UOp, int]: return (self.src[0], self.src[1].arg) if self.op is Ops.ADD and self.src[1].op is Ops.CONST else (self, 0)
@property
def vmin(self) -> ConstType: return self._min_max[0]
@property

View file

@ -1,5 +1,5 @@
# all of symbolic lives here now
from typing import Any, Literal, cast
from typing import Any, cast
import math, operator, struct, functools
from collections import defaultdict
from tinygrad.uop.ops import Ops, PatternMatcher, UPat, UOp, GroupOp, exec_alu
@ -139,65 +139,92 @@ def canonicalize_simplex(X:UOp) -> UOp|None:
ret.append(u)
return functools.reduce(operator.add, ret) if changed else None
def div_and_mod_folding(x: UOp, y: UOp, which: Literal[Ops.MOD, Ops.IDIV], split_rem: bool=False) -> UOp|None:
# simplify x // y or x % y, None means no change
# simple cancel div/mod case
def cancel_divmod(d: UOp, x: UOp, y: UOp) -> UOp|None:
# simple cancel div/mod case when the range of the numerator lies within a single denominator interval
x_min, x_max, y_min, y_max = x.vmin, x.vmax, y.vmin, y.vmax
assert isinstance(x_min, int) and isinstance(x_max, int) and isinstance(y_min, int) and isinstance(y_max, int)
if y_min==y_max==0: raise ZeroDivisionError(f"{'Division' if d.op is Ops.IDIV else 'Mod'} by zero trying to rewrite {x.alu(d.op, y)}")
if y_min*y_max > 0 and (q:=cdiv(x_min,y_min)) == cdiv(x_min,y_max) == cdiv(x_max,y_min) == cdiv(x_max,y_max):
return x - q*y if which is Ops.MOD else x.const_like(q)
return x - q*y if d.op is Ops.MOD else d.const_like(q)
return None
if (y.op is not Ops.CONST) or ((c := y.arg) < 0) or (x.dtype.count > 1): return None
if y.arg == 0: raise ZeroDivisionError(f"{'Division' if which is Ops.IDIV else 'Mod'} by zero trying to rewrite {x.alu(which, y)}")
svars, factors, quotients, remainders, gcd, div, const, something_changed = [], [], [], [], c, 1, 0, False
def remove_nested_mod(m: UOp, x: UOp, y: UOp) -> UOp|None:
# remove nested mod in case the inner mod is a multiple of the outer mod
# example: (a%4 + b)%2 -> (a+b)%2
if ((c := y.arg) < 0) or x.vmin<0: return None
new_xs = []
something_changed = False
for u in split_uop(x, Ops.ADD):
if u.op is Ops.MOD and which is Ops.MOD and u.src[1].op is Ops.CONST and u.src[1].arg%c == 0:
u = u.src[0]
something_changed = True
v: UOp = u.divides(f:=u.const_factor())
q, r = divmod(f, c)
if r==0 or ((which is Ops.MOD or split_rem or u.op is Ops.CONST) and r!=f): something_changed = True
if u.op is Ops.CONST: const += f
else: # div is the smallest common divisor of all terms
if f > 1 and c % f == 0 and (div == 1 or div > f): div = f
gcd = math.gcd(r, gcd)
factors.append(f); svars.append(v); quotients.append(q); remainders.append(r) # noqa: E702
if u.op is Ops.MOD:
if u.src[1].divides(c) is not None:
something_changed = True
u = u.src[0]
new_xs.append(u)
new_x: UOp = functools.reduce(operator.add, new_xs)
if something_changed and new_x.vmin>=0: return new_x % y
return None
def fold_binary_numerator(d: UOp, x: UOp, y: UOp) -> UOp|None:
# we can fold if the expression has only one non-constant term and this term can only take on two values
if len(svars)==1 and (v:=svars[0]).vmax-v.vmin == 1:
y1 = cmod(factors[0]*v.vmin+const, c) if which is Ops.MOD else cdiv(factors[0]*v.vmin+const, c)
y2 = cmod(factors[0]*v.vmax+const, c) if which is Ops.MOD else cdiv(factors[0]*v.vmax+const, c)
if ((c := y.arg) < 0) or (x.dtype.count > 1): return None
x,const = x.pop_const()
terms, factors = zip(*[(u.divides(f:=u.const_factor()),f) for u in split_uop(x, Ops.ADD)])
if len(terms)==1 and (v:=terms[0]).vmax-v.vmin == 1:
y1 = cmod(factors[0]*v.vmin+const, c) if d.op is Ops.MOD else cdiv(factors[0]*v.vmin+const, c) # type: ignore
y2 = cmod(factors[0]*v.vmax+const, c) if d.op is Ops.MOD else cdiv(factors[0]*v.vmax+const, c) # type: ignore
return (y2-y1)*(v-v.vmin) + y1
return None
if not CORRECT_DIVMOD_FOLDING or x_min>=0:
# a//c = (a-a%c)/c, if we can fold a%c, we can fold a//c
# within a mod we can freely subtract multiples of c, we use this to see if a is congruent to an expression whose vmin/vmax are between 0 and c
rems = [min(r, r-c, key=abs) for r in remainders]
if (rem:=sum(r*v for r,v in zip(rems,svars))+const%c).vmin//c==rem.vmax//c and all(f > 0 for f in factors):
if which is Ops.MOD: return rem - rem.vmin//c*c
return sum((f-r)//c * v for f,r,v in zip(factors,rems,svars)) + (const-const%c+rem.vmin//c*c)//c
def fold_divmod_congruence(d: UOp, x: UOp, y: UOp) -> UOp|None:
# within a mod we can freely subtract multiples of c, we use this to see if a is congruent to an expression whose vmin/vmax are between 0 and c
if (x.vmin<0 and CORRECT_DIVMOD_FOLDING) or ((c := y.arg) < 0) or (x.dtype.count > 1): return None
x,const = x.pop_const()
terms, factors = zip(*[(u.divides(f:=u.const_factor()),f) for u in split_uop(x, Ops.ADD)])
# a//c = (a-a%c)/c, if we can fold a%c, we can fold a//c
rems = [min((r:=f%c), r-c, key=abs) for f in factors]
if (rem:=sum(r*v for r,v in zip(rems,terms))+const%c).vmin//c==rem.vmax//c and all(f > 0 for f in factors):
if d.op is Ops.MOD: return rem - rem.vmin//c*c
return sum((f-r)//c * v for f,r,v in zip(factors,rems,terms)) + (const-const%c+rem.vmin//c*c)//c
return None
if (g:=math.gcd(gcd, const))!=1:
ret = UOp(which, x.dtype, src=(sum(f//g * v for f,v in zip(factors, svars)) + const//g, x.const_like(c//g)))
return ret*g if which is Ops.MOD else ret
def divide_by_gcd(d: UOp, x: UOp, y: UOp) -> UOp|None:
# x//y -> (x//gcd)//(y//gcd) or x%y -> gcd*(x//gcd)%(y//gcd)
terms, factors = zip(*[(u.divides(f:=u.const_factor()),f) for u in split_uop(x, Ops.ADD)])
if (gcd := math.gcd(y.arg, *factors)) == 1: return None
ret = sum(f//gcd * v for f,v in zip(factors, terms)).alu(d.op, y.const_like(y.arg//gcd))
return ret*gcd if d.op is Ops.MOD else ret
def nest_div_by_smallest_factor(d: UOp, x: UOp, y: UOp) -> UOp|None:
# we try and nest the div and see if it allows the numerator to be simplified
if ((c := y.arg) < 0) or (x.dtype.count > 1): return None
factors = [u.const_factor() for u in split_uop(x.pop_const()[0], Ops.ADD)]
# div is the smallest factor of the denominator (greater than 1) out of all "factors"
# TODO: there are better ways to pick `div`, this sometimes adds extra divisions
# TODO: add same optimization for mod
div = min([y.arg]+[abs(f) for f in factors if abs(f) > 1 and (c%f)==0])
if (1 < div < c) and (newxs:=(newx:=(x//div)).simplify()) is not newx and x.vmin>=0 and newx.vmin>=0: return newxs//(c//div)
return None
def simplify_remainder(d: UOp, x: UOp, y: UOp) -> UOp|None:
# we try and take out the quotient and see if it allows the numerator to be simplified
if ((c := y.arg) < 0) or (x.dtype.count > 1): return None
x_no_const,const = x.pop_const()
terms, factors = zip(*[(u.divides(f:=u.const_factor()),f) for u in split_uop(x_no_const, Ops.ADD)])
quotients, remainders = zip(*[divmod(f, c) for f in factors])
gcd = math.gcd(c, *remainders) # gcd without const!
if const%c==const and gcd==1 and not any(r==0 or (r!=f and d.op is Ops.MOD) for r,f in zip(remainders, factors)): return None
if gcd != 1: something_changed = True
if not something_changed:
if which is Ops.IDIV and (1 < div < c) and (newx:=div_and_mod_folding(x, x.const_like(div), Ops.IDIV)) is not None: return newx//(c//div)
return None
quo, rem = x.const_like(const//c), x.const_like((const%c)//gcd)
for q,r,f,v in zip(quotients, remainders, factors, svars):
if which is Ops.IDIV and (not split_rem) and r!=0:
for q,r,f,v in zip(quotients, remainders, factors, terms):
if d.op is Ops.IDIV and r!=0:
rem += f//gcd * v
else:
rem += r//gcd * v
quo += q * v
# if numerator before/after is negative, and it has remainder, don't simplify because C divmod is different from python divmod.
if (x_min < 0 or rem.vmin < 0) and remainders: return None
if which is Ops.MOD: return gcd*(rem % (c//gcd)) + const%gcd
if (x.vmin < 0 or rem.vmin < 0) and remainders: return None
if d.op is Ops.MOD: return gcd*(rem % (c//gcd)) + const%gcd
return rem//(c//gcd)+quo
def gep_through_wmma(gep:UOp, wmma:UOp):
@ -302,15 +329,20 @@ symbolic = symbolic_simple+commutative+PatternMatcher([
# div folding
((UPat.var("x")//UPat.cvar("c") + UPat.cvar("a"))//UPat.cvar("d"), lambda x,c,a,d: (x+a*c)//(c*d)
if c.vmin>0 and d.vmin>0 and ((x.vmin>=0 and a.vmin>=0) or (x.vmax<=0 and a.vmax<=0)) else None), # (x//c+a)//d -> (x+a*c)//(c*d)
(UPat.var("x", dtypes.sints) // UPat.var("y"), lambda x,y: div_and_mod_folding(x,y,Ops.IDIV)),
(UPat((Ops.IDIV, Ops.MOD), dtypes.sints, name="d", src=(UPat.var("x"), UPat.var("y"))), cancel_divmod),
(UPat((Ops.IDIV, Ops.MOD), dtypes.sints, name="d", src=(UPat.var("x"), UPat.cvar("y", vec=False))), fold_binary_numerator),
(UPat((Ops.IDIV, Ops.MOD), dtypes.sints, name="d", src=(UPat.var("x"), UPat.cvar("y", vec=False))), fold_divmod_congruence),
(UPat((Ops.IDIV, Ops.MOD), dtypes.sints, name="d", src=(UPat.var("x"), UPat.cvar("y", vec=False))), divide_by_gcd),
(UPat(Ops.MOD, dtypes.sints, name="m", src=(UPat.var("x"), UPat.cvar("y", vec=False))), remove_nested_mod),
(UPat((Ops.IDIV), dtypes.sints, name="d", src=(UPat.var("x"), UPat.cvar("y", vec=False))), nest_div_by_smallest_factor),
(UPat((Ops.IDIV, Ops.MOD), dtypes.sints, name="d", src=(UPat.var("x"), UPat.cvar("y", vec=False))), simplify_remainder),
(UPat.var("x") // UPat.var("d"), lambda x,d: -(x//(-d)) if d.vmax < 0 else None),
(UPat.var("x") // UPat.var("d"), lambda x,d: -((-x)//d) if x.vmax <=0 else None),
((UPat.var("x", dtypes.sints)+UPat.cvar("c", vec=False)).named("n")//UPat.cvar("d", vec=False),
lambda x,c,n,d: (-(-(c.arg%d.arg + x - (d.arg-1))//d) + c.arg//d.arg) if x.vmax<=0 and n.vmin>=0 and d.arg>0 else None),
# ** mod **
# mod folding
(UPat.var("x") % UPat.var("y"), lambda x,y: div_and_mod_folding(x,y,Ops.MOD)),
(UPat.var("x") % UPat.var("d"), lambda x,d: -((-x)%d) if x.vmax <=0 else None),
(UPat.var("x") % UPat.var("d"), lambda x,d: -((-x)%d) if x.vmax <= 0 else None),
(UPat.var("x") % UPat.var("d"), lambda x,d: (x%(-d)) if d.vmax < 0 else None),
])+gep_pushing