mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
parent
e146418f65
commit
ac3449b0c8
2 changed files with 7 additions and 2 deletions
|
|
@ -102,13 +102,18 @@ class TestHelpers(unittest.TestCase):
|
|||
self.assertEqual(truncate_fp16(65504), 65504)
|
||||
self.assertEqual(truncate_fp16(65519.999), 65504)
|
||||
self.assertEqual(truncate_fp16(65520), math.inf)
|
||||
self.assertEqual(truncate_fp16(1e-8), 0.0)
|
||||
self.assertEqual(truncate_fp16(-65504), -65504)
|
||||
self.assertEqual(truncate_fp16(-65519.999), -65504)
|
||||
self.assertEqual(truncate_fp16(-65520), -math.inf)
|
||||
self.assertTrue(math.isnan(truncate_fp16(math.nan)))
|
||||
|
||||
def test_truncate_bf16(self):
|
||||
self.assertEqual(truncate_bf16(1), 1)
|
||||
# TODO: rounding, torch bfloat 1.1 gives 1.1015625 instead of 1.09375
|
||||
self.assertAlmostEqual(truncate_bf16(1.1), 1.09375, places=7)
|
||||
for a in [1234, 23456, -777.777]:
|
||||
self.assertEqual(truncate_bf16(a), torch.tensor([a], dtype=torch.bfloat16).item())
|
||||
# TODO: torch bfloat 1.1 gives 1.1015625 instead of 1.09375
|
||||
max_bf16 = torch.finfo(torch.bfloat16).max
|
||||
self.assertEqual(truncate_bf16(max_bf16), max_bf16)
|
||||
self.assertEqual(truncate_bf16(min_bf16:=-max_bf16), min_bf16)
|
||||
|
|
|
|||
|
|
@ -215,7 +215,7 @@ def sum_acc_dtype(dt:DType):
|
|||
return least_upper_dtype(dt, to_dtype(getenv("SUM_DTYPE", "float32")))
|
||||
|
||||
def truncate_fp16(x):
|
||||
try: return struct.unpack("@e", struct.pack("@e", float(x)))[0]
|
||||
try: return struct.unpack('e', struct.pack('e', float(x)))[0]
|
||||
except OverflowError: return math.copysign(math.inf, x)
|
||||
|
||||
def truncate_bf16(x):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue