truncate_fp16 cleanup (#11838)

native `@` is default
This commit is contained in:
chenyu 2025-08-25 19:03:41 -04:00 committed by GitHub
commit ac3449b0c8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 7 additions and 2 deletions

View file

@ -102,13 +102,18 @@ class TestHelpers(unittest.TestCase):
self.assertEqual(truncate_fp16(65504), 65504)
self.assertEqual(truncate_fp16(65519.999), 65504)
self.assertEqual(truncate_fp16(65520), math.inf)
self.assertEqual(truncate_fp16(1e-8), 0.0)
self.assertEqual(truncate_fp16(-65504), -65504)
self.assertEqual(truncate_fp16(-65519.999), -65504)
self.assertEqual(truncate_fp16(-65520), -math.inf)
self.assertTrue(math.isnan(truncate_fp16(math.nan)))
def test_truncate_bf16(self):
self.assertEqual(truncate_bf16(1), 1)
# TODO: rounding, torch bfloat 1.1 gives 1.1015625 instead of 1.09375
self.assertAlmostEqual(truncate_bf16(1.1), 1.09375, places=7)
for a in [1234, 23456, -777.777]:
self.assertEqual(truncate_bf16(a), torch.tensor([a], dtype=torch.bfloat16).item())
# TODO: torch bfloat 1.1 gives 1.1015625 instead of 1.09375
max_bf16 = torch.finfo(torch.bfloat16).max
self.assertEqual(truncate_bf16(max_bf16), max_bf16)
self.assertEqual(truncate_bf16(min_bf16:=-max_bf16), min_bf16)

View file

@ -215,7 +215,7 @@ def sum_acc_dtype(dt:DType):
return least_upper_dtype(dt, to_dtype(getenv("SUM_DTYPE", "float32")))
def truncate_fp16(x):
try: return struct.unpack("@e", struct.pack("@e", float(x)))[0]
try: return struct.unpack('e', struct.pack('e', float(x)))[0]
except OverflowError: return math.copysign(math.inf, x)
def truncate_bf16(x):