minor fast_idiv cleanup [pr] (#13109)

This commit is contained in:
chenyu 2025-11-05 11:44:36 -05:00 committed by GitHub
commit 03ee0cfe45
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -282,7 +282,7 @@ def magicgu(vmax:int, d:int) -> tuple[int,int]:
def fast_idiv(device: str, x: UOp, d: int, dont_cast=False) -> UOp|None:
# If d is a power of two this is not valid for signed ints!
is_unsigned = True if x.vmin>=0 or x.dtype in dtypes.uints else False
is_unsigned = x.vmin>=0 or x.dtype in dtypes.uints
assert d>0, "Sign should have been taken out of divisor"
vmin,vmax = max(x.vmin, x.dtype.min), min(x.vmax, x.dtype.max)
m,s = magicgu(max(vmax, abs(vmin)), d)
@ -293,7 +293,7 @@ def fast_idiv(device: str, x: UOp, d: int, dont_cast=False) -> UOp|None:
if (ret:=fast_idiv(device, x//largest_factor_of_two_in_d, d//largest_factor_of_two_in_d, dont_cast=True)) is not None: return ret
if dont_cast: return None
# promo_lattice needs to return an unsigned type if the type is unsigned
if dtypes.is_int(next_dtype := promo_lattice[x.dtype.scalar()][-1]) and is_dtype_supported(next_dtype, None if device=='' else device):
if dtypes.is_int(next_dtype := promo_lattice[x.dtype.scalar()][-1]) and is_dtype_supported(next_dtype, device):
if m*vmin >= dtypes.min(next_dtype) and m*vmax <= dtypes.max(next_dtype):
return ((x.cast(next_dtype)*m) >> s).cast(x.dtype) if is_unsigned else ((x.cast(next_dtype)*m) >> s).cast(x.dtype) + (x<0).where(x.ufix(1), 0)
return None