cast parents of int64 alu to int32 if possible (#11977)

* add overflows helper

* add rules

* x -> y

* check overflow of u too

* cleaner

* use alu instead of replace to preserve vectorization

* just one rule

* add test
This commit is contained in:
Sieds Lykles 2025-09-03 11:05:04 +02:00 committed by GitHub
commit 86e908db57
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 9 additions and 0 deletions

View file

@ -732,6 +732,12 @@ class TestSymbolic(unittest.TestCase):
a = Variable("a", 1, 10, dtypes.int)
self.helper_test_variable(a.trunc(), 1, 10, "a", test_z3=False)
def test_do_math_in_int32(self):
a = Variable("a", 1, 10)
b = Variable("b", 1, 10)
self.helper_test_variable(a.cast(dtypes.long)+b.cast(dtypes.long), 2, 20, "(long)((a+b))")
self.helper_test_variable(a.cast(dtypes.long)*b.cast(dtypes.long), 1, 100, "(long)((a*b))")
class TestSymbolicNumeric(unittest.TestCase):
def helper_test_numeric(self, f):
MIN, MAX = 0, 10

View file

@ -327,6 +327,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
def allreduce(self, op, device:str|tuple[str, ...]|UOp):
assert isinstance(self.device, tuple), f"allreduce must be on tuple {self.device} isn't"
return UOp(Ops.ALLREDUCE, self.dtype, (self, UOp(Ops.DEVICE, arg=device) if not isinstance(device, UOp) else device), op)
def overflows(self, dtype:DType) -> bool: return self.vmin < dtype.min or dtype.max < self.vmax
# *** from MultiLazyBuffer ***

View file

@ -289,6 +289,8 @@ symbolic = symbolic_simple+commutative+PatternMatcher([
(-1 * (UPat.var("x") + UPat.cvar("c")), lambda x,c: (-x)+(-c)), # -(x+c) -> -x + -c
(UPat.var('x', dtypes.ints).cast(dtypes.ints, name="a").cast(name="b"),
lambda x,a,b: x.cast(b.dtype) if a.dtype.min<=x.vmin and x.vmax<=a.dtype.max else None),
(UPat(GroupOp.Binary, src=(UPat.var("x",dtypes.long), UPat.var("y", dtypes.long)), name="u"), lambda u,x,y:
x.cast(dtypes.int).alu(u.op, y.cast(dtypes.int)).cast(u.dtype) if not any(v.overflows(dtypes.int) for v in (u,x,y)) else None),
# a conditional with the same results either way is a noop, also fold const conditionals
(UPat.var().where(UPat.var("val"), UPat.var("val")), lambda val: val),
(UPat.cvar("gate", vec=False).where(UPat.var("c0"), UPat.var("c1")), lambda gate, c0, c1: c0 if gate.arg else c1),