mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
cast parents of int64 alu to int32 if possible (#11977)
* add overflows helper * add rules * x -> y * check overflow of u too * cleaner * use alu instead of replace to preserve vectorization * just one rule * add test
This commit is contained in:
parent
033184b3cb
commit
86e908db57
3 changed files with 9 additions and 0 deletions
|
|
@ -732,6 +732,12 @@ class TestSymbolic(unittest.TestCase):
|
|||
a = Variable("a", 1, 10, dtypes.int)
|
||||
self.helper_test_variable(a.trunc(), 1, 10, "a", test_z3=False)
|
||||
|
||||
def test_do_math_in_int32(self):
|
||||
a = Variable("a", 1, 10)
|
||||
b = Variable("b", 1, 10)
|
||||
self.helper_test_variable(a.cast(dtypes.long)+b.cast(dtypes.long), 2, 20, "(long)((a+b))")
|
||||
self.helper_test_variable(a.cast(dtypes.long)*b.cast(dtypes.long), 1, 100, "(long)((a*b))")
|
||||
|
||||
class TestSymbolicNumeric(unittest.TestCase):
|
||||
def helper_test_numeric(self, f):
|
||||
MIN, MAX = 0, 10
|
||||
|
|
|
|||
|
|
@ -327,6 +327,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
|||
def allreduce(self, op, device:str|tuple[str, ...]|UOp):
|
||||
assert isinstance(self.device, tuple), f"allreduce must be on tuple {self.device} isn't"
|
||||
return UOp(Ops.ALLREDUCE, self.dtype, (self, UOp(Ops.DEVICE, arg=device) if not isinstance(device, UOp) else device), op)
|
||||
def overflows(self, dtype:DType) -> bool: return self.vmin < dtype.min or dtype.max < self.vmax
|
||||
|
||||
# *** from MultiLazyBuffer ***
|
||||
|
||||
|
|
|
|||
|
|
@ -289,6 +289,8 @@ symbolic = symbolic_simple+commutative+PatternMatcher([
|
|||
(-1 * (UPat.var("x") + UPat.cvar("c")), lambda x,c: (-x)+(-c)), # -(x+c) -> -x + -c
|
||||
(UPat.var('x', dtypes.ints).cast(dtypes.ints, name="a").cast(name="b"),
|
||||
lambda x,a,b: x.cast(b.dtype) if a.dtype.min<=x.vmin and x.vmax<=a.dtype.max else None),
|
||||
(UPat(GroupOp.Binary, src=(UPat.var("x",dtypes.long), UPat.var("y", dtypes.long)), name="u"), lambda u,x,y:
|
||||
x.cast(dtypes.int).alu(u.op, y.cast(dtypes.int)).cast(u.dtype) if not any(v.overflows(dtypes.int) for v in (u,x,y)) else None),
|
||||
# a conditional with the same results either way is a noop, also fold const conditionals
|
||||
(UPat.var().where(UPat.var("val"), UPat.var("val")), lambda val: val),
|
||||
(UPat.cvar("gate", vec=False).where(UPat.var("c0"), UPat.var("c1")), lambda gate, c0, c1: c0 if gate.arg else c1),
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue