pow cleanups (#8894)

more readable
This commit is contained in:
chenyu 2025-02-04 15:52:57 -05:00 committed by GitHub
commit 89eebd4bfb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -3315,14 +3315,10 @@ class Tensor(SimpleMathTrait):
if not base.is_floating_point(): raise RuntimeError("base needs to be float")
# start with b ** e = exp(e * log(b))
ret = base.abs().log().mul(exponent).exp()
# correct sign of negative base with odd exponent
negative_base = (base < 0).detach().where(1, 0)
# 1 for non-negative base or negative even exponent, -1 for negative odd exponent, don't care about non-integer exponent
correct_sign = (exponent.int()%2==0).where(1, 1-2*negative_base)
# inject nan for negative base and non-integer exponent
inject_nan = (negative_base * (exponent != exponent.trunc())).detach().where(math.nan, 1)
# apply correct_sign inject_nan, and fix 0 ** 0 = 1
ret = ((base == 0) * (exponent == 0)).detach().where(1, ret * correct_sign * inject_nan)
# negative base adjustment: nan for non-integer exponent and -1 for odd exponent
adj = (base < 0).detach().where((exponent != exponent.int()).detach().where(math.nan, (exponent.int()%2==1).where(-1, 1)), 1)
# fix 0 ** 0 = 1
ret = ((base == 0) * (exponent == 0)).detach().where(1, ret * adj)
return ret.round().cast(self.dtype) if not dtypes.is_float(self.dtype) else ret
def maximum(self, x:Union[Tensor, ConstType]) -> Tensor: