mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
parent
7a9e3247c2
commit
89eebd4bfb
1 changed files with 4 additions and 8 deletions
|
|
@ -3315,14 +3315,10 @@ class Tensor(SimpleMathTrait):
|
|||
if not base.is_floating_point(): raise RuntimeError("base needs to be float")
|
||||
# start with b ** e = exp(e * log(b))
|
||||
ret = base.abs().log().mul(exponent).exp()
|
||||
# correct sign of negative base with odd exponent
|
||||
negative_base = (base < 0).detach().where(1, 0)
|
||||
# 1 for non-negative base or negative even exponent, -1 for negative odd exponent, don't care about non-integer exponent
|
||||
correct_sign = (exponent.int()%2==0).where(1, 1-2*negative_base)
|
||||
# inject nan for negative base and non-integer exponent
|
||||
inject_nan = (negative_base * (exponent != exponent.trunc())).detach().where(math.nan, 1)
|
||||
# apply correct_sign inject_nan, and fix 0 ** 0 = 1
|
||||
ret = ((base == 0) * (exponent == 0)).detach().where(1, ret * correct_sign * inject_nan)
|
||||
# negative base adjustment: nan for non-integer exponent and -1 for odd exponent
|
||||
adj = (base < 0).detach().where((exponent != exponent.int()).detach().where(math.nan, (exponent.int()%2==1).where(-1, 1)), 1)
|
||||
# fix 0 ** 0 = 1
|
||||
ret = ((base == 0) * (exponent == 0)).detach().where(1, ret * adj)
|
||||
return ret.round().cast(self.dtype) if not dtypes.is_float(self.dtype) else ret
|
||||
|
||||
def maximum(self, x:Union[Tensor, ConstType]) -> Tensor:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue