mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
start cleaning up dtype tests (#16324)
This commit is contained in:
parent
31424cda71
commit
150a82de1f
5 changed files with 50 additions and 95 deletions
|
|
@ -103,31 +103,6 @@ class TestDType(unittest.TestCase):
|
|||
_test_to_np(Tensor(v, dtype=self.DTYPE)+2, _to_np_dtype(self.DTYPE), np.array(v, dtype=_to_np_dtype(self.DTYPE))+2)
|
||||
_test_to_np(Tensor(v, dtype=self.DTYPE)*2, _to_np_dtype(self.DTYPE), np.array(v, dtype=_to_np_dtype(self.DTYPE))*2)
|
||||
|
||||
def test_dtypes_DTYPES_DICT(self):
|
||||
self.assertIn("float", DTYPES_DICT)
|
||||
self.assertIn("float32", DTYPES_DICT)
|
||||
self.assertEqual(len(DTYPES_DICT), 28)
|
||||
self.assertTrue(all(isinstance(value, DType) for value in DTYPES_DICT.values()))
|
||||
self.assertTrue(all(issubclass(_to_np_dtype(value), np.generic) for value in DTYPES_DICT.values() if _to_np_dtype(value) is not None))
|
||||
|
||||
def test_resulting_and_init_dtypes_match(self):
|
||||
dtypes = list(map(np.dtype, ["bool", "uint8", "int8", "int16", "int32", "int64", "float32", "float64"]))
|
||||
data = [1., 2., 0., 0.5, -1.5, 5.25]
|
||||
for dt in dtypes:
|
||||
arr = np.asarray(data).astype(dt)
|
||||
tensor = Tensor(arr)
|
||||
if tensor.dtype not in supported_dtypes: continue
|
||||
tin = tensor.numpy()
|
||||
tor = torch.as_tensor(arr).detach().numpy()
|
||||
assert dt == tin.dtype == tor.dtype, f"dtype mismatch: expected={dt} | tinygrad={tin.dtype} | torch={tor.dtype}"
|
||||
np.testing.assert_allclose(tin, tor, atol=1e-6, rtol=1e-3)
|
||||
|
||||
def test_finfo(self):
|
||||
if self.DTYPE not in [dtypes.float16, dtypes.float32, dtypes.float64]: return
|
||||
info = np.finfo(_to_np_dtype(self.DTYPE))
|
||||
self.assertEqual(info.bits, self.DTYPE.bitsize)
|
||||
self.assertEqual((info.nexp, info.nmant), dtypes.finfo(self.DTYPE))
|
||||
|
||||
def _test_ops(a_dtype:DType, b_dtype:DType, target_dtype=None):
|
||||
target_dtype = target_dtype or least_upper_dtype(a_dtype, b_dtype)
|
||||
if a_dtype == dtypes.bool or b_dtype == dtypes.bool: return
|
||||
|
|
@ -137,12 +112,6 @@ def _test_ops(a_dtype:DType, b_dtype:DType, target_dtype=None):
|
|||
_assert_eq(Tensor([[1,2],[3,4]], dtype=a_dtype)@Tensor.eye(2, dtype=b_dtype), target_dtype, [[1,2],[3,4]])
|
||||
_assert_eq(Tensor([1,1,1,1], dtype=a_dtype)+Tensor.ones((4,4), dtype=b_dtype), target_dtype, 2*Tensor.ones(4,4).numpy())
|
||||
|
||||
class TestFp8s(unittest.TestCase):
|
||||
def test_fp8e4m3_creation(self): assert Tensor([-1, 1, 2], dtype=dtypes.fp8e4m3).dtype == dtypes.fp8e4m3
|
||||
def test_fp8e5m2_creation(self): assert Tensor([-1, 1, 2], dtype=dtypes.fp8e5m2).dtype == dtypes.fp8e5m2
|
||||
def test_fp8e4m3fnuz_creation(self): assert Tensor([-1, 1, 2], dtype=dtypes.fp8e4m3fnuz).dtype == dtypes.fp8e4m3fnuz
|
||||
def test_fp8e5m2fnuz_creation(self): assert Tensor([-1, 1, 2], dtype=dtypes.fp8e5m2fnuz).dtype == dtypes.fp8e5m2fnuz
|
||||
|
||||
class TestFp8sConversions(unittest.TestCase):
|
||||
@given(strat.floats(width=32, allow_subnormal=True, allow_nan=False, allow_infinity=False, min_value=-FP8E4M3_MAX, max_value=FP8E4M3_MAX))
|
||||
def test_float_to_fp8e4m3(self, x):
|
||||
|
|
@ -192,25 +161,6 @@ class TestFp8sConversions(unittest.TestCase):
|
|||
def test_fp8e5m2fnuz_to_float(self, x):
|
||||
np.testing.assert_equal(fp8_to_float(x, dtypes.fp8e5m2fnuz), torch.tensor(x, dtype=torch.uint8).view(torch.float8_e5m2fnuz).float().item())
|
||||
|
||||
class TestBFloat16(unittest.TestCase):
|
||||
def test_bf16_creation_numpy(self):
|
||||
data = [-1, 1, 2]
|
||||
t = Tensor(data, dtype=dtypes.bfloat16)
|
||||
assert t.dtype == dtypes.bfloat16
|
||||
tnp = t.numpy()
|
||||
assert tnp.dtype == np.float32
|
||||
np.testing.assert_allclose(tnp, np.array(data))
|
||||
|
||||
def test_bf16_ones(self):
|
||||
t = Tensor.ones(3, 5, dtype=dtypes.bfloat16)
|
||||
assert t.dtype == dtypes.bfloat16
|
||||
np.testing.assert_allclose(t.numpy(), np.ones((3, 5)))
|
||||
|
||||
def test_bf16_eye(self):
|
||||
t = Tensor.eye(3, dtype=dtypes.bfloat16)
|
||||
assert t.dtype == dtypes.bfloat16
|
||||
np.testing.assert_allclose(t.numpy(), np.eye(3))
|
||||
|
||||
class TestBFloat16DType(unittest.TestCase):
|
||||
def test_bf16_to_float(self):
|
||||
_test_cast(Tensor([100000], dtype=dtypes.bfloat16), dtypes.float32)
|
||||
|
|
@ -419,48 +369,6 @@ class TestEmulatedFp8e5m2(TestFp8e5m2):
|
|||
@classmethod
|
||||
def tearDownClass(cls): cls.stack.close()
|
||||
|
||||
class TestPtrDType(unittest.TestCase):
|
||||
def test_vec_double(self):
|
||||
dt1 = dtypes.float.vec(4).ptr().vec(4)
|
||||
dt2 = dtypes.float.vec(4).ptr().vec(4)
|
||||
self.assertEqual(dt1, dt2)
|
||||
self.assertEqual(str(dt1), str(dt2))
|
||||
|
||||
def test_scalar(self):
|
||||
dt = dtypes.float.vec(4).ptr().scalar()
|
||||
self.assertEqual(dt.base, dtypes.float.vec(4))
|
||||
|
||||
dt = dtypes.float.vec(4).ptr().vec(4).scalar()
|
||||
self.assertEqual(dt.base, dtypes.float.vec(4))
|
||||
|
||||
dt = dtypes.float.vec(4).scalar()
|
||||
self.assertEqual(dt, dtypes.float)
|
||||
|
||||
def test_serialize(self):
|
||||
dt = dtypes.float.vec(4).ptr().vec(4)
|
||||
self.assertEqual(dt, eval(str(dt)))
|
||||
|
||||
def test_vec_ptr_sz(self):
|
||||
dt = dtypes.float.ptr(1024).vec(4)
|
||||
self.assertEqual(dt, eval(str(dt)))
|
||||
self.assertEqual(str(dt), "dtypes.float.ptr(1024).vec(4)")
|
||||
|
||||
def test_vcount(self):
|
||||
dt = dtypes.float.ptr().vec(4)
|
||||
self.assertEqual(dt.vcount, 4)
|
||||
self.assertEqual(dt.v, 4)
|
||||
self.assertEqual(dt.count, 1)
|
||||
|
||||
dt = dtypes.float.vec(4).ptr()
|
||||
self.assertEqual(dt.vcount, 1)
|
||||
self.assertEqual(dt.v, 1)
|
||||
self.assertEqual(dt.count, 4)
|
||||
|
||||
dt = dtypes.float.vec(4).ptr().vec(4)
|
||||
self.assertEqual(dt.vcount, 4)
|
||||
self.assertEqual(dt.v, 4)
|
||||
self.assertEqual(dt.count, 4)
|
||||
|
||||
class TestImplicitFunctionTypeChange(unittest.TestCase):
|
||||
def test_functions(self):
|
||||
result = []
|
||||
|
|
|
|||
|
|
@ -139,12 +139,12 @@ class TestDTypeALU(unittest.TestCase):
|
|||
@unittest.skipUnless(dtypes.bfloat16 in supported_dtypes, f"no bfloat16 on {Device.DEFAULT}")
|
||||
@given(ht.bfloat16, ht.bfloat16, strat.sampled_from(binary_operations))
|
||||
def test_bfloat16(self, a, b, op):
|
||||
universal_test(from_storage_scalar(a, dtypes.bfloat16), from_storage_scalar(a, dtypes.bfloat16), dtypes.bfloat16, op)
|
||||
universal_test(from_storage_scalar(a, dtypes.bfloat16), from_storage_scalar(b, dtypes.bfloat16), dtypes.bfloat16, op)
|
||||
|
||||
@given(ht.bfloat16, ht.bfloat16, strat.sampled_from(binary_operations))
|
||||
@Context(EMULATED_DTYPES="bfloat16")
|
||||
def test_emulated_bfloat16(self, a, b, op):
|
||||
universal_test(from_storage_scalar(a, dtypes.bfloat16), from_storage_scalar(a, dtypes.bfloat16), dtypes.bfloat16, op)
|
||||
universal_test(from_storage_scalar(a, dtypes.bfloat16), from_storage_scalar(b, dtypes.bfloat16), dtypes.bfloat16, op)
|
||||
|
||||
@unittest.skipUnless(dtypes.fp8e4m3 in supported_dtypes, f"no fp8e4m3 on {Device.DEFAULT}")
|
||||
@given(ht.fp8e4m3, ht.fp8e4m3, strat.sampled_from(binary_operations))
|
||||
|
|
|
|||
|
|
@ -10,6 +10,48 @@ class TestImageDType(unittest.TestCase):
|
|||
assert dtypes.imagef((10,10)).base.vec(4) == dtypes.float32.vec(4)
|
||||
assert dtypes.imageh((10,10)).base.vec(4) == dtypes.float32.vec(4)
|
||||
|
||||
class TestPtrDType(unittest.TestCase):
|
||||
def test_vec_double(self):
|
||||
dt1 = dtypes.float.vec(4).ptr().vec(4)
|
||||
dt2 = dtypes.float.vec(4).ptr().vec(4)
|
||||
self.assertEqual(dt1, dt2)
|
||||
self.assertEqual(str(dt1), str(dt2))
|
||||
|
||||
def test_scalar(self):
|
||||
dt = dtypes.float.vec(4).ptr().scalar()
|
||||
self.assertEqual(dt.base, dtypes.float.vec(4))
|
||||
|
||||
dt = dtypes.float.vec(4).ptr().vec(4).scalar()
|
||||
self.assertEqual(dt.base, dtypes.float.vec(4))
|
||||
|
||||
dt = dtypes.float.vec(4).scalar()
|
||||
self.assertEqual(dt, dtypes.float)
|
||||
|
||||
def test_serialize(self):
|
||||
dt = dtypes.float.vec(4).ptr().vec(4)
|
||||
self.assertEqual(dt, eval(str(dt)))
|
||||
|
||||
def test_vec_ptr_sz(self):
|
||||
dt = dtypes.float.ptr(1024).vec(4)
|
||||
self.assertEqual(dt, eval(str(dt)))
|
||||
self.assertEqual(str(dt), "dtypes.float.ptr(1024).vec(4)")
|
||||
|
||||
def test_vcount(self):
|
||||
dt = dtypes.float.ptr().vec(4)
|
||||
self.assertEqual(dt.vcount, 4)
|
||||
self.assertEqual(dt.v, 4)
|
||||
self.assertEqual(dt.count, 1)
|
||||
|
||||
dt = dtypes.float.vec(4).ptr()
|
||||
self.assertEqual(dt.vcount, 1)
|
||||
self.assertEqual(dt.v, 1)
|
||||
self.assertEqual(dt.count, 4)
|
||||
|
||||
dt = dtypes.float.vec(4).ptr().vec(4)
|
||||
self.assertEqual(dt.vcount, 4)
|
||||
self.assertEqual(dt.v, 4)
|
||||
self.assertEqual(dt.count, 4)
|
||||
|
||||
class TestEqStrDType(unittest.TestCase):
|
||||
def test_image_ne(self):
|
||||
if ImageDType is None: raise unittest.SkipTest("no ImageDType support")
|
||||
|
|
|
|||
|
|
@ -178,6 +178,12 @@ class TestHelpers(unittest.TestCase):
|
|||
elif x < -FP8E5M2_MAX: np.testing.assert_equal(truncate[dtypes.fp8e5m2](x), -FP8E5M2_MAX)
|
||||
else: np.testing.assert_equal(truncate[dtypes.fp8e5m2](x), torch.tensor(x, dtype=torch.float8_e5m2).float().item())
|
||||
|
||||
def test_finfo(self):
|
||||
for dt in [dtypes.float16, dtypes.float32, dtypes.float64]:
|
||||
info = np.finfo(_to_np_dtype(dt))
|
||||
self.assertEqual(info.bits, dt.bitsize)
|
||||
self.assertEqual((info.nexp, info.nmant), dtypes.finfo(dt))
|
||||
|
||||
class TestTypePromotion(unittest.TestCase):
|
||||
@given(strat.sampled_from(core_dtypes))
|
||||
def test_self_promo_to_self(self, dtype):
|
||||
|
|
|
|||
|
|
@ -59,7 +59,6 @@ class TestTypeSpec(unittest.TestCase):
|
|||
subprocess.run(['DEFAULT_FLOAT=TYPO python3 -c "from tinygrad import dtypes"'],
|
||||
shell=True, check=True)
|
||||
|
||||
@unittest.skipUnless(dtypes.int8 in supported_dtypes, f"no int8 on {Device.DEFAULT}")
|
||||
def test_dtype_str_arg(self):
|
||||
n = np.random.normal(0, 1, (10, 10)).astype(np.float32)
|
||||
tested = 0
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue