Compare commits

...

1 commit

Author SHA1 Message Date
George Hotz
66634c643e test unit should pass on NULL device 2026-02-03 00:11:24 +08:00
24 changed files with 527 additions and 491 deletions

425
test/test_dtype_spec.py Normal file
View file

@ -0,0 +1,425 @@
import unittest, math, operator, subprocess
from tinygrad.tensor import Tensor, dtypes, Device
from tinygrad.dtype import DType, DTYPES_DICT, least_upper_float
from tinygrad.device import is_dtype_supported
from tinygrad.helpers import getenv, DEBUG
from test.helpers import slow
from hypothesis import given, settings, strategies as strat
import numpy as np
import torch
settings.register_profile("my_profile", max_examples=50, deadline=None, derandomize=getenv("DERANDOMIZE_CI", False))
settings.load_profile("my_profile")
core_dtypes = list(DTYPES_DICT.values())
dtype_ints = [dt for dt in core_dtypes if dtypes.is_int(dt) and is_dtype_supported(dt)]
dtype_floats = [dt for dt in core_dtypes if dtypes.is_float(dt) and is_dtype_supported(dt)]
def _assert_eq(tensor:Tensor, target_dtype:DType, target, tol_target_dtype:float=1e-7):
if DEBUG >= 2: print(tensor.numpy())
try:
assert tensor.dtype == target_dtype
np.testing.assert_allclose(tensor.numpy(), target, rtol={dtypes.float16:1e-3, dtypes.bfloat16:1e-2,
dtypes.fp8e4m3:1e-1, dtypes.fp8e5m2:5e-1}.get(target_dtype, tol_target_dtype))
except AssertionError as e:
raise AssertionError(f"\ntensor {tensor.numpy()} dtype {tensor.dtype} does not match target {target} with dtype {target_dtype}") from e
class TestTypeSpec(unittest.TestCase):
def setUp(self):
self.old_default_int, self.old_default_float = dtypes.default_int, dtypes.default_float
def tearDown(self):
dtypes.default_int, dtypes.default_float = self.old_default_int, self.old_default_float
def test_set_dtype_default(self):
for default_int in [dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64]:
dtypes.default_int = default_int
assert dtypes.default_int == default_int
for default_float in [*dtypes.fp8s, dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64]:
dtypes.default_float = default_float
assert dtypes.default_float == default_float
@unittest.skip("this test is slow and spawning whole pythons")
def test_env_set_default_float(self):
# check default
subprocess.run(['python3 -c "from tinygrad import dtypes; assert dtypes.default_float == dtypes.float"'],
shell=True, check=True)
# check change
subprocess.run(['DEFAULT_FLOAT=HALF python3 -c "from tinygrad import dtypes; assert dtypes.default_float == dtypes.half"'],
shell=True, check=True)
# check invalid
with self.assertRaises(subprocess.CalledProcessError):
subprocess.run(['DEFAULT_FLOAT=INT32 python3 -c "from tinygrad import dtypes"'],
shell=True, check=True)
with self.assertRaises(subprocess.CalledProcessError):
subprocess.run(['DEFAULT_FLOAT=TYPO python3 -c "from tinygrad import dtypes"'],
shell=True, check=True)
@unittest.skipUnless(is_dtype_supported(dtypes.int8), f"no int8 on {Device.DEFAULT}")
def test_dtype_str_arg(self):
n = np.random.normal(0, 1, (10, 10)).astype(np.float32)
tested = 0
for dtype_str, dtype in [
("bool", dtypes.bool), ("int8", dtypes.int8), ("int", dtypes.int), ("uint32", dtypes.uint32), ("float32", dtypes.float32)]:
np.testing.assert_equal(Tensor(n, dtype=dtype_str).numpy(), Tensor(n, dtype=dtype).numpy())
np.testing.assert_equal(Tensor(n).cast(dtype_str).numpy(), Tensor(n).cast(dtype).numpy())
if dtype.itemsize == 4:
np.testing.assert_equal(Tensor(n).bitcast(dtype_str).numpy(), Tensor(n).bitcast(dtype).numpy())
tested += 1
assert tested == 3
with self.assertRaises(AttributeError): Tensor([1, 2, 3], dtype="nonexistdtype")
with self.assertRaises(AttributeError): Tensor([1, 2, 3], dtype="")
np.testing.assert_equal(Tensor(n).sum(dtype="int16").numpy(), Tensor(n).sum(dtype=dtypes.int16).numpy())
@given(strat.sampled_from(dtype_ints), strat.sampled_from(dtype_floats))
def test_creation(self, default_int, default_float):
dtypes.default_int, dtypes.default_float = default_int, default_float
_assert_eq(Tensor(True), dtypes.bool, True)
_assert_eq(Tensor(None), dtypes.default_float, [])
_assert_eq(Tensor(2), dtypes.default_int, 2)
_assert_eq(Tensor(2.34), dtypes.default_float, 2.34)
_assert_eq(Tensor([]), dtypes.default_float, [])
_assert_eq(Tensor([1]), dtypes.default_int, [1])
_assert_eq(Tensor([1.1]), dtypes.default_float, [1.1])
_assert_eq(Tensor.eye(0), dtypes.default_float, np.eye(0))
_assert_eq(Tensor.eye(3), dtypes.default_float, np.eye(3))
if is_dtype_supported(dtypes.int64):
_assert_eq(Tensor.eye(3, dtype=dtypes.int64), dtypes.int64, np.eye(3))
if is_dtype_supported(dtypes.float16):
_assert_eq(Tensor.eye(3, dtype=dtypes.float16), dtypes.float16, np.eye(3))
@given(strat.sampled_from(dtype_ints), strat.sampled_from(dtype_floats))
def test_full(self, default_int, default_float):
dtypes.default_int, dtypes.default_float = default_int, default_float
_assert_eq(Tensor.zeros((2, 3)), dtypes.default_float, np.zeros((2, 3)))
if is_dtype_supported(dtypes.int64):
_assert_eq(Tensor.zeros((2, 3), dtype=dtypes.int64), dtypes.int64, np.zeros((2, 3)))
if is_dtype_supported(dtypes.float16):
_assert_eq(Tensor.zeros((2, 3), dtype=dtypes.float16), dtypes.float16, np.zeros((2, 3)))
_assert_eq(Tensor.ones((2, 3)), dtypes.default_float, np.ones((2, 3)))
if is_dtype_supported(dtypes.int64):
_assert_eq(Tensor.ones((2, 3), dtype=dtypes.int64), dtypes.int64, np.ones((2, 3)))
if is_dtype_supported(dtypes.float16):
_assert_eq(Tensor.ones((2, 3), dtype=dtypes.float16), dtypes.float16, np.ones((2, 3)))
_assert_eq(Tensor.full((2, 3), 3.0), dtypes.default_float, np.full((2, 3), 3.0))
_assert_eq(Tensor.full((2, 3), 3), dtypes.default_int, np.full((2, 3), 3))
_assert_eq(Tensor.full((2, 3), True), dtypes.bool, np.full((2, 3), True))
if is_dtype_supported(dtypes.int64):
_assert_eq(Tensor.full((2, 3), 3, dtype=dtypes.int64), dtypes.int64, np.full((2, 3), 3))
_assert_eq(Tensor.full((2, 3), 3.0, dtype=dtypes.int64), dtypes.int64, np.full((2, 3), 3))
if is_dtype_supported(dtypes.float16):
_assert_eq(Tensor.full((2, 3), 3, dtype=dtypes.float16), dtypes.float16, np.full((2, 3), 3))
_assert_eq(Tensor.full((2, 3), 3.0, dtype=dtypes.float16), dtypes.float16, np.full((2, 3), 3))
@given(strat.sampled_from(dtype_ints), strat.sampled_from(dtype_floats))
def test_reduce_0d_default(self, default_int, default_float):
dtypes.default_int, dtypes.default_float = default_int, default_float
_assert_eq(Tensor.ones((2,3,0)).sum(2), dtypes.default_float, np.zeros((2, 3)))
# TODO: what should this one be?
# _assert_eq(Tensor.ones((2,3,0), dtype=dtypes.default_int).sum(2), dtypes.default_int, np.zeros((2, 3)))
_assert_eq(Tensor.ones((2,3,0), dtype=dtypes.int32).sum(2), dtypes.int32, np.zeros((2, 3)))
@given(strat.sampled_from(dtype_ints), strat.sampled_from(dtype_floats))
def test_arange(self, default_int, default_float):
dtypes.default_int, dtypes.default_float = default_int, default_float
_assert_eq(Tensor.arange(5), dtypes.default_int, np.arange(5))
_assert_eq(Tensor.arange(120), dtypes.default_int, np.arange(120))
_assert_eq(Tensor.arange(5.0), dtypes.default_float, np.arange(5))
if is_dtype_supported(dtypes.int16):
_assert_eq(Tensor.arange(5, dtype=dtypes.int16), dtypes.int16, np.arange(5))
if is_dtype_supported(dtypes.int64):
_assert_eq(Tensor.arange(5, dtype=dtypes.int64), dtypes.int64, np.arange(5))
if is_dtype_supported(dtypes.float16):
_assert_eq(Tensor.arange(5, dtype=dtypes.float16), dtypes.float16, np.arange(5))
_assert_eq(Tensor.arange(3, 9, 0.7), dtypes.default_float, np.arange(3, 9, 0.7), 1e-6 if Device.DEFAULT == "WEBGPU" else 1e-7)
_assert_eq(Tensor.arange(3, 8.5, 3), dtypes.default_float, np.arange(3, 8.5, 3))
# stop-start and step have different signs
_assert_eq(Tensor.arange(3, 5, -2), dtypes.default_int, np.arange(3, 5, -2))
_assert_eq(Tensor.arange(5.0, 3.0), dtypes.default_float, np.arange(5.0, 3.0))
@given(strat.sampled_from(core_dtypes), strat.sampled_from([operator.gt, operator.ge, operator.le, operator.lt, operator.eq, operator.ne]))
def test_bool_ops(self, dtype, op):
assert op(Tensor.ones(4, 4, dtype=dtype), Tensor.ones(4, 4, dtype=dtype)).dtype == dtypes.bool
@given(strat.sampled_from(core_dtypes), strat.sampled_from(dtype_ints), strat.sampled_from(dtype_floats))
def test_functions_return_index(self, dtype, default_int, default_float):
dtypes.default_int, dtypes.default_float = default_int, default_float
assert Tensor([0, 1], dtype=dtype).argmax().dtype == dtypes.int32
assert Tensor([0, 1], dtype=dtype).argmin().dtype == dtypes.int32
assert Tensor([0, 1], dtype=dtype).multinomial().dtype == dtypes.int32
@given(strat.sampled_from(core_dtypes), strat.sampled_from(dtype_ints))
def test_tensor_indexing_returns_same_dtype(self, data_dtype, indices_dtype):
X_data = Tensor.ones(60000, 1, 28, 28, dtype=data_dtype)
indices = Tensor.randint(512, high=X_data.shape[0]).cast(indices_dtype)
assert X_data[indices].dtype == X_data.dtype
@given(strat.sampled_from(core_dtypes), strat.sampled_from(dtype_ints))
def test_gather_returns_same_dtype(self, data_dtype, indices_dtype):
X_data = Tensor([[1, 0], [0, 1]], dtype=data_dtype)
indices = Tensor([[0, 0], [1, 0]], dtype=indices_dtype)
assert X_data.gather(0, indices).dtype == X_data.dtype
assert X_data.gather(1, indices).dtype == X_data.dtype
@given(strat.sampled_from(dtype_floats), strat.sampled_from(dtype_floats))
def test_attention_returns_same_dtype(self, data_dtype, default_float):
dtypes.default_float = default_float
query = Tensor.rand(32, 8, 128, 64, dtype=data_dtype)
key = Tensor.rand(32, 8, 128, 64, dtype=data_dtype)
value = Tensor.rand(32, 8, 128, 64, dtype=data_dtype)
mask = (Tensor.rand(32, 8, 128, 128) < 0.5)
assert query.scaled_dot_product_attention(key, value, is_causal=True).dtype == data_dtype
assert query.scaled_dot_product_attention(key, value, is_causal=True, dropout_p=0.3).dtype == data_dtype
assert query.scaled_dot_product_attention(key, value, is_causal=False).dtype == data_dtype
assert query.scaled_dot_product_attention(key, value, attn_mask=mask).dtype == data_dtype
class TestAutoCastType(unittest.TestCase):
def setUp(self):
self.old_default_int, self.old_default_float = dtypes.default_int, dtypes.default_float
def tearDown(self):
dtypes.default_int, dtypes.default_float = self.old_default_int, self.old_default_float
@given(strat.sampled_from(dtype_floats), strat.sampled_from(dtype_floats))
def test_least_upper_float_input_is_float(self, input_dtype, default_float):
dtypes.default_float = default_float
self.assertEqual(least_upper_float(input_dtype), input_dtype)
@given(strat.sampled_from(dtype_ints), strat.sampled_from(dtype_floats))
def test_least_upper_float_input_is_int(self, input_dtype, default_float):
dtypes.default_float = default_float
self.assertEqual(least_upper_float(input_dtype), default_float)
@given(strat.sampled_from([d for d in core_dtypes if dtypes.is_int(d) and is_dtype_supported(d)]))
def test_int_to_float_unary_func(self, dtype):
for func in [
lambda t: t.exp(),
lambda t: t.exp2(),
lambda t: t.log(),
lambda t: t.log2(),
lambda t: t.sqrt(),
lambda t: t.rsqrt(),
lambda t: t.sin(),
lambda t: t.cos(),
lambda t: t.tan(),
lambda t: t.sigmoid(),
]:
a = [2, 3, 4]
# float16 can have larger precision errors
np.testing.assert_allclose(func(Tensor(a, dtype=dtype)).numpy(), func(torch.tensor(a)), rtol=1e-3, atol=1e-3)
@given(strat.sampled_from(core_dtypes))
def test_broadcast_scalar(self, dt):
assert (Tensor.ones(4, 4, dtype=dt) + 2.3).dtype == (dt if dtypes.is_float(dt) else dtypes.default_float)
assert (Tensor.ones(4, 4, dtype=dt) + 2).dtype == (dt if dtypes.is_float(dt) or dtypes.is_int(dt) else dtypes.default_int)
assert (Tensor.ones(4, 4, dtype=dt) + True).dtype == dt
@given(strat.sampled_from(dtype_floats))
def test_int_div_int(self, default_float):
dtypes.default_float = default_float
self.assertEqual(Tensor([1]).div(Tensor([2])).dtype, default_float)
def test_sum(self):
assert (Tensor([0, 1], dtype=dtypes.bool)).sum().dtype == dtypes.int32
assert (Tensor([0, 1], dtype=dtypes.int8)).sum().dtype == dtypes.int32
assert (Tensor([0, 1], dtype=dtypes.int16)).sum().dtype == dtypes.int32
assert (Tensor([0, 1], dtype=dtypes.int32)).sum().dtype == dtypes.int32
assert (Tensor([0, 1], dtype=dtypes.int64)).sum().dtype == dtypes.int64
assert (Tensor([0, 1], dtype=dtypes.uint8)).sum().dtype == dtypes.uint32
assert (Tensor([0, 1], dtype=dtypes.uint16)).sum().dtype == dtypes.uint32
assert (Tensor([0, 1], dtype=dtypes.uint32)).sum().dtype == dtypes.uint32
assert (Tensor([0, 1], dtype=dtypes.uint64)).sum().dtype == dtypes.uint64
assert (Tensor([0, 1], dtype=dtypes.fp8e4m3)).sum().dtype == dtypes.fp8e4m3
assert (Tensor([0, 1], dtype=dtypes.fp8e5m2)).sum().dtype == dtypes.fp8e5m2
assert (Tensor([0, 1], dtype=dtypes.float16)).sum().dtype == dtypes.float16
assert (Tensor([0, 1], dtype=dtypes.bfloat16)).sum().dtype == dtypes.bfloat16
assert (Tensor([0, 1], dtype=dtypes.float32)).sum().dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.float64)).sum().dtype == dtypes.float64
@unittest.skipUnless(is_dtype_supported(dtypes.float16), "need float16")
def test_sum_dtype_arg(self):
t = Tensor([40000, 40000], dtype=dtypes.float16)
# default float16 sum returns in float16, overflowed in this case
assert t.sum().dtype == dtypes.float16
assert math.isinf(t.sum().numpy().item())
# specifiying dtype and it's not downcasted
assert t.sum(dtype=dtypes.float32).dtype == dtypes.float32
np.testing.assert_allclose(t.sum(dtype=dtypes.float32).numpy(), 80000)
def test_prod_dtype_arg(self):
t = Tensor([100, 200], dtype=dtypes.int32)
assert t.prod().dtype == dtypes.int32
np.testing.assert_allclose(t.prod().numpy(), 20000)
assert t.prod(dtype=dtypes.float32).dtype == dtypes.float32
np.testing.assert_allclose(t.prod(dtype=dtypes.float32).numpy(), 20000)
def test_mean(self):
assert (Tensor([0, 1], dtype=dtypes.bool)).mean().dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.int8)).mean().dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.int16)).mean().dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.int32)).mean().dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.int64)).mean().dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.uint8)).mean().dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.uint16)).mean().dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.uint32)).mean().dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.uint64)).mean().dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.fp8e4m3)).mean().dtype == dtypes.fp8e4m3
assert (Tensor([0, 1], dtype=dtypes.fp8e5m2)).mean().dtype == dtypes.fp8e5m2
assert (Tensor([0, 1], dtype=dtypes.float16)).mean().dtype == dtypes.float16
assert (Tensor([0, 1], dtype=dtypes.bfloat16)).mean().dtype == dtypes.bfloat16
assert (Tensor([0, 1], dtype=dtypes.float32)).mean().dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.float64)).mean().dtype == dtypes.float64
def test_cumsum(self):
assert (Tensor([0, 1], dtype=dtypes.bool)).cumsum(0).dtype == dtypes.int32
assert (Tensor([0, 1], dtype=dtypes.int8)).cumsum(0).dtype == dtypes.int32
assert (Tensor([0, 1], dtype=dtypes.int16)).cumsum(0).dtype == dtypes.int32
assert (Tensor([0, 1], dtype=dtypes.int32)).cumsum(0).dtype == dtypes.int32
assert (Tensor([0, 1], dtype=dtypes.int64)).cumsum(0).dtype == dtypes.int64
assert (Tensor([0, 1], dtype=dtypes.uint8)).cumsum(0).dtype == dtypes.uint32
assert (Tensor([0, 1], dtype=dtypes.uint16)).cumsum(0).dtype == dtypes.uint32
assert (Tensor([0, 1], dtype=dtypes.uint32)).cumsum(0).dtype == dtypes.uint32
assert (Tensor([0, 1], dtype=dtypes.uint64)).cumsum(0).dtype == dtypes.uint64
assert (Tensor([0, 1], dtype=dtypes.fp8e4m3)).cumsum(0).dtype == dtypes.fp8e4m3
assert (Tensor([0, 1], dtype=dtypes.fp8e5m2)).cumsum(0).dtype == dtypes.fp8e5m2
assert (Tensor([0, 1], dtype=dtypes.float16)).cumsum(0).dtype == dtypes.float16
assert (Tensor([0, 1], dtype=dtypes.bfloat16)).cumsum(0).dtype == dtypes.bfloat16
assert (Tensor([0, 1], dtype=dtypes.float32)).cumsum(0).dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.float64)).cumsum(0).dtype == dtypes.float64
@given(strat.sampled_from(core_dtypes), strat.sampled_from(core_dtypes), strat.sampled_from(core_dtypes))
def test_matmul(self, dt1, dt2, acc_dt):
t1 = Tensor([0, 1], dtype=dt1)
t2 = Tensor([0, 1], dtype=dt2)
from tinygrad.dtype import least_upper_dtype
self.assertEqual(t1.matmul(t2).dtype, least_upper_dtype(t1.dtype, t2.dtype))
# if dtype is specified, return in dtype
self.assertEqual(t1.matmul(t2, dtype=acc_dt).dtype, acc_dt)
@given(strat.sampled_from(core_dtypes), strat.sampled_from(core_dtypes), strat.sampled_from(core_dtypes), strat.sampled_from(core_dtypes))
def test_linear(self, dt1, dt2, dt3, acc_dt):
x = Tensor([0, 1], dtype=dt1)
w = Tensor([0, 1], dtype=dt2)
b = Tensor([0, 1], dtype=dt3)
from tinygrad.dtype import least_upper_dtype
self.assertEqual(x.linear(w).dtype, least_upper_dtype(x.dtype, w.dtype))
self.assertEqual(x.linear(w, b).dtype, least_upper_dtype(least_upper_dtype(x.dtype, w.dtype), b.dtype))
# if dtype is specified, return in dtype
self.assertEqual(x.linear(w, dtype=acc_dt).dtype, acc_dt)
self.assertEqual(x.linear(w, b, dtype=acc_dt).dtype, acc_dt)
@staticmethod
def check_where_alternate_input_other(input_, other, data_type):
assert (Tensor([True, False]).where(input_, other)).dtype == data_type
assert (Tensor([True, False]).where(other, input_)).dtype == data_type
@given(strat.sampled_from(core_dtypes), strat.sampled_from(core_dtypes))
def test_where_no_scalar(self, dt1, dt2):
from tinygrad.dtype import least_upper_dtype
self.check_where_alternate_input_other(Tensor(2, dtype=dt1), Tensor(3, dtype=dt2), least_upper_dtype(dt1, dt2))
@given(strat.sampled_from(core_dtypes))
def test_where_one_scalar(self, dt):
t = Tensor(2, dtype=dt)
self.check_where_alternate_input_other(t, 3.2, (dt if dtypes.is_float(dt) else dtypes.default_float))
self.check_where_alternate_input_other(t, 3, (dt if dtypes.is_float(dt) or dtypes.is_int(dt) else dtypes.default_int))
self.check_where_alternate_input_other(t, True, dt)
def test_where_two_scalars(self):
self.check_where_alternate_input_other(3.1, 3.2, dtypes.default_float)
self.check_where_alternate_input_other(3.1, 3, dtypes.default_float)
self.check_where_alternate_input_other(3.1, True, dtypes.default_float)
self.check_where_alternate_input_other(3, 2, dtypes.default_int)
self.check_where_alternate_input_other(3, True, dtypes.default_int)
self.check_where_alternate_input_other(False, True, dtypes.bool)
@given(strat.sampled_from(core_dtypes), strat.sampled_from(core_dtypes))
def test_maximum(self, dt1, dt2):
from tinygrad.dtype import least_upper_dtype
assert Tensor([0, 1, 2], dtype=dt1).maximum(Tensor([2, 0, 5], dtype=dt2)).dtype == least_upper_dtype(dt1, dt2)
@given(strat.sampled_from(core_dtypes))
def test_maximum_const(self, dt):
assert Tensor([1, 2], dtype=dt).maximum(3.1).dtype == (dt if dtypes.is_float(dt) else dtypes.default_float)
assert Tensor([1, 2], dtype=dt).maximum(3).dtype == (dt if dtypes.is_float(dt) or dtypes.is_int(dt) else dtypes.default_int)
assert Tensor([1, 2], dtype=dt).maximum(True).dtype == dt
def test_div(self):
assert (Tensor([1, 2], dtype=dtypes.int32) / Tensor([2, 2], dtype=dtypes.int32)).dtype == dtypes.default_float
assert (Tensor([1, 2], dtype=dtypes.int16) / Tensor([2, 2], dtype=dtypes.int32)).dtype == dtypes.default_float
assert (Tensor([1, 2], dtype=dtypes.float32) / Tensor([2, 2], dtype=dtypes.float16)).dtype == dtypes.float32
assert (Tensor([1, 2], dtype=dtypes.int32) / Tensor([2, 2], dtype=dtypes.float16)).dtype == dtypes.float16
def test_div_const(self):
assert (Tensor([1, 2], dtype=dtypes.int32) / 2).dtype == dtypes.default_float
assert (Tensor([1, 2], dtype=dtypes.int32) / 2.0).dtype == dtypes.default_float
assert (Tensor([1, 2], dtype=dtypes.float16) / 2).dtype == dtypes.float16
assert (Tensor([1, 2], dtype=dtypes.float16) / 2.0).dtype == dtypes.float16
def test_gradient_dtype(self):
old_default_float = dtypes.default_float
for default_dtype in dtypes.floats:
if not is_dtype_supported(default_dtype): continue
dtypes.default_float = default_dtype
for dtype in dtypes.floats:
if not is_dtype_supported(dtype): continue
if DEBUG >= 2:
print(f"testing {default_dtype=}, {dtype=}")
a = Tensor([1, 2, 3], dtype=dtype, requires_grad=True)
b = (a * 5).sum()
b.backward() # if there is dtype mismatch, lazy should assert
assert a.grad.dtype == a.dtype
np.testing.assert_allclose(a.grad.numpy(), [5, 5, 5])
dtypes.default_float = old_default_float
@unittest.skipIf(Device.DEFAULT == "PYTHON", "very slow")
@slow
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "Binding size is larger than the maximum storage buffer binding size")
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
def test_mean_half_precision_underflow(self):
N = 10000
x = 0.001
t = Tensor([[x]], dtype=dtypes.half, requires_grad=True).expand(N, N).contiguous()
np.testing.assert_allclose(t.mean(axis=1).numpy(), np.array([x] * N, dtype=np.float16), rtol=1e-3)
@unittest.skip("this test only works with SPLIT_REDUCEOP=1")
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
def test_mean_half_precision_overflow(self):
N = 256
t = Tensor([60000] * N*N, dtype=dtypes.half, requires_grad=True).reshape(N, N)
np.testing.assert_allclose(t.mean().numpy(), 60000)
t.square().mean().backward()
np.testing.assert_allclose(t.grad.numpy().flatten(), [60000 * 2 / (N*N)] * N*N)
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "Precision error")
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
def test_softmax_dtype(self):
data = [1, 2, 3]
t = Tensor(data, dtype=dtypes.half)
tt = torch.tensor(data, dtype=torch.half)
out = t.softmax(0)
self.assertEqual(out.dtype, dtypes.half)
np.testing.assert_allclose(out.numpy(), tt.softmax(0).numpy(), rtol=1e-3)
out = t.softmax(0, dtype=dtypes.float)
self.assertEqual(out.dtype, dtypes.float)
np.testing.assert_allclose(out.numpy(), tt.softmax(0, dtype=torch.float).numpy(), rtol=1e-3)
out = t.log_softmax(0)
self.assertEqual(out.dtype, dtypes.half)
np.testing.assert_allclose(out.numpy(), tt.log_softmax(0).numpy(), rtol=1e-3)
out = t.log_softmax(0, dtype=dtypes.float)
self.assertEqual(out.dtype, dtypes.float)
np.testing.assert_allclose(out.numpy(), tt.log_softmax(0, dtype=torch.float).numpy(), rtol=1e-3)
if __name__ == '__main__':
unittest.main()

85
test/test_gradient.py Normal file
View file

@ -0,0 +1,85 @@
import unittest
import numpy as np
from tinygrad import Tensor
from tinygrad.dtype import dtypes
class TestTensorGradient(unittest.TestCase):
def test_example(self):
x = Tensor.eye(3)
y = Tensor([[2.0,0,-2.0]])
z = y.matmul(x).sum()
dx, dy = z.gradient(x, y)
self.assertListEqual(dx.tolist(), [[2.0, 2.0, 2.0], [0.0, 0.0, 0.0], [-2.0, -2.0, -2.0]])
self.assertListEqual(dy.tolist(), [[1.0, 1.0, 1.0]])
def test_raises(self):
x = Tensor([1.0, 2.0, 3.0])
w = Tensor.randn((3,))
with self.assertRaises(RuntimeError): x.sum().gradient(w)
def test_with_custom_gradient(self):
x = Tensor([1.0, 2.0, 3.0])
z = (x * x).sum()
dx = z.gradient(x, gradient=Tensor([3.0]))[0]
self.assertListEqual(dx.tolist(), [6.0, 12.0, 18.0])
def test_broadcast_gradient(self):
x = Tensor([[1.0], [2.0], [3.0]])
y = Tensor([[10.0, 20.0, 30.0, 40.0]])
z = (x + y).sum()
dx, dy = z.gradient(x, y)
self.assertListEqual(dx.tolist(), [[4.0], [4.0], [4.0]])
self.assertListEqual(dy.tolist(), [[3.0, 3.0, 3.0, 3.0]])
def test_non_scalar_output(self):
x = Tensor([1.0, 2.0, 3.0])
z = x * x
with self.assertRaises(AssertionError): z.gradient(x)
dz = Tensor([1.0, 1.0, 1.0])
dx = z.gradient(x, gradient=dz)[0]
self.assertListEqual(dx.tolist(), [2.0, 4.0, 6.0])
def test_cast_before_view(self):
x = Tensor([1.0, 1, 1, 1])
x_reshaped = x.reshape(2,2)
x_casted = x_reshaped.cast(dtypes.float16)
x_casted.mean().gradient(x_reshaped)
def test_non_float_tensor_raise(self):
x = Tensor([1, 2, 3])
with self.assertRaises(RuntimeError): x.sum().gradient(x)
with self.assertRaises(RuntimeError): x.float().sum().gradient(x)
def test_copy_to_device_gradient(self):
t = Tensor([1.0, 2, 3], requires_grad=True).realize()
t.to("CPU:1").square().sum().backward()
self.assertEqual(t.grad.device, t.device)
self.assertListEqual(t.grad.tolist(), [2.0, 4.0, 6.0])
def test_multiple_backward(self):
x = Tensor([3.], requires_grad=True)
(x*2)[0].backward()
np.testing.assert_allclose(x.grad.numpy(), [2.0])
old_grad = x.grad
(x*3)[0].backward()
np.testing.assert_allclose(x.grad.numpy(), [2.0+3.0])
self.assertIs(x.grad, old_grad)
(x*x)[0].backward()
np.testing.assert_allclose(x.grad.numpy(), [2.0+3.0+2*3.0])
self.assertIs(x.grad, old_grad)
class TestViewGradient(unittest.TestCase):
def test_expand(self):
# this test shows that if Tensors collapse to the views and create a disconnected graph
# there's no way to recover the proper gradient
x = Tensor.randn(5,2)
a = Tensor([3.], requires_grad=True)
aex = a.expand(10)
(aex.reshape(5,2) * x).sum().backward()
np.testing.assert_allclose(aex.grad.numpy(), x.reshape(10).numpy())
# NOTE: aex.grad is *not* a.grad.expand(10)!
with self.assertRaises(AssertionError):
np.testing.assert_allclose(aex.grad.numpy(), a.grad.expand(10).numpy())
if __name__ == '__main__':
unittest.main()

11
test/test_helpers.py Normal file
View file

@ -0,0 +1,11 @@
import unittest
import numpy as np
from tinygrad.helpers import polyN
class TestPolyN(unittest.TestCase):
def test_tensor(self):
from tinygrad.tensor import Tensor
np.testing.assert_allclose(polyN(Tensor([1.0, 2.0, 3.0, 4.0]), [1.0, -2.0, 1.0]).numpy(), [0.0, 1.0, 4.0, 9.0])
if __name__ == '__main__':
unittest.main()

View file

@ -1,9 +1,8 @@
import unittest, math, operator, subprocess, struct
from tinygrad.tensor import Tensor, dtypes, Device
from tinygrad.dtype import DType, DTYPES_DICT, truncate, float_to_fp16, float_to_bf16, _to_np_dtype, least_upper_dtype, least_upper_float
import unittest, math, struct
from tinygrad.tensor import dtypes
from tinygrad.dtype import DTYPES_DICT, truncate, float_to_fp16, float_to_bf16, _to_np_dtype, least_upper_dtype
from tinygrad.device import is_dtype_supported
from tinygrad.helpers import getenv, DEBUG
from test.helpers import slow
from tinygrad.helpers import getenv
from hypothesis import given, settings, strategies as strat
import numpy as np
import torch
@ -12,22 +11,10 @@ settings.register_profile("my_profile", max_examples=50, deadline=None, derandom
settings.load_profile("my_profile")
core_dtypes = list(DTYPES_DICT.values())
dtype_ints = [dt for dt in core_dtypes if dtypes.is_int(dt) and is_dtype_supported(dt)]
dtype_floats = [dt for dt in core_dtypes if dtypes.is_float(dt) and is_dtype_supported(dt)]
FP8E4M3_MAX = 448.0
FP8E5M2_MAX = 57344.0
def _assert_eq(tensor:Tensor, target_dtype:DType, target, tol_target_dtype:float=1e-7):
if DEBUG >= 2: print(tensor.numpy())
try:
assert tensor.dtype == target_dtype
np.testing.assert_allclose(tensor.numpy(), target, rtol={dtypes.float16:1e-3, dtypes.bfloat16:1e-2,
dtypes.fp8e4m3:1e-1, dtypes.fp8e5m2:5e-1}.get(target_dtype, tol_target_dtype))
except AssertionError as e:
raise AssertionError(f"\ntensor {tensor.numpy()} dtype {tensor.dtype} does not match target {target} with dtype {target_dtype}") from e
def u32_to_f32(u): return struct.unpack('f', struct.pack('I', u))[0]
def f32_to_u32(f): return struct.unpack('I', struct.pack('f', f))[0]
@ -202,163 +189,6 @@ class TestHelpers(unittest.TestCase):
elif x < -FP8E5M2_MAX: np.testing.assert_equal(truncate[dtypes.fp8e5m2](x), -FP8E5M2_MAX)
else: np.testing.assert_equal(truncate[dtypes.fp8e5m2](x), torch.tensor(x, dtype=torch.float8_e5m2).float().item())
class TestTypeSpec(unittest.TestCase):
def setUp(self):
self.old_default_int, self.old_default_float = dtypes.default_int, dtypes.default_float
def tearDown(self):
dtypes.default_int, dtypes.default_float = self.old_default_int, self.old_default_float
def test_set_dtype_default(self):
for default_int in [dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64]:
dtypes.default_int = default_int
assert dtypes.default_int == default_int
for default_float in [*dtypes.fp8s, dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64]:
dtypes.default_float = default_float
assert dtypes.default_float == default_float
@unittest.skip("this test is slow and spawning whole pythons")
def test_env_set_default_float(self):
# check default
subprocess.run(['python3 -c "from tinygrad import dtypes; assert dtypes.default_float == dtypes.float"'],
shell=True, check=True)
# check change
subprocess.run(['DEFAULT_FLOAT=HALF python3 -c "from tinygrad import dtypes; assert dtypes.default_float == dtypes.half"'],
shell=True, check=True)
# check invalid
with self.assertRaises(subprocess.CalledProcessError):
subprocess.run(['DEFAULT_FLOAT=INT32 python3 -c "from tinygrad import dtypes"'],
shell=True, check=True)
with self.assertRaises(subprocess.CalledProcessError):
subprocess.run(['DEFAULT_FLOAT=TYPO python3 -c "from tinygrad import dtypes"'],
shell=True, check=True)
@unittest.skipUnless(is_dtype_supported(dtypes.int8), f"no int8 on {Device.DEFAULT}")
def test_dtype_str_arg(self):
n = np.random.normal(0, 1, (10, 10)).astype(np.float32)
tested = 0
for dtype_str, dtype in [
("bool", dtypes.bool), ("int8", dtypes.int8), ("int", dtypes.int), ("uint32", dtypes.uint32), ("float32", dtypes.float32)]:
np.testing.assert_equal(Tensor(n, dtype=dtype_str).numpy(), Tensor(n, dtype=dtype).numpy())
np.testing.assert_equal(Tensor(n).cast(dtype_str).numpy(), Tensor(n).cast(dtype).numpy())
if dtype.itemsize == 4:
np.testing.assert_equal(Tensor(n).bitcast(dtype_str).numpy(), Tensor(n).bitcast(dtype).numpy())
tested += 1
assert tested == 3
with self.assertRaises(AttributeError): Tensor([1, 2, 3], dtype="nonexistdtype")
with self.assertRaises(AttributeError): Tensor([1, 2, 3], dtype="")
np.testing.assert_equal(Tensor(n).sum(dtype="int16").numpy(), Tensor(n).sum(dtype=dtypes.int16).numpy())
@given(strat.sampled_from(dtype_ints), strat.sampled_from(dtype_floats))
def test_creation(self, default_int, default_float):
dtypes.default_int, dtypes.default_float = default_int, default_float
_assert_eq(Tensor(True), dtypes.bool, True)
_assert_eq(Tensor(None), dtypes.default_float, [])
_assert_eq(Tensor(2), dtypes.default_int, 2)
_assert_eq(Tensor(2.34), dtypes.default_float, 2.34)
_assert_eq(Tensor([]), dtypes.default_float, [])
_assert_eq(Tensor([1]), dtypes.default_int, [1])
_assert_eq(Tensor([1.1]), dtypes.default_float, [1.1])
_assert_eq(Tensor.eye(0), dtypes.default_float, np.eye(0))
_assert_eq(Tensor.eye(3), dtypes.default_float, np.eye(3))
if is_dtype_supported(dtypes.int64):
_assert_eq(Tensor.eye(3, dtype=dtypes.int64), dtypes.int64, np.eye(3))
if is_dtype_supported(dtypes.float16):
_assert_eq(Tensor.eye(3, dtype=dtypes.float16), dtypes.float16, np.eye(3))
@given(strat.sampled_from(dtype_ints), strat.sampled_from(dtype_floats))
def test_full(self, default_int, default_float):
dtypes.default_int, dtypes.default_float = default_int, default_float
_assert_eq(Tensor.zeros((2, 3)), dtypes.default_float, np.zeros((2, 3)))
if is_dtype_supported(dtypes.int64):
_assert_eq(Tensor.zeros((2, 3), dtype=dtypes.int64), dtypes.int64, np.zeros((2, 3)))
if is_dtype_supported(dtypes.float16):
_assert_eq(Tensor.zeros((2, 3), dtype=dtypes.float16), dtypes.float16, np.zeros((2, 3)))
_assert_eq(Tensor.ones((2, 3)), dtypes.default_float, np.ones((2, 3)))
if is_dtype_supported(dtypes.int64):
_assert_eq(Tensor.ones((2, 3), dtype=dtypes.int64), dtypes.int64, np.ones((2, 3)))
if is_dtype_supported(dtypes.float16):
_assert_eq(Tensor.ones((2, 3), dtype=dtypes.float16), dtypes.float16, np.ones((2, 3)))
_assert_eq(Tensor.full((2, 3), 3.0), dtypes.default_float, np.full((2, 3), 3.0))
_assert_eq(Tensor.full((2, 3), 3), dtypes.default_int, np.full((2, 3), 3))
_assert_eq(Tensor.full((2, 3), True), dtypes.bool, np.full((2, 3), True))
if is_dtype_supported(dtypes.int64):
_assert_eq(Tensor.full((2, 3), 3, dtype=dtypes.int64), dtypes.int64, np.full((2, 3), 3))
_assert_eq(Tensor.full((2, 3), 3.0, dtype=dtypes.int64), dtypes.int64, np.full((2, 3), 3))
if is_dtype_supported(dtypes.float16):
_assert_eq(Tensor.full((2, 3), 3, dtype=dtypes.float16), dtypes.float16, np.full((2, 3), 3))
_assert_eq(Tensor.full((2, 3), 3.0, dtype=dtypes.float16), dtypes.float16, np.full((2, 3), 3))
@given(strat.sampled_from(dtype_ints), strat.sampled_from(dtype_floats))
def test_reduce_0d_default(self, default_int, default_float):
dtypes.default_int, dtypes.default_float = default_int, default_float
_assert_eq(Tensor.ones((2,3,0)).sum(2), dtypes.default_float, np.zeros((2, 3)))
# TODO: what should this one be?
# _assert_eq(Tensor.ones((2,3,0), dtype=dtypes.default_int).sum(2), dtypes.default_int, np.zeros((2, 3)))
_assert_eq(Tensor.ones((2,3,0), dtype=dtypes.int32).sum(2), dtypes.int32, np.zeros((2, 3)))
@given(strat.sampled_from(dtype_ints), strat.sampled_from(dtype_floats))
def test_arange(self, default_int, default_float):
dtypes.default_int, dtypes.default_float = default_int, default_float
_assert_eq(Tensor.arange(5), dtypes.default_int, np.arange(5))
_assert_eq(Tensor.arange(120), dtypes.default_int, np.arange(120))
_assert_eq(Tensor.arange(5.0), dtypes.default_float, np.arange(5))
if is_dtype_supported(dtypes.int16):
_assert_eq(Tensor.arange(5, dtype=dtypes.int16), dtypes.int16, np.arange(5))
if is_dtype_supported(dtypes.int64):
_assert_eq(Tensor.arange(5, dtype=dtypes.int64), dtypes.int64, np.arange(5))
if is_dtype_supported(dtypes.float16):
_assert_eq(Tensor.arange(5, dtype=dtypes.float16), dtypes.float16, np.arange(5))
_assert_eq(Tensor.arange(3, 9, 0.7), dtypes.default_float, np.arange(3, 9, 0.7), 1e-6 if Device.DEFAULT == "WEBGPU" else 1e-7)
_assert_eq(Tensor.arange(3, 8.5, 3), dtypes.default_float, np.arange(3, 8.5, 3))
# stop-start and step have different signs
_assert_eq(Tensor.arange(3, 5, -2), dtypes.default_int, np.arange(3, 5, -2))
_assert_eq(Tensor.arange(5.0, 3.0), dtypes.default_float, np.arange(5.0, 3.0))
@given(strat.sampled_from(core_dtypes), strat.sampled_from([operator.gt, operator.ge, operator.le, operator.lt, operator.eq, operator.ne]))
def test_bool_ops(self, dtype, op):
assert op(Tensor.ones(4, 4, dtype=dtype), Tensor.ones(4, 4, dtype=dtype)).dtype == dtypes.bool
@given(strat.sampled_from(core_dtypes), strat.sampled_from(dtype_ints), strat.sampled_from(dtype_floats))
def test_functions_return_index(self, dtype, default_int, default_float):
dtypes.default_int, dtypes.default_float = default_int, default_float
assert Tensor([0, 1], dtype=dtype).argmax().dtype == dtypes.int32
assert Tensor([0, 1], dtype=dtype).argmin().dtype == dtypes.int32
assert Tensor([0, 1], dtype=dtype).multinomial().dtype == dtypes.int32
@given(strat.sampled_from(core_dtypes), strat.sampled_from(dtype_ints))
def test_tensor_indexing_returns_same_dtype(self, data_dtype, indices_dtype):
X_data = Tensor.ones(60000, 1, 28, 28, dtype=data_dtype)
indices = Tensor.randint(512, high=X_data.shape[0]).cast(indices_dtype)
assert X_data[indices].dtype == X_data.dtype
@given(strat.sampled_from(core_dtypes), strat.sampled_from(dtype_ints))
def test_gather_returns_same_dtype(self, data_dtype, indices_dtype):
X_data = Tensor([[1, 0], [0, 1]], dtype=data_dtype)
indices = Tensor([[0, 0], [1, 0]], dtype=indices_dtype)
assert X_data.gather(0, indices).dtype == X_data.dtype
assert X_data.gather(1, indices).dtype == X_data.dtype
@given(strat.sampled_from(dtype_floats), strat.sampled_from(dtype_floats))
def test_attention_returns_same_dtype(self, data_dtype, default_float):
dtypes.default_float = default_float
query = Tensor.rand(32, 8, 128, 64, dtype=data_dtype)
key = Tensor.rand(32, 8, 128, 64, dtype=data_dtype)
value = Tensor.rand(32, 8, 128, 64, dtype=data_dtype)
mask = (Tensor.rand(32, 8, 128, 128) < 0.5)
assert query.scaled_dot_product_attention(key, value, is_causal=True).dtype == data_dtype
assert query.scaled_dot_product_attention(key, value, is_causal=True, dropout_p=0.3).dtype == data_dtype
assert query.scaled_dot_product_attention(key, value, is_causal=False).dtype == data_dtype
assert query.scaled_dot_product_attention(key, value, attn_mask=mask).dtype == data_dtype
class TestTypePromotion(unittest.TestCase):
@given(strat.sampled_from(core_dtypes))
def test_self_promo_to_self(self, dtype):
@ -398,237 +228,5 @@ class TestTypePromotion(unittest.TestCase):
assert least_upper_dtype(dtypes.fp8e5m2, dtypes.int64) == dtypes.fp8e5m2
assert least_upper_dtype(dtypes.fp8e5m2, dtypes.uint64) == dtypes.fp8e5m2
class TestAutoCastType(unittest.TestCase):
def setUp(self):
self.old_default_int, self.old_default_float = dtypes.default_int, dtypes.default_float
def tearDown(self):
dtypes.default_int, dtypes.default_float = self.old_default_int, self.old_default_float
@given(strat.sampled_from(dtype_floats), strat.sampled_from(dtype_floats))
def test_least_upper_float_input_is_float(self, input_dtype, default_float):
dtypes.default_float = default_float
self.assertEqual(least_upper_float(input_dtype), input_dtype)
@given(strat.sampled_from(dtype_ints), strat.sampled_from(dtype_floats))
def test_least_upper_float_input_is_int(self, input_dtype, default_float):
dtypes.default_float = default_float
self.assertEqual(least_upper_float(input_dtype), default_float)
@given(strat.sampled_from([d for d in core_dtypes if dtypes.is_int(d) and is_dtype_supported(d)]))
def test_int_to_float_unary_func(self, dtype):
for func in [
lambda t: t.exp(),
lambda t: t.exp2(),
lambda t: t.log(),
lambda t: t.log2(),
lambda t: t.sqrt(),
lambda t: t.rsqrt(),
lambda t: t.sin(),
lambda t: t.cos(),
lambda t: t.tan(),
lambda t: t.sigmoid(),
]:
a = [2, 3, 4]
# float16 can have larger precision errors
np.testing.assert_allclose(func(Tensor(a, dtype=dtype)).numpy(), func(torch.tensor(a)), rtol=1e-3, atol=1e-3)
@given(strat.sampled_from(core_dtypes))
def test_broadcast_scalar(self, dt):
assert (Tensor.ones(4, 4, dtype=dt) + 2.3).dtype == (dt if dtypes.is_float(dt) else dtypes.default_float)
assert (Tensor.ones(4, 4, dtype=dt) + 2).dtype == (dt if dtypes.is_float(dt) or dtypes.is_int(dt) else dtypes.default_int)
assert (Tensor.ones(4, 4, dtype=dt) + True).dtype == dt
@given(strat.sampled_from(dtype_floats))
def test_int_div_int(self, default_float):
dtypes.default_float = default_float
self.assertEqual(Tensor([1]).div(Tensor([2])).dtype, default_float)
def test_sum(self):
assert (Tensor([0, 1], dtype=dtypes.bool)).sum().dtype == dtypes.int32
assert (Tensor([0, 1], dtype=dtypes.int8)).sum().dtype == dtypes.int32
assert (Tensor([0, 1], dtype=dtypes.int16)).sum().dtype == dtypes.int32
assert (Tensor([0, 1], dtype=dtypes.int32)).sum().dtype == dtypes.int32
assert (Tensor([0, 1], dtype=dtypes.int64)).sum().dtype == dtypes.int64
assert (Tensor([0, 1], dtype=dtypes.uint8)).sum().dtype == dtypes.uint32
assert (Tensor([0, 1], dtype=dtypes.uint16)).sum().dtype == dtypes.uint32
assert (Tensor([0, 1], dtype=dtypes.uint32)).sum().dtype == dtypes.uint32
assert (Tensor([0, 1], dtype=dtypes.uint64)).sum().dtype == dtypes.uint64
assert (Tensor([0, 1], dtype=dtypes.fp8e4m3)).sum().dtype == dtypes.fp8e4m3
assert (Tensor([0, 1], dtype=dtypes.fp8e5m2)).sum().dtype == dtypes.fp8e5m2
assert (Tensor([0, 1], dtype=dtypes.float16)).sum().dtype == dtypes.float16
assert (Tensor([0, 1], dtype=dtypes.bfloat16)).sum().dtype == dtypes.bfloat16
assert (Tensor([0, 1], dtype=dtypes.float32)).sum().dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.float64)).sum().dtype == dtypes.float64
@unittest.skipUnless(is_dtype_supported(dtypes.float16), "need float16")
def test_sum_dtype_arg(self):
t = Tensor([40000, 40000], dtype=dtypes.float16)
# default float16 sum returns in float16, overflowed in this case
assert t.sum().dtype == dtypes.float16
assert math.isinf(t.sum().numpy().item())
# specifiying dtype and it's not downcasted
assert t.sum(dtype=dtypes.float32).dtype == dtypes.float32
np.testing.assert_allclose(t.sum(dtype=dtypes.float32).numpy(), 80000)
def test_prod_dtype_arg(self):
t = Tensor([100, 200], dtype=dtypes.int32)
assert t.prod().dtype == dtypes.int32
np.testing.assert_allclose(t.prod().numpy(), 20000)
assert t.prod(dtype=dtypes.float32).dtype == dtypes.float32
np.testing.assert_allclose(t.prod(dtype=dtypes.float32).numpy(), 20000)
def test_mean(self):
assert (Tensor([0, 1], dtype=dtypes.bool)).mean().dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.int8)).mean().dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.int16)).mean().dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.int32)).mean().dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.int64)).mean().dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.uint8)).mean().dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.uint16)).mean().dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.uint32)).mean().dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.uint64)).mean().dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.fp8e4m3)).mean().dtype == dtypes.fp8e4m3
assert (Tensor([0, 1], dtype=dtypes.fp8e5m2)).mean().dtype == dtypes.fp8e5m2
assert (Tensor([0, 1], dtype=dtypes.float16)).mean().dtype == dtypes.float16
assert (Tensor([0, 1], dtype=dtypes.bfloat16)).mean().dtype == dtypes.bfloat16
assert (Tensor([0, 1], dtype=dtypes.float32)).mean().dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.float64)).mean().dtype == dtypes.float64
def test_cumsum(self):
assert (Tensor([0, 1], dtype=dtypes.bool)).cumsum(0).dtype == dtypes.int32
assert (Tensor([0, 1], dtype=dtypes.int8)).cumsum(0).dtype == dtypes.int32
assert (Tensor([0, 1], dtype=dtypes.int16)).cumsum(0).dtype == dtypes.int32
assert (Tensor([0, 1], dtype=dtypes.int32)).cumsum(0).dtype == dtypes.int32
assert (Tensor([0, 1], dtype=dtypes.int64)).cumsum(0).dtype == dtypes.int64
assert (Tensor([0, 1], dtype=dtypes.uint8)).cumsum(0).dtype == dtypes.uint32
assert (Tensor([0, 1], dtype=dtypes.uint16)).cumsum(0).dtype == dtypes.uint32
assert (Tensor([0, 1], dtype=dtypes.uint32)).cumsum(0).dtype == dtypes.uint32
assert (Tensor([0, 1], dtype=dtypes.uint64)).cumsum(0).dtype == dtypes.uint64
assert (Tensor([0, 1], dtype=dtypes.fp8e4m3)).cumsum(0).dtype == dtypes.fp8e4m3
assert (Tensor([0, 1], dtype=dtypes.fp8e5m2)).cumsum(0).dtype == dtypes.fp8e5m2
assert (Tensor([0, 1], dtype=dtypes.float16)).cumsum(0).dtype == dtypes.float16
assert (Tensor([0, 1], dtype=dtypes.bfloat16)).cumsum(0).dtype == dtypes.bfloat16
assert (Tensor([0, 1], dtype=dtypes.float32)).cumsum(0).dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.float64)).cumsum(0).dtype == dtypes.float64
@given(strat.sampled_from(core_dtypes), strat.sampled_from(core_dtypes), strat.sampled_from(core_dtypes))
def test_matmul(self, dt1, dt2, acc_dt):
t1 = Tensor([0, 1], dtype=dt1)
t2 = Tensor([0, 1], dtype=dt2)
self.assertEqual(t1.matmul(t2).dtype, least_upper_dtype(t1.dtype, t2.dtype))
# if dtype is specified, return in dtype
self.assertEqual(t1.matmul(t2, dtype=acc_dt).dtype, acc_dt)
@given(strat.sampled_from(core_dtypes), strat.sampled_from(core_dtypes), strat.sampled_from(core_dtypes), strat.sampled_from(core_dtypes))
def test_linear(self, dt1, dt2, dt3, acc_dt):
x = Tensor([0, 1], dtype=dt1)
w = Tensor([0, 1], dtype=dt2)
b = Tensor([0, 1], dtype=dt3)
self.assertEqual(x.linear(w).dtype, least_upper_dtype(x.dtype, w.dtype))
self.assertEqual(x.linear(w, b).dtype, least_upper_dtype(least_upper_dtype(x.dtype, w.dtype), b.dtype))
# if dtype is specified, return in dtype
self.assertEqual(x.linear(w, dtype=acc_dt).dtype, acc_dt)
self.assertEqual(x.linear(w, b, dtype=acc_dt).dtype, acc_dt)
@staticmethod
def check_where_alternate_input_other(input_, other, data_type):
assert (Tensor([True, False]).where(input_, other)).dtype == data_type
assert (Tensor([True, False]).where(other, input_)).dtype == data_type
@given(strat.sampled_from(core_dtypes), strat.sampled_from(core_dtypes))
def test_where_no_scalar(self, dt1, dt2):
self.check_where_alternate_input_other(Tensor(2, dtype=dt1), Tensor(3, dtype=dt2), least_upper_dtype(dt1, dt2))
@given(strat.sampled_from(core_dtypes))
def test_where_one_scalar(self, dt):
t = Tensor(2, dtype=dt)
self.check_where_alternate_input_other(t, 3.2, (dt if dtypes.is_float(dt) else dtypes.default_float))
self.check_where_alternate_input_other(t, 3, (dt if dtypes.is_float(dt) or dtypes.is_int(dt) else dtypes.default_int))
self.check_where_alternate_input_other(t, True, dt)
def test_where_two_scalars(self):
self.check_where_alternate_input_other(3.1, 3.2, dtypes.default_float)
self.check_where_alternate_input_other(3.1, 3, dtypes.default_float)
self.check_where_alternate_input_other(3.1, True, dtypes.default_float)
self.check_where_alternate_input_other(3, 2, dtypes.default_int)
self.check_where_alternate_input_other(3, True, dtypes.default_int)
self.check_where_alternate_input_other(False, True, dtypes.bool)
@given(strat.sampled_from(core_dtypes), strat.sampled_from(core_dtypes))
def test_maximum(self, dt1, dt2):
assert Tensor([0, 1, 2], dtype=dt1).maximum(Tensor([2, 0, 5], dtype=dt2)).dtype == least_upper_dtype(dt1, dt2)
@given(strat.sampled_from(core_dtypes))
def test_maximum_const(self, dt):
assert Tensor([1, 2], dtype=dt).maximum(3.1).dtype == (dt if dtypes.is_float(dt) else dtypes.default_float)
assert Tensor([1, 2], dtype=dt).maximum(3).dtype == (dt if dtypes.is_float(dt) or dtypes.is_int(dt) else dtypes.default_int)
assert Tensor([1, 2], dtype=dt).maximum(True).dtype == dt
def test_div(self):
assert (Tensor([1, 2], dtype=dtypes.int32) / Tensor([2, 2], dtype=dtypes.int32)).dtype == dtypes.default_float
assert (Tensor([1, 2], dtype=dtypes.int16) / Tensor([2, 2], dtype=dtypes.int32)).dtype == dtypes.default_float
assert (Tensor([1, 2], dtype=dtypes.float32) / Tensor([2, 2], dtype=dtypes.float16)).dtype == dtypes.float32
assert (Tensor([1, 2], dtype=dtypes.int32) / Tensor([2, 2], dtype=dtypes.float16)).dtype == dtypes.float16
def test_div_const(self):
assert (Tensor([1, 2], dtype=dtypes.int32) / 2).dtype == dtypes.default_float
assert (Tensor([1, 2], dtype=dtypes.int32) / 2.0).dtype == dtypes.default_float
assert (Tensor([1, 2], dtype=dtypes.float16) / 2).dtype == dtypes.float16
assert (Tensor([1, 2], dtype=dtypes.float16) / 2.0).dtype == dtypes.float16
def test_gradient_dtype(self):
old_default_float = dtypes.default_float
for default_dtype in dtypes.floats:
if not is_dtype_supported(default_dtype): continue
dtypes.default_float = default_dtype
for dtype in dtypes.floats:
if not is_dtype_supported(dtype): continue
if DEBUG >= 2:
print(f"testing {default_dtype=}, {dtype=}")
a = Tensor([1, 2, 3], dtype=dtype, requires_grad=True)
b = (a * 5).sum()
b.backward() # if there is dtype mismatch, lazy should assert
assert a.grad.dtype == a.dtype
np.testing.assert_allclose(a.grad.numpy(), [5, 5, 5])
dtypes.default_float = old_default_float
@unittest.skipIf(Device.DEFAULT == "PYTHON", "very slow")
@slow
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "Binding size is larger than the maximum storage buffer binding size")
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
def test_mean_half_precision_underflow(self):
N = 10000
x = 0.001
t = Tensor([[x]], dtype=dtypes.half, requires_grad=True).expand(N, N).contiguous()
np.testing.assert_allclose(t.mean(axis=1).numpy(), np.array([x] * N, dtype=np.float16), rtol=1e-3)
@unittest.skip("this test only works with SPLIT_REDUCEOP=1")
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
def test_mean_half_precision_overflow(self):
N = 256
t = Tensor([60000] * N*N, dtype=dtypes.half, requires_grad=True).reshape(N, N)
np.testing.assert_allclose(t.mean().numpy(), 60000)
t.square().mean().backward()
np.testing.assert_allclose(t.grad.numpy().flatten(), [60000 * 2 / (N*N)] * N*N)
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "Precision error")
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
def test_softmax_dtype(self):
data = [1, 2, 3]
t = Tensor(data, dtype=dtypes.half)
tt = torch.tensor(data, dtype=torch.half)
out = t.softmax(0)
self.assertEqual(out.dtype, dtypes.half)
np.testing.assert_allclose(out.numpy(), tt.softmax(0).numpy(), rtol=1e-3)
out = t.softmax(0, dtype=dtypes.float)
self.assertEqual(out.dtype, dtypes.float)
np.testing.assert_allclose(out.numpy(), tt.softmax(0, dtype=torch.float).numpy(), rtol=1e-3)
out = t.log_softmax(0)
self.assertEqual(out.dtype, dtypes.half)
np.testing.assert_allclose(out.numpy(), tt.log_softmax(0).numpy(), rtol=1e-3)
out = t.log_softmax(0, dtype=dtypes.float)
self.assertEqual(out.dtype, dtypes.float)
np.testing.assert_allclose(out.numpy(), tt.log_softmax(0, dtype=torch.float).numpy(), rtol=1e-3)
if __name__ == '__main__':
unittest.main()

View file

@ -1,7 +1,6 @@
from typing import Callable
import unittest, math
import torch
import numpy as np
from tinygrad import Tensor
from tinygrad.dtype import dtypes
from tinygrad.uop.ops import UOp
@ -63,71 +62,6 @@ class TestGradient(unittest.TestCase):
def test_big_chain(self): self._test_two_input_function(lambda x,y: (1.0/x*y)+x*y)
def test_where(self): self._test_two_input_function(lambda x,y: (x<y).where(x,y), lambda x,y: torch.where(x<y,x,y))
class TestTensorGradient(unittest.TestCase):
def test_example(self):
x = Tensor.eye(3)
y = Tensor([[2.0,0,-2.0]])
z = y.matmul(x).sum()
dx, dy = z.gradient(x, y)
self.assertListEqual(dx.tolist(), [[2.0, 2.0, 2.0], [0.0, 0.0, 0.0], [-2.0, -2.0, -2.0]])
self.assertListEqual(dy.tolist(), [[1.0, 1.0, 1.0]])
def test_raises(self):
x = Tensor([1.0, 2.0, 3.0])
w = Tensor.randn((3,))
with self.assertRaises(RuntimeError): x.sum().gradient(w)
def test_with_custom_gradient(self):
x = Tensor([1.0, 2.0, 3.0])
z = (x * x).sum()
dx = z.gradient(x, gradient=Tensor([3.0]))[0]
self.assertListEqual(dx.tolist(), [6.0, 12.0, 18.0])
def test_broadcast_gradient(self):
x = Tensor([[1.0], [2.0], [3.0]])
y = Tensor([[10.0, 20.0, 30.0, 40.0]])
z = (x + y).sum()
dx, dy = z.gradient(x, y)
self.assertListEqual(dx.tolist(), [[4.0], [4.0], [4.0]])
self.assertListEqual(dy.tolist(), [[3.0, 3.0, 3.0, 3.0]])
def test_non_scalar_output(self):
x = Tensor([1.0, 2.0, 3.0])
z = x * x
with self.assertRaises(AssertionError): z.gradient(x)
dz = Tensor([1.0, 1.0, 1.0])
dx = z.gradient(x, gradient=dz)[0]
self.assertListEqual(dx.tolist(), [2.0, 4.0, 6.0])
def test_cast_before_view(self):
x = Tensor([1.0, 1, 1, 1])
x_reshaped = x.reshape(2,2)
x_casted = x_reshaped.cast(dtypes.float16)
x_casted.mean().gradient(x_reshaped)
def test_non_float_tensor_raise(self):
x = Tensor([1, 2, 3])
with self.assertRaises(RuntimeError): x.sum().gradient(x)
with self.assertRaises(RuntimeError): x.float().sum().gradient(x)
def test_copy_to_device_gradient(self):
t = Tensor([1.0, 2, 3], requires_grad=True).realize()
t.to("CPU:1").square().sum().backward()
self.assertEqual(t.grad.device, t.device)
self.assertListEqual(t.grad.tolist(), [2.0, 4.0, 6.0])
def test_multiple_backward(self):
x = Tensor([3.], requires_grad=True)
(x*2)[0].backward()
np.testing.assert_allclose(x.grad.numpy(), [2.0])
old_grad = x.grad
(x*3)[0].backward()
np.testing.assert_allclose(x.grad.numpy(), [2.0+3.0])
self.assertIs(x.grad, old_grad)
(x*x)[0].backward()
np.testing.assert_allclose(x.grad.numpy(), [2.0+3.0+2*3.0])
self.assertIs(x.grad, old_grad)
class TestRealizeMeansRealize(unittest.TestCase):
def test_randn_realizes(self):
x = Tensor.randn(2, 3, 64, 64, requires_grad=True).realize()
@ -148,18 +82,5 @@ class TestRealizeMeansRealize(unittest.TestCase):
y = x * 2
y.sum().gradient(x)[0].realize()
class TestViewGradient(unittest.TestCase):
def test_expand(self):
# this test shows that if Tensors collapse to the views and create a disconnected graph
# there's no way to recover the proper gradient
x = Tensor.randn(5,2)
a = Tensor([3.], requires_grad=True)
aex = a.expand(10)
(aex.reshape(5,2) * x).sum().backward()
np.testing.assert_allclose(aex.grad.numpy(), x.reshape(10).numpy())
# NOTE: aex.grad is *not* a.grad.expand(10)!
with self.assertRaises(AssertionError):
np.testing.assert_allclose(aex.grad.numpy(), a.grad.expand(10).numpy())
if __name__ == '__main__':
unittest.main()

View file

@ -356,10 +356,6 @@ class TestPolyN(unittest.TestCase):
np.testing.assert_allclose(polyN(3.0, [1.0, -2.0, 1.0]), 4.0)
np.testing.assert_allclose(polyN(4.0, [1.0, -2.0, 1.0]), 9.0)
def test_tensor(self):
from tinygrad.tensor import Tensor
np.testing.assert_allclose(polyN(Tensor([1.0, 2.0, 3.0, 4.0]), [1.0, -2.0, 1.0]).numpy(), [0.0, 1.0, 4.0, 9.0])
def test_uop(self):
from tinygrad.dtype import dtypes
from tinygrad.uop.ops import UOp