start cleaning up dtype tests (#16324)

This commit is contained in:
Christopher Milan 2026-05-21 18:11:49 -07:00 committed by GitHub
commit 150a82de1f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 50 additions and 95 deletions

View file

@ -103,31 +103,6 @@ class TestDType(unittest.TestCase):
_test_to_np(Tensor(v, dtype=self.DTYPE)+2, _to_np_dtype(self.DTYPE), np.array(v, dtype=_to_np_dtype(self.DTYPE))+2)
_test_to_np(Tensor(v, dtype=self.DTYPE)*2, _to_np_dtype(self.DTYPE), np.array(v, dtype=_to_np_dtype(self.DTYPE))*2)
def test_dtypes_DTYPES_DICT(self):
self.assertIn("float", DTYPES_DICT)
self.assertIn("float32", DTYPES_DICT)
self.assertEqual(len(DTYPES_DICT), 28)
self.assertTrue(all(isinstance(value, DType) for value in DTYPES_DICT.values()))
self.assertTrue(all(issubclass(_to_np_dtype(value), np.generic) for value in DTYPES_DICT.values() if _to_np_dtype(value) is not None))
def test_resulting_and_init_dtypes_match(self):
dtypes = list(map(np.dtype, ["bool", "uint8", "int8", "int16", "int32", "int64", "float32", "float64"]))
data = [1., 2., 0., 0.5, -1.5, 5.25]
for dt in dtypes:
arr = np.asarray(data).astype(dt)
tensor = Tensor(arr)
if tensor.dtype not in supported_dtypes: continue
tin = tensor.numpy()
tor = torch.as_tensor(arr).detach().numpy()
assert dt == tin.dtype == tor.dtype, f"dtype mismatch: expected={dt} | tinygrad={tin.dtype} | torch={tor.dtype}"
np.testing.assert_allclose(tin, tor, atol=1e-6, rtol=1e-3)
def test_finfo(self):
if self.DTYPE not in [dtypes.float16, dtypes.float32, dtypes.float64]: return
info = np.finfo(_to_np_dtype(self.DTYPE))
self.assertEqual(info.bits, self.DTYPE.bitsize)
self.assertEqual((info.nexp, info.nmant), dtypes.finfo(self.DTYPE))
def _test_ops(a_dtype:DType, b_dtype:DType, target_dtype=None):
target_dtype = target_dtype or least_upper_dtype(a_dtype, b_dtype)
if a_dtype == dtypes.bool or b_dtype == dtypes.bool: return
@ -137,12 +112,6 @@ def _test_ops(a_dtype:DType, b_dtype:DType, target_dtype=None):
_assert_eq(Tensor([[1,2],[3,4]], dtype=a_dtype)@Tensor.eye(2, dtype=b_dtype), target_dtype, [[1,2],[3,4]])
_assert_eq(Tensor([1,1,1,1], dtype=a_dtype)+Tensor.ones((4,4), dtype=b_dtype), target_dtype, 2*Tensor.ones(4,4).numpy())
class TestFp8s(unittest.TestCase):
def test_fp8e4m3_creation(self): assert Tensor([-1, 1, 2], dtype=dtypes.fp8e4m3).dtype == dtypes.fp8e4m3
def test_fp8e5m2_creation(self): assert Tensor([-1, 1, 2], dtype=dtypes.fp8e5m2).dtype == dtypes.fp8e5m2
def test_fp8e4m3fnuz_creation(self): assert Tensor([-1, 1, 2], dtype=dtypes.fp8e4m3fnuz).dtype == dtypes.fp8e4m3fnuz
def test_fp8e5m2fnuz_creation(self): assert Tensor([-1, 1, 2], dtype=dtypes.fp8e5m2fnuz).dtype == dtypes.fp8e5m2fnuz
class TestFp8sConversions(unittest.TestCase):
@given(strat.floats(width=32, allow_subnormal=True, allow_nan=False, allow_infinity=False, min_value=-FP8E4M3_MAX, max_value=FP8E4M3_MAX))
def test_float_to_fp8e4m3(self, x):
@ -192,25 +161,6 @@ class TestFp8sConversions(unittest.TestCase):
def test_fp8e5m2fnuz_to_float(self, x):
np.testing.assert_equal(fp8_to_float(x, dtypes.fp8e5m2fnuz), torch.tensor(x, dtype=torch.uint8).view(torch.float8_e5m2fnuz).float().item())
class TestBFloat16(unittest.TestCase):
def test_bf16_creation_numpy(self):
data = [-1, 1, 2]
t = Tensor(data, dtype=dtypes.bfloat16)
assert t.dtype == dtypes.bfloat16
tnp = t.numpy()
assert tnp.dtype == np.float32
np.testing.assert_allclose(tnp, np.array(data))
def test_bf16_ones(self):
t = Tensor.ones(3, 5, dtype=dtypes.bfloat16)
assert t.dtype == dtypes.bfloat16
np.testing.assert_allclose(t.numpy(), np.ones((3, 5)))
def test_bf16_eye(self):
t = Tensor.eye(3, dtype=dtypes.bfloat16)
assert t.dtype == dtypes.bfloat16
np.testing.assert_allclose(t.numpy(), np.eye(3))
class TestBFloat16DType(unittest.TestCase):
def test_bf16_to_float(self):
_test_cast(Tensor([100000], dtype=dtypes.bfloat16), dtypes.float32)
@ -419,48 +369,6 @@ class TestEmulatedFp8e5m2(TestFp8e5m2):
@classmethod
def tearDownClass(cls): cls.stack.close()
class TestPtrDType(unittest.TestCase):
def test_vec_double(self):
dt1 = dtypes.float.vec(4).ptr().vec(4)
dt2 = dtypes.float.vec(4).ptr().vec(4)
self.assertEqual(dt1, dt2)
self.assertEqual(str(dt1), str(dt2))
def test_scalar(self):
dt = dtypes.float.vec(4).ptr().scalar()
self.assertEqual(dt.base, dtypes.float.vec(4))
dt = dtypes.float.vec(4).ptr().vec(4).scalar()
self.assertEqual(dt.base, dtypes.float.vec(4))
dt = dtypes.float.vec(4).scalar()
self.assertEqual(dt, dtypes.float)
def test_serialize(self):
dt = dtypes.float.vec(4).ptr().vec(4)
self.assertEqual(dt, eval(str(dt)))
def test_vec_ptr_sz(self):
dt = dtypes.float.ptr(1024).vec(4)
self.assertEqual(dt, eval(str(dt)))
self.assertEqual(str(dt), "dtypes.float.ptr(1024).vec(4)")
def test_vcount(self):
dt = dtypes.float.ptr().vec(4)
self.assertEqual(dt.vcount, 4)
self.assertEqual(dt.v, 4)
self.assertEqual(dt.count, 1)
dt = dtypes.float.vec(4).ptr()
self.assertEqual(dt.vcount, 1)
self.assertEqual(dt.v, 1)
self.assertEqual(dt.count, 4)
dt = dtypes.float.vec(4).ptr().vec(4)
self.assertEqual(dt.vcount, 4)
self.assertEqual(dt.v, 4)
self.assertEqual(dt.count, 4)
class TestImplicitFunctionTypeChange(unittest.TestCase):
def test_functions(self):
result = []

View file

@ -139,12 +139,12 @@ class TestDTypeALU(unittest.TestCase):
@unittest.skipUnless(dtypes.bfloat16 in supported_dtypes, f"no bfloat16 on {Device.DEFAULT}")
@given(ht.bfloat16, ht.bfloat16, strat.sampled_from(binary_operations))
def test_bfloat16(self, a, b, op):
universal_test(from_storage_scalar(a, dtypes.bfloat16), from_storage_scalar(a, dtypes.bfloat16), dtypes.bfloat16, op)
universal_test(from_storage_scalar(a, dtypes.bfloat16), from_storage_scalar(b, dtypes.bfloat16), dtypes.bfloat16, op)
@given(ht.bfloat16, ht.bfloat16, strat.sampled_from(binary_operations))
@Context(EMULATED_DTYPES="bfloat16")
def test_emulated_bfloat16(self, a, b, op):
universal_test(from_storage_scalar(a, dtypes.bfloat16), from_storage_scalar(a, dtypes.bfloat16), dtypes.bfloat16, op)
universal_test(from_storage_scalar(a, dtypes.bfloat16), from_storage_scalar(b, dtypes.bfloat16), dtypes.bfloat16, op)
@unittest.skipUnless(dtypes.fp8e4m3 in supported_dtypes, f"no fp8e4m3 on {Device.DEFAULT}")
@given(ht.fp8e4m3, ht.fp8e4m3, strat.sampled_from(binary_operations))

View file

@ -10,6 +10,48 @@ class TestImageDType(unittest.TestCase):
assert dtypes.imagef((10,10)).base.vec(4) == dtypes.float32.vec(4)
assert dtypes.imageh((10,10)).base.vec(4) == dtypes.float32.vec(4)
class TestPtrDType(unittest.TestCase):
def test_vec_double(self):
dt1 = dtypes.float.vec(4).ptr().vec(4)
dt2 = dtypes.float.vec(4).ptr().vec(4)
self.assertEqual(dt1, dt2)
self.assertEqual(str(dt1), str(dt2))
def test_scalar(self):
dt = dtypes.float.vec(4).ptr().scalar()
self.assertEqual(dt.base, dtypes.float.vec(4))
dt = dtypes.float.vec(4).ptr().vec(4).scalar()
self.assertEqual(dt.base, dtypes.float.vec(4))
dt = dtypes.float.vec(4).scalar()
self.assertEqual(dt, dtypes.float)
def test_serialize(self):
dt = dtypes.float.vec(4).ptr().vec(4)
self.assertEqual(dt, eval(str(dt)))
def test_vec_ptr_sz(self):
dt = dtypes.float.ptr(1024).vec(4)
self.assertEqual(dt, eval(str(dt)))
self.assertEqual(str(dt), "dtypes.float.ptr(1024).vec(4)")
def test_vcount(self):
dt = dtypes.float.ptr().vec(4)
self.assertEqual(dt.vcount, 4)
self.assertEqual(dt.v, 4)
self.assertEqual(dt.count, 1)
dt = dtypes.float.vec(4).ptr()
self.assertEqual(dt.vcount, 1)
self.assertEqual(dt.v, 1)
self.assertEqual(dt.count, 4)
dt = dtypes.float.vec(4).ptr().vec(4)
self.assertEqual(dt.vcount, 4)
self.assertEqual(dt.v, 4)
self.assertEqual(dt.count, 4)
class TestEqStrDType(unittest.TestCase):
def test_image_ne(self):
if ImageDType is None: raise unittest.SkipTest("no ImageDType support")

View file

@ -178,6 +178,12 @@ class TestHelpers(unittest.TestCase):
elif x < -FP8E5M2_MAX: np.testing.assert_equal(truncate[dtypes.fp8e5m2](x), -FP8E5M2_MAX)
else: np.testing.assert_equal(truncate[dtypes.fp8e5m2](x), torch.tensor(x, dtype=torch.float8_e5m2).float().item())
def test_finfo(self):
for dt in [dtypes.float16, dtypes.float32, dtypes.float64]:
info = np.finfo(_to_np_dtype(dt))
self.assertEqual(info.bits, dt.bitsize)
self.assertEqual((info.nexp, info.nmant), dtypes.finfo(dt))
class TestTypePromotion(unittest.TestCase):
@given(strat.sampled_from(core_dtypes))
def test_self_promo_to_self(self, dtype):

View file

@ -59,7 +59,6 @@ class TestTypeSpec(unittest.TestCase):
subprocess.run(['DEFAULT_FLOAT=TYPO python3 -c "from tinygrad import dtypes"'],
shell=True, check=True)
@unittest.skipUnless(dtypes.int8 in supported_dtypes, f"no int8 on {Device.DEFAULT}")
def test_dtype_str_arg(self):
n = np.random.normal(0, 1, (10, 10)).astype(np.float32)
tested = 0