mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Merge branch 'master' into no_merge_views
This commit is contained in:
commit
4fd4e13fcf
4 changed files with 115 additions and 57 deletions
|
|
@ -143,13 +143,12 @@ class TestIndexingConstFolding(unittest.TestCase):
|
|||
_check_ast_count(1, t[:,:,Tensor(1)+2,:])
|
||||
_check_ast_count(1, t[:,:,Tensor(1),Tensor(0)])
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_const_tensor_index(self):
|
||||
# TODO: implement const tensor folded indexing
|
||||
# TODO: these can be 0, implement const tensor folded indexing
|
||||
t = Tensor.arange(16).float().reshape(1,1,4,4).realize()
|
||||
_check_ast_count(0, t[:,:,Tensor.ones(2,1),:])
|
||||
_check_ast_count(0, t[:,:,Tensor.ones(1,2)+2,:])
|
||||
_check_ast_count(0, t[:,:,Tensor.ones(1,1),Tensor.zeros(2,1,2)])
|
||||
_check_ast_count(1, t[:,:,Tensor.ones(2,1,dtype=dtypes.int),:])
|
||||
_check_ast_count(1, t[:,:,Tensor.ones(1,2,dtype=dtypes.int)+2,:])
|
||||
_check_ast_count(1, t[:,:,Tensor.ones(1,1,dtype=dtypes.int),Tensor.zeros(2,1,2,dtype=dtypes.int)])
|
||||
|
||||
class TestMovedConstFolding(unittest.TestCase):
|
||||
def test_add_shrunk_zero(self):
|
||||
|
|
|
|||
|
|
@ -30,16 +30,16 @@ class TestSymbolicPickle(unittest.TestCase):
|
|||
|
||||
class TestSymbolic(unittest.TestCase):
|
||||
def helper_test_variable(self, v, n, m, s, test_z3:bool=True):
|
||||
rendered, nmin, nmax = render(v)
|
||||
if isinstance(s, tuple): self.assertIn(rendered, s)
|
||||
else: self.assertEqual(rendered, s)
|
||||
self.assertEqual(nmin, n)
|
||||
self.assertEqual(nmax, m)
|
||||
if test_z3:
|
||||
solver = z3.Solver()
|
||||
z3_sink = graph_rewrite(v.sink(v.simplify()), z3_renderer, ctx=(solver, {}))
|
||||
expr, epxr_simplified = z3_sink.src[0].arg, z3_sink.src[1].arg
|
||||
self.assertEqual(solver.check(expr != epxr_simplified), z3.unsat, "simplified expression not equal to original")
|
||||
rendered, nmin, nmax = render(v)
|
||||
if isinstance(s, tuple): self.assertIn(rendered, s)
|
||||
else: self.assertEqual(rendered, s)
|
||||
self.assertEqual(nmin, n)
|
||||
self.assertEqual(nmax, m)
|
||||
|
||||
def test_cmp_simple(self):
|
||||
self.helper_test_variable(Variable("a", 3, 8) < 4, 0, 1, "(a<4)")
|
||||
|
|
@ -266,6 +266,16 @@ class TestSymbolic(unittest.TestCase):
|
|||
self.helper_test_variable(((5*Variable("a", 0, 31)) % 12) % 5, 0, 4, "(((a*5)%12)%5)")
|
||||
self.helper_test_variable((Variable("a", 0, 31) % 4) % 12, 0, 3, "(a%4)")
|
||||
|
||||
def test_mod_mod_wrong_sign(self):
|
||||
v1=Variable("v1", 0, 128)
|
||||
v3=Variable("v3", 0, 7)
|
||||
self.helper_test_variable((((((v1%2)*2)+((v3+-1)%5))+-2)%5), -4, 4, "(((((v1%2)*2)+((v3+-1)%5))+-2)%5)")
|
||||
|
||||
def test_mod_mod_wrong_sign2(self):
|
||||
v2=Variable("v2", 0, 8)
|
||||
v3=Variable("v3", 0, 4)
|
||||
self.helper_test_variable((((((v3+3)%7)+(v2+-2))%7)%7), -6, 6, "(((v2+((v3+3)%7))+-2)%7)")
|
||||
|
||||
def test_mul_mul(self):
|
||||
self.helper_test_variable((Variable("a", 0, 5)*10)*9, 0, 5*10*9, "(a*90)")
|
||||
|
||||
|
|
@ -375,6 +385,17 @@ class TestSymbolic(unittest.TestCase):
|
|||
def test_mul_div(self):
|
||||
self.helper_test_variable((Variable("a", 0, 10)*4)//4, 0, 10, "a")
|
||||
|
||||
def test_div_drop_small_terms(self):
|
||||
# from openpilot, shouldnt simplify
|
||||
gidx0 = UOp.variable("gidx0", 0, 10)
|
||||
gidx1 = UOp.variable("gidx1", 0, 10)
|
||||
lidx0 = UOp.variable("lidx0", 0, 1)
|
||||
lidx1 = UOp.variable("lidx1", 0, 1)
|
||||
ridx1005 = UOp.variable("ridx1005", 0, 2)
|
||||
ridx1006 = UOp.variable("ridx1006", 0, 2)
|
||||
self.helper_test_variable((lidx1+((gidx1*18)+(ridx1005*18)+(lidx0*162))+(gidx0*2)+(ridx1006*2)+-40)//18, -2, 20,
|
||||
"(((((lidx1+(((gidx1*18)+(ridx1005*18))+(lidx0*162)))+(gidx0*2))+(ridx1006*2))+-40)//18)")
|
||||
|
||||
def test_add_div(self):
|
||||
# careful about the lower bounds and upper bounds
|
||||
self.helper_test_variable((Variable("a", 0, 5)-2)//4, 0, 0, "0")
|
||||
|
|
@ -421,6 +442,11 @@ class TestSymbolic(unittest.TestCase):
|
|||
def test_div_numerator_negative(self):
|
||||
self.helper_test_variable((Variable("idx", 0, 9)*-10)//11, -8, 0, "(((idx*10)//11)*-1)")
|
||||
|
||||
def test_nest_div_negative_factor(self):
|
||||
ridx0=UOp.variable("ridx0", 0, 9)
|
||||
ridx1=UOp.variable("ridx1", 0, 6)
|
||||
self.helper_test_variable(((((ridx0*-7)+ridx1)+63)//35), 0, 1, "(((ridx0//5)*-1)+1)")
|
||||
|
||||
def test_div_into_mod(self):
|
||||
self.helper_test_variable((Variable("idx", 0, 16)*4)%8//4, 0, 1, "(idx%2)")
|
||||
|
||||
|
|
|
|||
|
|
@ -487,6 +487,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
|||
if (d0:=self.src[0].divides(v)) is not None: return d0 * self.src[1]
|
||||
if (d1:=self.src[1].divides(v)) is not None: return self.src[0] * d1
|
||||
return None # generic None if we aren't sure
|
||||
def pop_const(self) -> tuple[UOp, int]: return (self.src[0], self.src[1].arg) if self.op is Ops.ADD and self.src[1].op is Ops.CONST else (self, 0)
|
||||
@property
|
||||
def vmin(self) -> ConstType: return self._min_max[0]
|
||||
@property
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
# all of symbolic lives here now
|
||||
from typing import Any, Literal, cast
|
||||
from typing import Any, cast
|
||||
import math, operator, struct, functools
|
||||
from collections import defaultdict
|
||||
from tinygrad.uop.ops import Ops, PatternMatcher, UPat, UOp, GroupOp, exec_alu
|
||||
|
|
@ -139,65 +139,92 @@ def canonicalize_simplex(X:UOp) -> UOp|None:
|
|||
ret.append(u)
|
||||
return functools.reduce(operator.add, ret) if changed else None
|
||||
|
||||
def div_and_mod_folding(x: UOp, y: UOp, which: Literal[Ops.MOD, Ops.IDIV], split_rem: bool=False) -> UOp|None:
|
||||
# simplify x // y or x % y, None means no change
|
||||
# simple cancel div/mod case
|
||||
def cancel_divmod(d: UOp, x: UOp, y: UOp) -> UOp|None:
|
||||
# simple cancel div/mod case when the range of the numerator lies within a single denominator interval
|
||||
x_min, x_max, y_min, y_max = x.vmin, x.vmax, y.vmin, y.vmax
|
||||
assert isinstance(x_min, int) and isinstance(x_max, int) and isinstance(y_min, int) and isinstance(y_max, int)
|
||||
|
||||
if y_min==y_max==0: raise ZeroDivisionError(f"{'Division' if d.op is Ops.IDIV else 'Mod'} by zero trying to rewrite {x.alu(d.op, y)}")
|
||||
if y_min*y_max > 0 and (q:=cdiv(x_min,y_min)) == cdiv(x_min,y_max) == cdiv(x_max,y_min) == cdiv(x_max,y_max):
|
||||
return x - q*y if which is Ops.MOD else x.const_like(q)
|
||||
return x - q*y if d.op is Ops.MOD else d.const_like(q)
|
||||
return None
|
||||
|
||||
if (y.op is not Ops.CONST) or ((c := y.arg) < 0) or (x.dtype.count > 1): return None
|
||||
if y.arg == 0: raise ZeroDivisionError(f"{'Division' if which is Ops.IDIV else 'Mod'} by zero trying to rewrite {x.alu(which, y)}")
|
||||
|
||||
svars, factors, quotients, remainders, gcd, div, const, something_changed = [], [], [], [], c, 1, 0, False
|
||||
def remove_nested_mod(m: UOp, x: UOp, y: UOp) -> UOp|None:
|
||||
# remove nested mod in case the inner mod is a multiple of the outer mod
|
||||
# example: (a%4 + b)%2 -> (a+b)%2
|
||||
if ((c := y.arg) < 0) or x.vmin<0: return None
|
||||
new_xs = []
|
||||
something_changed = False
|
||||
for u in split_uop(x, Ops.ADD):
|
||||
if u.op is Ops.MOD and which is Ops.MOD and u.src[1].op is Ops.CONST and u.src[1].arg%c == 0:
|
||||
u = u.src[0]
|
||||
something_changed = True
|
||||
v: UOp = u.divides(f:=u.const_factor())
|
||||
q, r = divmod(f, c)
|
||||
if r==0 or ((which is Ops.MOD or split_rem or u.op is Ops.CONST) and r!=f): something_changed = True
|
||||
if u.op is Ops.CONST: const += f
|
||||
else: # div is the smallest common divisor of all terms
|
||||
if f > 1 and c % f == 0 and (div == 1 or div > f): div = f
|
||||
gcd = math.gcd(r, gcd)
|
||||
factors.append(f); svars.append(v); quotients.append(q); remainders.append(r) # noqa: E702
|
||||
if u.op is Ops.MOD:
|
||||
if u.src[1].divides(c) is not None:
|
||||
something_changed = True
|
||||
u = u.src[0]
|
||||
new_xs.append(u)
|
||||
new_x: UOp = functools.reduce(operator.add, new_xs)
|
||||
if something_changed and new_x.vmin>=0: return new_x % y
|
||||
return None
|
||||
|
||||
def fold_binary_numerator(d: UOp, x: UOp, y: UOp) -> UOp|None:
|
||||
# we can fold if the expression has only one non-constant term and this term can only take on two values
|
||||
if len(svars)==1 and (v:=svars[0]).vmax-v.vmin == 1:
|
||||
y1 = cmod(factors[0]*v.vmin+const, c) if which is Ops.MOD else cdiv(factors[0]*v.vmin+const, c)
|
||||
y2 = cmod(factors[0]*v.vmax+const, c) if which is Ops.MOD else cdiv(factors[0]*v.vmax+const, c)
|
||||
if ((c := y.arg) < 0) or (x.dtype.count > 1): return None
|
||||
x,const = x.pop_const()
|
||||
terms, factors = zip(*[(u.divides(f:=u.const_factor()),f) for u in split_uop(x, Ops.ADD)])
|
||||
if len(terms)==1 and (v:=terms[0]).vmax-v.vmin == 1:
|
||||
y1 = cmod(factors[0]*v.vmin+const, c) if d.op is Ops.MOD else cdiv(factors[0]*v.vmin+const, c) # type: ignore
|
||||
y2 = cmod(factors[0]*v.vmax+const, c) if d.op is Ops.MOD else cdiv(factors[0]*v.vmax+const, c) # type: ignore
|
||||
return (y2-y1)*(v-v.vmin) + y1
|
||||
return None
|
||||
|
||||
if not CORRECT_DIVMOD_FOLDING or x_min>=0:
|
||||
# a//c = (a-a%c)/c, if we can fold a%c, we can fold a//c
|
||||
# within a mod we can freely subtract multiples of c, we use this to see if a is congruent to an expression whose vmin/vmax are between 0 and c
|
||||
rems = [min(r, r-c, key=abs) for r in remainders]
|
||||
if (rem:=sum(r*v for r,v in zip(rems,svars))+const%c).vmin//c==rem.vmax//c and all(f > 0 for f in factors):
|
||||
if which is Ops.MOD: return rem - rem.vmin//c*c
|
||||
return sum((f-r)//c * v for f,r,v in zip(factors,rems,svars)) + (const-const%c+rem.vmin//c*c)//c
|
||||
def fold_divmod_congruence(d: UOp, x: UOp, y: UOp) -> UOp|None:
|
||||
# within a mod we can freely subtract multiples of c, we use this to see if a is congruent to an expression whose vmin/vmax are between 0 and c
|
||||
if (x.vmin<0 and CORRECT_DIVMOD_FOLDING) or ((c := y.arg) < 0) or (x.dtype.count > 1): return None
|
||||
x,const = x.pop_const()
|
||||
terms, factors = zip(*[(u.divides(f:=u.const_factor()),f) for u in split_uop(x, Ops.ADD)])
|
||||
# a//c = (a-a%c)/c, if we can fold a%c, we can fold a//c
|
||||
rems = [min((r:=f%c), r-c, key=abs) for f in factors]
|
||||
if (rem:=sum(r*v for r,v in zip(rems,terms))+const%c).vmin//c==rem.vmax//c and all(f > 0 for f in factors):
|
||||
if d.op is Ops.MOD: return rem - rem.vmin//c*c
|
||||
return sum((f-r)//c * v for f,r,v in zip(factors,rems,terms)) + (const-const%c+rem.vmin//c*c)//c
|
||||
return None
|
||||
|
||||
if (g:=math.gcd(gcd, const))!=1:
|
||||
ret = UOp(which, x.dtype, src=(sum(f//g * v for f,v in zip(factors, svars)) + const//g, x.const_like(c//g)))
|
||||
return ret*g if which is Ops.MOD else ret
|
||||
def divide_by_gcd(d: UOp, x: UOp, y: UOp) -> UOp|None:
|
||||
# x//y -> (x//gcd)//(y//gcd) or x%y -> gcd*(x//gcd)%(y//gcd)
|
||||
terms, factors = zip(*[(u.divides(f:=u.const_factor()),f) for u in split_uop(x, Ops.ADD)])
|
||||
if (gcd := math.gcd(y.arg, *factors)) == 1: return None
|
||||
ret = sum(f//gcd * v for f,v in zip(factors, terms)).alu(d.op, y.const_like(y.arg//gcd))
|
||||
return ret*gcd if d.op is Ops.MOD else ret
|
||||
|
||||
def nest_div_by_smallest_factor(d: UOp, x: UOp, y: UOp) -> UOp|None:
|
||||
# we try and nest the div and see if it allows the numerator to be simplified
|
||||
if ((c := y.arg) < 0) or (x.dtype.count > 1): return None
|
||||
factors = [u.const_factor() for u in split_uop(x.pop_const()[0], Ops.ADD)]
|
||||
# div is the smallest factor of the denominator (greater than 1) out of all "factors"
|
||||
# TODO: there are better ways to pick `div`, this sometimes adds extra divisions
|
||||
# TODO: add same optimization for mod
|
||||
div = min([y.arg]+[abs(f) for f in factors if abs(f) > 1 and (c%f)==0])
|
||||
if (1 < div < c) and (newxs:=(newx:=(x//div)).simplify()) is not newx and x.vmin>=0 and newx.vmin>=0: return newxs//(c//div)
|
||||
return None
|
||||
|
||||
def simplify_remainder(d: UOp, x: UOp, y: UOp) -> UOp|None:
|
||||
# we try and take out the quotient and see if it allows the numerator to be simplified
|
||||
if ((c := y.arg) < 0) or (x.dtype.count > 1): return None
|
||||
x_no_const,const = x.pop_const()
|
||||
terms, factors = zip(*[(u.divides(f:=u.const_factor()),f) for u in split_uop(x_no_const, Ops.ADD)])
|
||||
quotients, remainders = zip(*[divmod(f, c) for f in factors])
|
||||
gcd = math.gcd(c, *remainders) # gcd without const!
|
||||
if const%c==const and gcd==1 and not any(r==0 or (r!=f and d.op is Ops.MOD) for r,f in zip(remainders, factors)): return None
|
||||
|
||||
if gcd != 1: something_changed = True
|
||||
if not something_changed:
|
||||
if which is Ops.IDIV and (1 < div < c) and (newx:=div_and_mod_folding(x, x.const_like(div), Ops.IDIV)) is not None: return newx//(c//div)
|
||||
return None
|
||||
quo, rem = x.const_like(const//c), x.const_like((const%c)//gcd)
|
||||
for q,r,f,v in zip(quotients, remainders, factors, svars):
|
||||
if which is Ops.IDIV and (not split_rem) and r!=0:
|
||||
for q,r,f,v in zip(quotients, remainders, factors, terms):
|
||||
if d.op is Ops.IDIV and r!=0:
|
||||
rem += f//gcd * v
|
||||
else:
|
||||
rem += r//gcd * v
|
||||
quo += q * v
|
||||
|
||||
# if numerator before/after is negative, and it has remainder, don't simplify because C divmod is different from python divmod.
|
||||
if (x_min < 0 or rem.vmin < 0) and remainders: return None
|
||||
if which is Ops.MOD: return gcd*(rem % (c//gcd)) + const%gcd
|
||||
if (x.vmin < 0 or rem.vmin < 0) and remainders: return None
|
||||
if d.op is Ops.MOD: return gcd*(rem % (c//gcd)) + const%gcd
|
||||
return rem//(c//gcd)+quo
|
||||
|
||||
def gep_through_wmma(gep:UOp, wmma:UOp):
|
||||
|
|
@ -302,15 +329,20 @@ symbolic = symbolic_simple+commutative+PatternMatcher([
|
|||
# div folding
|
||||
((UPat.var("x")//UPat.cvar("c") + UPat.cvar("a"))//UPat.cvar("d"), lambda x,c,a,d: (x+a*c)//(c*d)
|
||||
if c.vmin>0 and d.vmin>0 and ((x.vmin>=0 and a.vmin>=0) or (x.vmax<=0 and a.vmax<=0)) else None), # (x//c+a)//d -> (x+a*c)//(c*d)
|
||||
(UPat.var("x", dtypes.sints) // UPat.var("y"), lambda x,y: div_and_mod_folding(x,y,Ops.IDIV)),
|
||||
(UPat((Ops.IDIV, Ops.MOD), dtypes.sints, name="d", src=(UPat.var("x"), UPat.var("y"))), cancel_divmod),
|
||||
(UPat((Ops.IDIV, Ops.MOD), dtypes.sints, name="d", src=(UPat.var("x"), UPat.cvar("y", vec=False))), fold_binary_numerator),
|
||||
(UPat((Ops.IDIV, Ops.MOD), dtypes.sints, name="d", src=(UPat.var("x"), UPat.cvar("y", vec=False))), fold_divmod_congruence),
|
||||
(UPat((Ops.IDIV, Ops.MOD), dtypes.sints, name="d", src=(UPat.var("x"), UPat.cvar("y", vec=False))), divide_by_gcd),
|
||||
(UPat(Ops.MOD, dtypes.sints, name="m", src=(UPat.var("x"), UPat.cvar("y", vec=False))), remove_nested_mod),
|
||||
(UPat((Ops.IDIV), dtypes.sints, name="d", src=(UPat.var("x"), UPat.cvar("y", vec=False))), nest_div_by_smallest_factor),
|
||||
(UPat((Ops.IDIV, Ops.MOD), dtypes.sints, name="d", src=(UPat.var("x"), UPat.cvar("y", vec=False))), simplify_remainder),
|
||||
(UPat.var("x") // UPat.var("d"), lambda x,d: -(x//(-d)) if d.vmax < 0 else None),
|
||||
(UPat.var("x") // UPat.var("d"), lambda x,d: -((-x)//d) if x.vmax <=0 else None),
|
||||
((UPat.var("x", dtypes.sints)+UPat.cvar("c", vec=False)).named("n")//UPat.cvar("d", vec=False),
|
||||
lambda x,c,n,d: (-(-(c.arg%d.arg + x - (d.arg-1))//d) + c.arg//d.arg) if x.vmax<=0 and n.vmin>=0 and d.arg>0 else None),
|
||||
# ** mod **
|
||||
# mod folding
|
||||
(UPat.var("x") % UPat.var("y"), lambda x,y: div_and_mod_folding(x,y,Ops.MOD)),
|
||||
(UPat.var("x") % UPat.var("d"), lambda x,d: -((-x)%d) if x.vmax <=0 else None),
|
||||
(UPat.var("x") % UPat.var("d"), lambda x,d: -((-x)%d) if x.vmax <= 0 else None),
|
||||
(UPat.var("x") % UPat.var("d"), lambda x,d: (x%(-d)) if d.vmax < 0 else None),
|
||||
])+gep_pushing
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue