tiny cleanup to transcendental xexp2 (#7326)

also added test for exp and log of nan and inf
This commit is contained in:
chenyu 2024-10-27 21:54:20 -04:00 committed by GitHub
commit cb5702f170
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 9 additions and 12 deletions

View file

@ -607,16 +607,20 @@ class TestOps(unittest.TestCase):
def test_log(self):
helper_test_op([(45,65)], torch.log, Tensor.log)
helper_test_op(None, torch.log, Tensor.log, vals=[[math.inf, -math.inf, math.nan]])
helper_test_op([()], torch.log, Tensor.log)
def test_log2(self):
helper_test_op([(45,65)], torch.log2, Tensor.log2)
helper_test_op(None, torch.log2, Tensor.log2, vals=[[math.inf, -math.inf, math.nan]])
helper_test_op([()], torch.log2, Tensor.log2)
def test_exp(self):
helper_test_op([(45,65)], torch.exp, Tensor.exp)
helper_test_op(None, torch.exp, Tensor.exp, vals=[[math.inf, -math.inf, math.nan]])
helper_test_op([()], torch.exp, Tensor.exp)
def test_exp2(self):
helper_test_op([(45,65)], torch.exp2, Tensor.exp2)
helper_test_op(None, torch.exp2, Tensor.exp2, vals=[[math.inf, -math.inf, math.nan]])
helper_test_op([()], torch.exp2, Tensor.exp2)
def test_sign(self):

View file

@ -230,29 +230,23 @@ def xexp2(d:UOp) -> UOp:
- Paper: https://arxiv.org/pdf/2001.09258
"""
assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
fp64_p = d.dtype == dtypes.float64
# mask +=inf/nan as zero.
x = _lazy_map_numbers(d, d.const_like(0.0), d.const_like(0.0), d.const_like(0.0), d)
q = rintk(x)
# s = d - round(d)
s = x - q.cast(x.dtype)
# a polynomial approximation with 13 non-zero terms in the range of [(log 2)/2,(log 2)/2].
if fp64_p:
if d.dtype == dtypes.float64:
u = polyN(s, [0.4434359082926529454e-9, 0.7073164598085707425e-8, 0.1017819260921760451e-6, 0.1321543872511327615e-5, 0.1525273353517584730e-4,
0.1540353045101147808e-3, 0.1333355814670499073e-2, 0.9618129107597600536e-2, 0.5550410866482046596e-1, 0.2402265069591012214e+0,
0.6931471805599452862e+0, 0.1000000000000000000e+1])
else:
u = polyN(s, [0.1535920892e-3, 0.1339262701e-2, 0.9618384764e-2, 0.5550347269e-1, 0.2402264476e+0, 0.6931471825e+0, 0.1000000000e+1])
else: u = polyN(s, [0.1535920892e-3, 0.1339262701e-2, 0.9618384764e-2, 0.5550347269e-1, 0.2402264476e+0, 0.6931471825e+0, 1.0])
u = ldexp2k(u, q) # u*2^q
upper = {dtypes.float64: 1024, dtypes.float32: 128, dtypes.float16: 23}[x.dtype]
lower = {dtypes.float64: -2000, dtypes.float32: -150, dtypes.float16: -22}[x.dtype]
upper, lower = {dtypes.float64: (1024, -2000), dtypes.float32: (128, -150), dtypes.float16: (23, -22)}[x.dtype]
# Replace x >= upper with +inf
u = x.ne(upper).where(u, x.const_like(math.inf))
u = x.lt(upper).where(u, x.const_like(math.inf))
u = x.ge(upper).where(x.const_like(math.inf), u)
# Replace x <= lower with zero.
u = x.lt(lower).where(x.const_like(0.0), u)
# x=NaN never satisfies x < Inf. (for fastmode)
u = x.lt(math.inf).where(u, u.const_like(math.nan))
# exp2(Inf) = Inf, exp2(-Inf) = 0, exp2(NaN) = NaN
return _lazy_map_numbers(d, d.const_like(math.inf), d.const_like(0.0), d.const_like(math.nan), u)
@ -262,7 +256,6 @@ def xlog2(d:UOp) -> UOp:
Paper: https://arxiv.org/pdf/2001.09258
"""
assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
fp64_p = d.dtype == dtypes.float64
FLT_MIN = d.const_like(1e-6 if d.dtype == dtypes.float16 else 1e-4)
d_orig = d
denormal_map = d.lt(FLT_MIN)
@ -272,7 +265,7 @@ def xlog2(d:UOp) -> UOp:
m = ldexp3k(d, -e)
e = denormal_map.where(e + (-64), e)
if fp64_p:
if d.dtype == dtypes.float64:
x = (m - 1.0) * (m + 1.0).recip()
x2 = x * x
t = polyN(x2, [0.2211941750456081490e+0, 0.2200768693152277689e+0, 0.2623708057488514656e+0, 0.3205977477944495502e+0,