Compare commits

...

1 commit

Author SHA1 Message Date
George Hotz
66634c643e test unit should pass on NULL device 2026-02-03 00:11:24 +08:00
24 changed files with 527 additions and 491 deletions

425
test/test_dtype_spec.py Normal file
View file

@ -0,0 +1,425 @@
import unittest, math, operator, subprocess
from tinygrad.tensor import Tensor, dtypes, Device
from tinygrad.dtype import DType, DTYPES_DICT, least_upper_float
from tinygrad.device import is_dtype_supported
from tinygrad.helpers import getenv, DEBUG
from test.helpers import slow
from hypothesis import given, settings, strategies as strat
import numpy as np
import torch
settings.register_profile("my_profile", max_examples=50, deadline=None, derandomize=getenv("DERANDOMIZE_CI", False))
settings.load_profile("my_profile")
core_dtypes = list(DTYPES_DICT.values())
dtype_ints = [dt for dt in core_dtypes if dtypes.is_int(dt) and is_dtype_supported(dt)]
dtype_floats = [dt for dt in core_dtypes if dtypes.is_float(dt) and is_dtype_supported(dt)]
def _assert_eq(tensor:Tensor, target_dtype:DType, target, tol_target_dtype:float=1e-7):
if DEBUG >= 2: print(tensor.numpy())
try:
assert tensor.dtype == target_dtype
np.testing.assert_allclose(tensor.numpy(), target, rtol={dtypes.float16:1e-3, dtypes.bfloat16:1e-2,
dtypes.fp8e4m3:1e-1, dtypes.fp8e5m2:5e-1}.get(target_dtype, tol_target_dtype))
except AssertionError as e:
raise AssertionError(f"\ntensor {tensor.numpy()} dtype {tensor.dtype} does not match target {target} with dtype {target_dtype}") from e
class TestTypeSpec(unittest.TestCase):
def setUp(self):
self.old_default_int, self.old_default_float = dtypes.default_int, dtypes.default_float
def tearDown(self):
dtypes.default_int, dtypes.default_float = self.old_default_int, self.old_default_float
def test_set_dtype_default(self):
for default_int in [dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64]:
dtypes.default_int = default_int
assert dtypes.default_int == default_int
for default_float in [*dtypes.fp8s, dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64]:
dtypes.default_float = default_float
assert dtypes.default_float == default_float
@unittest.skip("this test is slow and spawning whole pythons")
def test_env_set_default_float(self):
# check default
subprocess.run(['python3 -c "from tinygrad import dtypes; assert dtypes.default_float == dtypes.float"'],
shell=True, check=True)
# check change
subprocess.run(['DEFAULT_FLOAT=HALF python3 -c "from tinygrad import dtypes; assert dtypes.default_float == dtypes.half"'],
shell=True, check=True)
# check invalid
with self.assertRaises(subprocess.CalledProcessError):
subprocess.run(['DEFAULT_FLOAT=INT32 python3 -c "from tinygrad import dtypes"'],
shell=True, check=True)
with self.assertRaises(subprocess.CalledProcessError):
subprocess.run(['DEFAULT_FLOAT=TYPO python3 -c "from tinygrad import dtypes"'],
shell=True, check=True)
@unittest.skipUnless(is_dtype_supported(dtypes.int8), f"no int8 on {Device.DEFAULT}")
def test_dtype_str_arg(self):
n = np.random.normal(0, 1, (10, 10)).astype(np.float32)
tested = 0
for dtype_str, dtype in [
("bool", dtypes.bool), ("int8", dtypes.int8), ("int", dtypes.int), ("uint32", dtypes.uint32), ("float32", dtypes.float32)]:
np.testing.assert_equal(Tensor(n, dtype=dtype_str).numpy(), Tensor(n, dtype=dtype).numpy())
np.testing.assert_equal(Tensor(n).cast(dtype_str).numpy(), Tensor(n).cast(dtype).numpy())
if dtype.itemsize == 4:
np.testing.assert_equal(Tensor(n).bitcast(dtype_str).numpy(), Tensor(n).bitcast(dtype).numpy())
tested += 1
assert tested == 3
with self.assertRaises(AttributeError): Tensor([1, 2, 3], dtype="nonexistdtype")
with self.assertRaises(AttributeError): Tensor([1, 2, 3], dtype="")
np.testing.assert_equal(Tensor(n).sum(dtype="int16").numpy(), Tensor(n).sum(dtype=dtypes.int16).numpy())
@given(strat.sampled_from(dtype_ints), strat.sampled_from(dtype_floats))
def test_creation(self, default_int, default_float):
dtypes.default_int, dtypes.default_float = default_int, default_float
_assert_eq(Tensor(True), dtypes.bool, True)
_assert_eq(Tensor(None), dtypes.default_float, [])
_assert_eq(Tensor(2), dtypes.default_int, 2)
_assert_eq(Tensor(2.34), dtypes.default_float, 2.34)
_assert_eq(Tensor([]), dtypes.default_float, [])
_assert_eq(Tensor([1]), dtypes.default_int, [1])
_assert_eq(Tensor([1.1]), dtypes.default_float, [1.1])
_assert_eq(Tensor.eye(0), dtypes.default_float, np.eye(0))
_assert_eq(Tensor.eye(3), dtypes.default_float, np.eye(3))
if is_dtype_supported(dtypes.int64):
_assert_eq(Tensor.eye(3, dtype=dtypes.int64), dtypes.int64, np.eye(3))
if is_dtype_supported(dtypes.float16):
_assert_eq(Tensor.eye(3, dtype=dtypes.float16), dtypes.float16, np.eye(3))
@given(strat.sampled_from(dtype_ints), strat.sampled_from(dtype_floats))
def test_full(self, default_int, default_float):
dtypes.default_int, dtypes.default_float = default_int, default_float
_assert_eq(Tensor.zeros((2, 3)), dtypes.default_float, np.zeros((2, 3)))
if is_dtype_supported(dtypes.int64):
_assert_eq(Tensor.zeros((2, 3), dtype=dtypes.int64), dtypes.int64, np.zeros((2, 3)))
if is_dtype_supported(dtypes.float16):
_assert_eq(Tensor.zeros((2, 3), dtype=dtypes.float16), dtypes.float16, np.zeros((2, 3)))
_assert_eq(Tensor.ones((2, 3)), dtypes.default_float, np.ones((2, 3)))
if is_dtype_supported(dtypes.int64):
_assert_eq(Tensor.ones((2, 3), dtype=dtypes.int64), dtypes.int64, np.ones((2, 3)))
if is_dtype_supported(dtypes.float16):
_assert_eq(Tensor.ones((2, 3), dtype=dtypes.float16), dtypes.float16, np.ones((2, 3)))
_assert_eq(Tensor.full((2, 3), 3.0), dtypes.default_float, np.full((2, 3), 3.0))
_assert_eq(Tensor.full((2, 3), 3), dtypes.default_int, np.full((2, 3), 3))
_assert_eq(Tensor.full((2, 3), True), dtypes.bool, np.full((2, 3), True))
if is_dtype_supported(dtypes.int64):
_assert_eq(Tensor.full((2, 3), 3, dtype=dtypes.int64), dtypes.int64, np.full((2, 3), 3))
_assert_eq(Tensor.full((2, 3), 3.0, dtype=dtypes.int64), dtypes.int64, np.full((2, 3), 3))
if is_dtype_supported(dtypes.float16):
_assert_eq(Tensor.full((2, 3), 3, dtype=dtypes.float16), dtypes.float16, np.full((2, 3), 3))
_assert_eq(Tensor.full((2, 3), 3.0, dtype=dtypes.float16), dtypes.float16, np.full((2, 3), 3))
@given(strat.sampled_from(dtype_ints), strat.sampled_from(dtype_floats))
def test_reduce_0d_default(self, default_int, default_float):
dtypes.default_int, dtypes.default_float = default_int, default_float
_assert_eq(Tensor.ones((2,3,0)).sum(2), dtypes.default_float, np.zeros((2, 3)))
# TODO: what should this one be?
# _assert_eq(Tensor.ones((2,3,0), dtype=dtypes.default_int).sum(2), dtypes.default_int, np.zeros((2, 3)))
_assert_eq(Tensor.ones((2,3,0), dtype=dtypes.int32).sum(2), dtypes.int32, np.zeros((2, 3)))
@given(strat.sampled_from(dtype_ints), strat.sampled_from(dtype_floats))
def test_arange(self, default_int, default_float):
dtypes.default_int, dtypes.default_float = default_int, default_float
_assert_eq(Tensor.arange(5), dtypes.default_int, np.arange(5))
_assert_eq(Tensor.arange(120), dtypes.default_int, np.arange(120))
_assert_eq(Tensor.arange(5.0), dtypes.default_float, np.arange(5))
if is_dtype_supported(dtypes.int16):
_assert_eq(Tensor.arange(5, dtype=dtypes.int16), dtypes.int16, np.arange(5))
if is_dtype_supported(dtypes.int64):
_assert_eq(Tensor.arange(5, dtype=dtypes.int64), dtypes.int64, np.arange(5))
if is_dtype_supported(dtypes.float16):
_assert_eq(Tensor.arange(5, dtype=dtypes.float16), dtypes.float16, np.arange(5))
_assert_eq(Tensor.arange(3, 9, 0.7), dtypes.default_float, np.arange(3, 9, 0.7), 1e-6 if Device.DEFAULT == "WEBGPU" else 1e-7)
_assert_eq(Tensor.arange(3, 8.5, 3), dtypes.default_float, np.arange(3, 8.5, 3))
# stop-start and step have different signs
_assert_eq(Tensor.arange(3, 5, -2), dtypes.default_int, np.arange(3, 5, -2))
_assert_eq(Tensor.arange(5.0, 3.0), dtypes.default_float, np.arange(5.0, 3.0))
@given(strat.sampled_from(core_dtypes), strat.sampled_from([operator.gt, operator.ge, operator.le, operator.lt, operator.eq, operator.ne]))
def test_bool_ops(self, dtype, op):
assert op(Tensor.ones(4, 4, dtype=dtype), Tensor.ones(4, 4, dtype=dtype)).dtype == dtypes.bool
@given(strat.sampled_from(core_dtypes), strat.sampled_from(dtype_ints), strat.sampled_from(dtype_floats))
def test_functions_return_index(self, dtype, default_int, default_float):
dtypes.default_int, dtypes.default_float = default_int, default_float
assert Tensor([0, 1], dtype=dtype).argmax().dtype == dtypes.int32
assert Tensor([0, 1], dtype=dtype).argmin().dtype == dtypes.int32
assert Tensor([0, 1], dtype=dtype).multinomial().dtype == dtypes.int32
@given(strat.sampled_from(core_dtypes), strat.sampled_from(dtype_ints))
def test_tensor_indexing_returns_same_dtype(self, data_dtype, indices_dtype):
X_data = Tensor.ones(60000, 1, 28, 28, dtype=data_dtype)
indices = Tensor.randint(512, high=X_data.shape[0]).cast(indices_dtype)
assert X_data[indices].dtype == X_data.dtype
@given(strat.sampled_from(core_dtypes), strat.sampled_from(dtype_ints))
def test_gather_returns_same_dtype(self, data_dtype, indices_dtype):
X_data = Tensor([[1, 0], [0, 1]], dtype=data_dtype)
indices = Tensor([[0, 0], [1, 0]], dtype=indices_dtype)
assert X_data.gather(0, indices).dtype == X_data.dtype
assert X_data.gather(1, indices).dtype == X_data.dtype
@given(strat.sampled_from(dtype_floats), strat.sampled_from(dtype_floats))
def test_attention_returns_same_dtype(self, data_dtype, default_float):
dtypes.default_float = default_float
query = Tensor.rand(32, 8, 128, 64, dtype=data_dtype)
key = Tensor.rand(32, 8, 128, 64, dtype=data_dtype)
value = Tensor.rand(32, 8, 128, 64, dtype=data_dtype)
mask = (Tensor.rand(32, 8, 128, 128) < 0.5)
assert query.scaled_dot_product_attention(key, value, is_causal=True).dtype == data_dtype
assert query.scaled_dot_product_attention(key, value, is_causal=True, dropout_p=0.3).dtype == data_dtype
assert query.scaled_dot_product_attention(key, value, is_causal=False).dtype == data_dtype
assert query.scaled_dot_product_attention(key, value, attn_mask=mask).dtype == data_dtype
class TestAutoCastType(unittest.TestCase):
def setUp(self):
self.old_default_int, self.old_default_float = dtypes.default_int, dtypes.default_float
def tearDown(self):
dtypes.default_int, dtypes.default_float = self.old_default_int, self.old_default_float
@given(strat.sampled_from(dtype_floats), strat.sampled_from(dtype_floats))
def test_least_upper_float_input_is_float(self, input_dtype, default_float):
dtypes.default_float = default_float
self.assertEqual(least_upper_float(input_dtype), input_dtype)
@given(strat.sampled_from(dtype_ints), strat.sampled_from(dtype_floats))
def test_least_upper_float_input_is_int(self, input_dtype, default_float):
dtypes.default_float = default_float
self.assertEqual(least_upper_float(input_dtype), default_float)
@given(strat.sampled_from([d for d in core_dtypes if dtypes.is_int(d) and is_dtype_supported(d)]))
def test_int_to_float_unary_func(self, dtype):
for func in [
lambda t: t.exp(),
lambda t: t.exp2(),
lambda t: t.log(),
lambda t: t.log2(),
lambda t: t.sqrt(),
lambda t: t.rsqrt(),
lambda t: t.sin(),
lambda t: t.cos(),
lambda t: t.tan(),
lambda t: t.sigmoid(),
]:
a = [2, 3, 4]
# float16 can have larger precision errors
np.testing.assert_allclose(func(Tensor(a, dtype=dtype)).numpy(), func(torch.tensor(a)), rtol=1e-3, atol=1e-3)
@given(strat.sampled_from(core_dtypes))
def test_broadcast_scalar(self, dt):
assert (Tensor.ones(4, 4, dtype=dt) + 2.3).dtype == (dt if dtypes.is_float(dt) else dtypes.default_float)
assert (Tensor.ones(4, 4, dtype=dt) + 2).dtype == (dt if dtypes.is_float(dt) or dtypes.is_int(dt) else dtypes.default_int)
assert (Tensor.ones(4, 4, dtype=dt) + True).dtype == dt
@given(strat.sampled_from(dtype_floats))
def test_int_div_int(self, default_float):
dtypes.default_float = default_float
self.assertEqual(Tensor([1]).div(Tensor([2])).dtype, default_float)
def test_sum(self):
assert (Tensor([0, 1], dtype=dtypes.bool)).sum().dtype == dtypes.int32
assert (Tensor([0, 1], dtype=dtypes.int8)).sum().dtype == dtypes.int32
assert (Tensor([0, 1], dtype=dtypes.int16)).sum().dtype == dtypes.int32
assert (Tensor([0, 1], dtype=dtypes.int32)).sum().dtype == dtypes.int32
assert (Tensor([0, 1], dtype=dtypes.int64)).sum().dtype == dtypes.int64
assert (Tensor([0, 1], dtype=dtypes.uint8)).sum().dtype == dtypes.uint32
assert (Tensor([0, 1], dtype=dtypes.uint16)).sum().dtype == dtypes.uint32
assert (Tensor([0, 1], dtype=dtypes.uint32)).sum().dtype == dtypes.uint32
assert (Tensor([0, 1], dtype=dtypes.uint64)).sum().dtype == dtypes.uint64
assert (Tensor([0, 1], dtype=dtypes.fp8e4m3)).sum().dtype == dtypes.fp8e4m3
assert (Tensor([0, 1], dtype=dtypes.fp8e5m2)).sum().dtype == dtypes.fp8e5m2
assert (Tensor([0, 1], dtype=dtypes.float16)).sum().dtype == dtypes.float16
assert (Tensor([0, 1], dtype=dtypes.bfloat16)).sum().dtype == dtypes.bfloat16
assert (Tensor([0, 1], dtype=dtypes.float32)).sum().dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.float64)).sum().dtype == dtypes.float64
@unittest.skipUnless(is_dtype_supported(dtypes.float16), "need float16")
def test_sum_dtype_arg(self):
t = Tensor([40000, 40000], dtype=dtypes.float16)
# default float16 sum returns in float16, overflowed in this case
assert t.sum().dtype == dtypes.float16
assert math.isinf(t.sum().numpy().item())
# specifiying dtype and it's not downcasted
assert t.sum(dtype=dtypes.float32).dtype == dtypes.float32
np.testing.assert_allclose(t.sum(dtype=dtypes.float32).numpy(), 80000)
def test_prod_dtype_arg(self):
t = Tensor([100, 200], dtype=dtypes.int32)
assert t.prod().dtype == dtypes.int32
np.testing.assert_allclose(t.prod().numpy(), 20000)
assert t.prod(dtype=dtypes.float32).dtype == dtypes.float32
np.testing.assert_allclose(t.prod(dtype=dtypes.float32).numpy(), 20000)
def test_mean(self):
assert (Tensor([0, 1], dtype=dtypes.bool)).mean().dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.int8)).mean().dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.int16)).mean().dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.int32)).mean().dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.int64)).mean().dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.uint8)).mean().dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.uint16)).mean().dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.uint32)).mean().dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.uint64)).mean().dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.fp8e4m3)).mean().dtype == dtypes.fp8e4m3
assert (Tensor([0, 1], dtype=dtypes.fp8e5m2)).mean().dtype == dtypes.fp8e5m2
assert (Tensor([0, 1], dtype=dtypes.float16)).mean().dtype == dtypes.float16
assert (Tensor([0, 1], dtype=dtypes.bfloat16)).mean().dtype == dtypes.bfloat16
assert (Tensor([0, 1], dtype=dtypes.float32)).mean().dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.float64)).mean().dtype == dtypes.float64
def test_cumsum(self):
assert (Tensor([0, 1], dtype=dtypes.bool)).cumsum(0).dtype == dtypes.int32
assert (Tensor([0, 1], dtype=dtypes.int8)).cumsum(0).dtype == dtypes.int32
assert (Tensor([0, 1], dtype=dtypes.int16)).cumsum(0).dtype == dtypes.int32
assert (Tensor([0, 1], dtype=dtypes.int32)).cumsum(0).dtype == dtypes.int32
assert (Tensor([0, 1], dtype=dtypes.int64)).cumsum(0).dtype == dtypes.int64
assert (Tensor([0, 1], dtype=dtypes.uint8)).cumsum(0).dtype == dtypes.uint32
assert (Tensor([0, 1], dtype=dtypes.uint16)).cumsum(0).dtype == dtypes.uint32
assert (Tensor([0, 1], dtype=dtypes.uint32)).cumsum(0).dtype == dtypes.uint32
assert (Tensor([0, 1], dtype=dtypes.uint64)).cumsum(0).dtype == dtypes.uint64
assert (Tensor([0, 1], dtype=dtypes.fp8e4m3)).cumsum(0).dtype == dtypes.fp8e4m3
assert (Tensor([0, 1], dtype=dtypes.fp8e5m2)).cumsum(0).dtype == dtypes.fp8e5m2
assert (Tensor([0, 1], dtype=dtypes.float16)).cumsum(0).dtype == dtypes.float16
assert (Tensor([0, 1], dtype=dtypes.bfloat16)).cumsum(0).dtype == dtypes.bfloat16
assert (Tensor([0, 1], dtype=dtypes.float32)).cumsum(0).dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.float64)).cumsum(0).dtype == dtypes.float64
@given(strat.sampled_from(core_dtypes), strat.sampled_from(core_dtypes), strat.sampled_from(core_dtypes))
def test_matmul(self, dt1, dt2, acc_dt):
t1 = Tensor([0, 1], dtype=dt1)
t2 = Tensor([0, 1], dtype=dt2)
from tinygrad.dtype import least_upper_dtype
self.assertEqual(t1.matmul(t2).dtype, least_upper_dtype(t1.dtype, t2.dtype))
# if dtype is specified, return in dtype
self.assertEqual(t1.matmul(t2, dtype=acc_dt).dtype, acc_dt)
@given(strat.sampled_from(core_dtypes), strat.sampled_from(core_dtypes), strat.sampled_from(core_dtypes), strat.sampled_from(core_dtypes))
def test_linear(self, dt1, dt2, dt3, acc_dt):
x = Tensor([0, 1], dtype=dt1)
w = Tensor([0, 1], dtype=dt2)
b = Tensor([0, 1], dtype=dt3)
from tinygrad.dtype import least_upper_dtype
self.assertEqual(x.linear(w).dtype, least_upper_dtype(x.dtype, w.dtype))
self.assertEqual(x.linear(w, b).dtype, least_upper_dtype(least_upper_dtype(x.dtype, w.dtype), b.dtype))
# if dtype is specified, return in dtype
self.assertEqual(x.linear(w, dtype=acc_dt).dtype, acc_dt)
self.assertEqual(x.linear(w, b, dtype=acc_dt).dtype, acc_dt)
@staticmethod
def check_where_alternate_input_other(input_, other, data_type):
assert (Tensor([True, False]).where(input_, other)).dtype == data_type
assert (Tensor([True, False]).where(other, input_)).dtype == data_type
@given(strat.sampled_from(core_dtypes), strat.sampled_from(core_dtypes))
def test_where_no_scalar(self, dt1, dt2):
from tinygrad.dtype import least_upper_dtype
self.check_where_alternate_input_other(Tensor(2, dtype=dt1), Tensor(3, dtype=dt2), least_upper_dtype(dt1, dt2))
@given(strat.sampled_from(core_dtypes))
def test_where_one_scalar(self, dt):
t = Tensor(2, dtype=dt)
self.check_where_alternate_input_other(t, 3.2, (dt if dtypes.is_float(dt) else dtypes.default_float))
self.check_where_alternate_input_other(t, 3, (dt if dtypes.is_float(dt) or dtypes.is_int(dt) else dtypes.default_int))
self.check_where_alternate_input_other(t, True, dt)
def test_where_two_scalars(self):
self.check_where_alternate_input_other(3.1, 3.2, dtypes.default_float)
self.check_where_alternate_input_other(3.1, 3, dtypes.default_float)
self.check_where_alternate_input_other(3.1, True, dtypes.default_float)
self.check_where_alternate_input_other(3, 2, dtypes.default_int)
self.check_where_alternate_input_other(3, True, dtypes.default_int)
self.check_where_alternate_input_other(False, True, dtypes.bool)
@given(strat.sampled_from(core_dtypes), strat.sampled_from(core_dtypes))
def test_maximum(self, dt1, dt2):
from tinygrad.dtype import least_upper_dtype
assert Tensor([0, 1, 2], dtype=dt1).maximum(Tensor([2, 0, 5], dtype=dt2)).dtype == least_upper_dtype(dt1, dt2)
@given(strat.sampled_from(core_dtypes))
def test_maximum_const(self, dt):
assert Tensor([1, 2], dtype=dt).maximum(3.1).dtype == (dt if dtypes.is_float(dt) else dtypes.default_float)
assert Tensor([1, 2], dtype=dt).maximum(3).dtype == (dt if dtypes.is_float(dt) or dtypes.is_int(dt) else dtypes.default_int)
assert Tensor([1, 2], dtype=dt).maximum(True).dtype == dt
def test_div(self):
assert (Tensor([1, 2], dtype=dtypes.int32) / Tensor([2, 2], dtype=dtypes.int32)).dtype == dtypes.default_float
assert (Tensor([1, 2], dtype=dtypes.int16) / Tensor([2, 2], dtype=dtypes.int32)).dtype == dtypes.default_float
assert (Tensor([1, 2], dtype=dtypes.float32) / Tensor([2, 2], dtype=dtypes.float16)).dtype == dtypes.float32
assert (Tensor([1, 2], dtype=dtypes.int32) / Tensor([2, 2], dtype=dtypes.float16)).dtype == dtypes.float16
def test_div_const(self):
assert (Tensor([1, 2], dtype=dtypes.int32) / 2).dtype == dtypes.default_float
assert (Tensor([1, 2], dtype=dtypes.int32) / 2.0).dtype == dtypes.default_float
assert (Tensor([1, 2], dtype=dtypes.float16) / 2).dtype == dtypes.float16
assert (Tensor([1, 2], dtype=dtypes.float16) / 2.0).dtype == dtypes.float16
def test_gradient_dtype(self):
old_default_float = dtypes.default_float
for default_dtype in dtypes.floats:
if not is_dtype_supported(default_dtype): continue
dtypes.default_float = default_dtype
for dtype in dtypes.floats:
if not is_dtype_supported(dtype): continue
if DEBUG >= 2:
print(f"testing {default_dtype=}, {dtype=}")
a = Tensor([1, 2, 3], dtype=dtype, requires_grad=True)
b = (a * 5).sum()
b.backward() # if there is dtype mismatch, lazy should assert
assert a.grad.dtype == a.dtype
np.testing.assert_allclose(a.grad.numpy(), [5, 5, 5])
dtypes.default_float = old_default_float
@unittest.skipIf(Device.DEFAULT == "PYTHON", "very slow")
@slow
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "Binding size is larger than the maximum storage buffer binding size")
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
def test_mean_half_precision_underflow(self):
N = 10000
x = 0.001
t = Tensor([[x]], dtype=dtypes.half, requires_grad=True).expand(N, N).contiguous()
np.testing.assert_allclose(t.mean(axis=1).numpy(), np.array([x] * N, dtype=np.float16), rtol=1e-3)
@unittest.skip("this test only works with SPLIT_REDUCEOP=1")
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
def test_mean_half_precision_overflow(self):
N = 256
t = Tensor([60000] * N*N, dtype=dtypes.half, requires_grad=True).reshape(N, N)
np.testing.assert_allclose(t.mean().numpy(), 60000)
t.square().mean().backward()
np.testing.assert_allclose(t.grad.numpy().flatten(), [60000 * 2 / (N*N)] * N*N)
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "Precision error")
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
def test_softmax_dtype(self):
data = [1, 2, 3]
t = Tensor(data, dtype=dtypes.half)
tt = torch.tensor(data, dtype=torch.half)
out = t.softmax(0)
self.assertEqual(out.dtype, dtypes.half)
np.testing.assert_allclose(out.numpy(), tt.softmax(0).numpy(), rtol=1e-3)
out = t.softmax(0, dtype=dtypes.float)
self.assertEqual(out.dtype, dtypes.float)
np.testing.assert_allclose(out.numpy(), tt.softmax(0, dtype=torch.float).numpy(), rtol=1e-3)
out = t.log_softmax(0)
self.assertEqual(out.dtype, dtypes.half)
np.testing.assert_allclose(out.numpy(), tt.log_softmax(0).numpy(), rtol=1e-3)
out = t.log_softmax(0, dtype=dtypes.float)
self.assertEqual(out.dtype, dtypes.float)
np.testing.assert_allclose(out.numpy(), tt.log_softmax(0, dtype=torch.float).numpy(), rtol=1e-3)
if __name__ == '__main__':
unittest.main()

85
test/test_gradient.py Normal file
View file

@ -0,0 +1,85 @@
import unittest
import numpy as np
from tinygrad import Tensor
from tinygrad.dtype import dtypes
class TestTensorGradient(unittest.TestCase):
def test_example(self):
x = Tensor.eye(3)
y = Tensor([[2.0,0,-2.0]])
z = y.matmul(x).sum()
dx, dy = z.gradient(x, y)
self.assertListEqual(dx.tolist(), [[2.0, 2.0, 2.0], [0.0, 0.0, 0.0], [-2.0, -2.0, -2.0]])
self.assertListEqual(dy.tolist(), [[1.0, 1.0, 1.0]])
def test_raises(self):
x = Tensor([1.0, 2.0, 3.0])
w = Tensor.randn((3,))
with self.assertRaises(RuntimeError): x.sum().gradient(w)
def test_with_custom_gradient(self):
x = Tensor([1.0, 2.0, 3.0])
z = (x * x).sum()
dx = z.gradient(x, gradient=Tensor([3.0]))[0]
self.assertListEqual(dx.tolist(), [6.0, 12.0, 18.0])
def test_broadcast_gradient(self):
x = Tensor([[1.0], [2.0], [3.0]])
y = Tensor([[10.0, 20.0, 30.0, 40.0]])
z = (x + y).sum()
dx, dy = z.gradient(x, y)
self.assertListEqual(dx.tolist(), [[4.0], [4.0], [4.0]])
self.assertListEqual(dy.tolist(), [[3.0, 3.0, 3.0, 3.0]])
def test_non_scalar_output(self):
x = Tensor([1.0, 2.0, 3.0])
z = x * x
with self.assertRaises(AssertionError): z.gradient(x)
dz = Tensor([1.0, 1.0, 1.0])
dx = z.gradient(x, gradient=dz)[0]
self.assertListEqual(dx.tolist(), [2.0, 4.0, 6.0])
def test_cast_before_view(self):
x = Tensor([1.0, 1, 1, 1])
x_reshaped = x.reshape(2,2)
x_casted = x_reshaped.cast(dtypes.float16)
x_casted.mean().gradient(x_reshaped)
def test_non_float_tensor_raise(self):
x = Tensor([1, 2, 3])
with self.assertRaises(RuntimeError): x.sum().gradient(x)
with self.assertRaises(RuntimeError): x.float().sum().gradient(x)
def test_copy_to_device_gradient(self):
t = Tensor([1.0, 2, 3], requires_grad=True).realize()
t.to("CPU:1").square().sum().backward()
self.assertEqual(t.grad.device, t.device)
self.assertListEqual(t.grad.tolist(), [2.0, 4.0, 6.0])
def test_multiple_backward(self):
x = Tensor([3.], requires_grad=True)
(x*2)[0].backward()
np.testing.assert_allclose(x.grad.numpy(), [2.0])
old_grad = x.grad
(x*3)[0].backward()
np.testing.assert_allclose(x.grad.numpy(), [2.0+3.0])
self.assertIs(x.grad, old_grad)
(x*x)[0].backward()
np.testing.assert_allclose(x.grad.numpy(), [2.0+3.0+2*3.0])
self.assertIs(x.grad, old_grad)
class TestViewGradient(unittest.TestCase):
def test_expand(self):
# this test shows that if Tensors collapse to the views and create a disconnected graph
# there's no way to recover the proper gradient
x = Tensor.randn(5,2)
a = Tensor([3.], requires_grad=True)
aex = a.expand(10)
(aex.reshape(5,2) * x).sum().backward()
np.testing.assert_allclose(aex.grad.numpy(), x.reshape(10).numpy())
# NOTE: aex.grad is *not* a.grad.expand(10)!
with self.assertRaises(AssertionError):
np.testing.assert_allclose(aex.grad.numpy(), a.grad.expand(10).numpy())
if __name__ == '__main__':
unittest.main()

11
test/test_helpers.py Normal file
View file

@ -0,0 +1,11 @@
import unittest
import numpy as np
from tinygrad.helpers import polyN
class TestPolyN(unittest.TestCase):
def test_tensor(self):
from tinygrad.tensor import Tensor
np.testing.assert_allclose(polyN(Tensor([1.0, 2.0, 3.0, 4.0]), [1.0, -2.0, 1.0]).numpy(), [0.0, 1.0, 4.0, 9.0])
if __name__ == '__main__':
unittest.main()

View file

@ -1,9 +1,8 @@
import unittest, math, operator, subprocess, struct import unittest, math, struct
from tinygrad.tensor import Tensor, dtypes, Device from tinygrad.tensor import dtypes
from tinygrad.dtype import DType, DTYPES_DICT, truncate, float_to_fp16, float_to_bf16, _to_np_dtype, least_upper_dtype, least_upper_float from tinygrad.dtype import DTYPES_DICT, truncate, float_to_fp16, float_to_bf16, _to_np_dtype, least_upper_dtype
from tinygrad.device import is_dtype_supported from tinygrad.device import is_dtype_supported
from tinygrad.helpers import getenv, DEBUG from tinygrad.helpers import getenv
from test.helpers import slow
from hypothesis import given, settings, strategies as strat from hypothesis import given, settings, strategies as strat
import numpy as np import numpy as np
import torch import torch
@ -12,22 +11,10 @@ settings.register_profile("my_profile", max_examples=50, deadline=None, derandom
settings.load_profile("my_profile") settings.load_profile("my_profile")
core_dtypes = list(DTYPES_DICT.values()) core_dtypes = list(DTYPES_DICT.values())
dtype_ints = [dt for dt in core_dtypes if dtypes.is_int(dt) and is_dtype_supported(dt)]
dtype_floats = [dt for dt in core_dtypes if dtypes.is_float(dt) and is_dtype_supported(dt)]
FP8E4M3_MAX = 448.0 FP8E4M3_MAX = 448.0
FP8E5M2_MAX = 57344.0 FP8E5M2_MAX = 57344.0
def _assert_eq(tensor:Tensor, target_dtype:DType, target, tol_target_dtype:float=1e-7):
if DEBUG >= 2: print(tensor.numpy())
try:
assert tensor.dtype == target_dtype
np.testing.assert_allclose(tensor.numpy(), target, rtol={dtypes.float16:1e-3, dtypes.bfloat16:1e-2,
dtypes.fp8e4m3:1e-1, dtypes.fp8e5m2:5e-1}.get(target_dtype, tol_target_dtype))
except AssertionError as e:
raise AssertionError(f"\ntensor {tensor.numpy()} dtype {tensor.dtype} does not match target {target} with dtype {target_dtype}") from e
def u32_to_f32(u): return struct.unpack('f', struct.pack('I', u))[0] def u32_to_f32(u): return struct.unpack('f', struct.pack('I', u))[0]
def f32_to_u32(f): return struct.unpack('I', struct.pack('f', f))[0] def f32_to_u32(f): return struct.unpack('I', struct.pack('f', f))[0]
@ -202,163 +189,6 @@ class TestHelpers(unittest.TestCase):
elif x < -FP8E5M2_MAX: np.testing.assert_equal(truncate[dtypes.fp8e5m2](x), -FP8E5M2_MAX) elif x < -FP8E5M2_MAX: np.testing.assert_equal(truncate[dtypes.fp8e5m2](x), -FP8E5M2_MAX)
else: np.testing.assert_equal(truncate[dtypes.fp8e5m2](x), torch.tensor(x, dtype=torch.float8_e5m2).float().item()) else: np.testing.assert_equal(truncate[dtypes.fp8e5m2](x), torch.tensor(x, dtype=torch.float8_e5m2).float().item())
class TestTypeSpec(unittest.TestCase):
def setUp(self):
self.old_default_int, self.old_default_float = dtypes.default_int, dtypes.default_float
def tearDown(self):
dtypes.default_int, dtypes.default_float = self.old_default_int, self.old_default_float
def test_set_dtype_default(self):
for default_int in [dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64]:
dtypes.default_int = default_int
assert dtypes.default_int == default_int
for default_float in [*dtypes.fp8s, dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64]:
dtypes.default_float = default_float
assert dtypes.default_float == default_float
@unittest.skip("this test is slow and spawning whole pythons")
def test_env_set_default_float(self):
# check default
subprocess.run(['python3 -c "from tinygrad import dtypes; assert dtypes.default_float == dtypes.float"'],
shell=True, check=True)
# check change
subprocess.run(['DEFAULT_FLOAT=HALF python3 -c "from tinygrad import dtypes; assert dtypes.default_float == dtypes.half"'],
shell=True, check=True)
# check invalid
with self.assertRaises(subprocess.CalledProcessError):
subprocess.run(['DEFAULT_FLOAT=INT32 python3 -c "from tinygrad import dtypes"'],
shell=True, check=True)
with self.assertRaises(subprocess.CalledProcessError):
subprocess.run(['DEFAULT_FLOAT=TYPO python3 -c "from tinygrad import dtypes"'],
shell=True, check=True)
@unittest.skipUnless(is_dtype_supported(dtypes.int8), f"no int8 on {Device.DEFAULT}")
def test_dtype_str_arg(self):
n = np.random.normal(0, 1, (10, 10)).astype(np.float32)
tested = 0
for dtype_str, dtype in [
("bool", dtypes.bool), ("int8", dtypes.int8), ("int", dtypes.int), ("uint32", dtypes.uint32), ("float32", dtypes.float32)]:
np.testing.assert_equal(Tensor(n, dtype=dtype_str).numpy(), Tensor(n, dtype=dtype).numpy())
np.testing.assert_equal(Tensor(n).cast(dtype_str).numpy(), Tensor(n).cast(dtype).numpy())
if dtype.itemsize == 4:
np.testing.assert_equal(Tensor(n).bitcast(dtype_str).numpy(), Tensor(n).bitcast(dtype).numpy())
tested += 1
assert tested == 3
with self.assertRaises(AttributeError): Tensor([1, 2, 3], dtype="nonexistdtype")
with self.assertRaises(AttributeError): Tensor([1, 2, 3], dtype="")
np.testing.assert_equal(Tensor(n).sum(dtype="int16").numpy(), Tensor(n).sum(dtype=dtypes.int16).numpy())
@given(strat.sampled_from(dtype_ints), strat.sampled_from(dtype_floats))
def test_creation(self, default_int, default_float):
dtypes.default_int, dtypes.default_float = default_int, default_float
_assert_eq(Tensor(True), dtypes.bool, True)
_assert_eq(Tensor(None), dtypes.default_float, [])
_assert_eq(Tensor(2), dtypes.default_int, 2)
_assert_eq(Tensor(2.34), dtypes.default_float, 2.34)
_assert_eq(Tensor([]), dtypes.default_float, [])
_assert_eq(Tensor([1]), dtypes.default_int, [1])
_assert_eq(Tensor([1.1]), dtypes.default_float, [1.1])
_assert_eq(Tensor.eye(0), dtypes.default_float, np.eye(0))
_assert_eq(Tensor.eye(3), dtypes.default_float, np.eye(3))
if is_dtype_supported(dtypes.int64):
_assert_eq(Tensor.eye(3, dtype=dtypes.int64), dtypes.int64, np.eye(3))
if is_dtype_supported(dtypes.float16):
_assert_eq(Tensor.eye(3, dtype=dtypes.float16), dtypes.float16, np.eye(3))
@given(strat.sampled_from(dtype_ints), strat.sampled_from(dtype_floats))
def test_full(self, default_int, default_float):
dtypes.default_int, dtypes.default_float = default_int, default_float
_assert_eq(Tensor.zeros((2, 3)), dtypes.default_float, np.zeros((2, 3)))
if is_dtype_supported(dtypes.int64):
_assert_eq(Tensor.zeros((2, 3), dtype=dtypes.int64), dtypes.int64, np.zeros((2, 3)))
if is_dtype_supported(dtypes.float16):
_assert_eq(Tensor.zeros((2, 3), dtype=dtypes.float16), dtypes.float16, np.zeros((2, 3)))
_assert_eq(Tensor.ones((2, 3)), dtypes.default_float, np.ones((2, 3)))
if is_dtype_supported(dtypes.int64):
_assert_eq(Tensor.ones((2, 3), dtype=dtypes.int64), dtypes.int64, np.ones((2, 3)))
if is_dtype_supported(dtypes.float16):
_assert_eq(Tensor.ones((2, 3), dtype=dtypes.float16), dtypes.float16, np.ones((2, 3)))
_assert_eq(Tensor.full((2, 3), 3.0), dtypes.default_float, np.full((2, 3), 3.0))
_assert_eq(Tensor.full((2, 3), 3), dtypes.default_int, np.full((2, 3), 3))
_assert_eq(Tensor.full((2, 3), True), dtypes.bool, np.full((2, 3), True))
if is_dtype_supported(dtypes.int64):
_assert_eq(Tensor.full((2, 3), 3, dtype=dtypes.int64), dtypes.int64, np.full((2, 3), 3))
_assert_eq(Tensor.full((2, 3), 3.0, dtype=dtypes.int64), dtypes.int64, np.full((2, 3), 3))
if is_dtype_supported(dtypes.float16):
_assert_eq(Tensor.full((2, 3), 3, dtype=dtypes.float16), dtypes.float16, np.full((2, 3), 3))
_assert_eq(Tensor.full((2, 3), 3.0, dtype=dtypes.float16), dtypes.float16, np.full((2, 3), 3))
@given(strat.sampled_from(dtype_ints), strat.sampled_from(dtype_floats))
def test_reduce_0d_default(self, default_int, default_float):
dtypes.default_int, dtypes.default_float = default_int, default_float
_assert_eq(Tensor.ones((2,3,0)).sum(2), dtypes.default_float, np.zeros((2, 3)))
# TODO: what should this one be?
# _assert_eq(Tensor.ones((2,3,0), dtype=dtypes.default_int).sum(2), dtypes.default_int, np.zeros((2, 3)))
_assert_eq(Tensor.ones((2,3,0), dtype=dtypes.int32).sum(2), dtypes.int32, np.zeros((2, 3)))
@given(strat.sampled_from(dtype_ints), strat.sampled_from(dtype_floats))
def test_arange(self, default_int, default_float):
dtypes.default_int, dtypes.default_float = default_int, default_float
_assert_eq(Tensor.arange(5), dtypes.default_int, np.arange(5))
_assert_eq(Tensor.arange(120), dtypes.default_int, np.arange(120))
_assert_eq(Tensor.arange(5.0), dtypes.default_float, np.arange(5))
if is_dtype_supported(dtypes.int16):
_assert_eq(Tensor.arange(5, dtype=dtypes.int16), dtypes.int16, np.arange(5))
if is_dtype_supported(dtypes.int64):
_assert_eq(Tensor.arange(5, dtype=dtypes.int64), dtypes.int64, np.arange(5))
if is_dtype_supported(dtypes.float16):
_assert_eq(Tensor.arange(5, dtype=dtypes.float16), dtypes.float16, np.arange(5))
_assert_eq(Tensor.arange(3, 9, 0.7), dtypes.default_float, np.arange(3, 9, 0.7), 1e-6 if Device.DEFAULT == "WEBGPU" else 1e-7)
_assert_eq(Tensor.arange(3, 8.5, 3), dtypes.default_float, np.arange(3, 8.5, 3))
# stop-start and step have different signs
_assert_eq(Tensor.arange(3, 5, -2), dtypes.default_int, np.arange(3, 5, -2))
_assert_eq(Tensor.arange(5.0, 3.0), dtypes.default_float, np.arange(5.0, 3.0))
@given(strat.sampled_from(core_dtypes), strat.sampled_from([operator.gt, operator.ge, operator.le, operator.lt, operator.eq, operator.ne]))
def test_bool_ops(self, dtype, op):
assert op(Tensor.ones(4, 4, dtype=dtype), Tensor.ones(4, 4, dtype=dtype)).dtype == dtypes.bool
@given(strat.sampled_from(core_dtypes), strat.sampled_from(dtype_ints), strat.sampled_from(dtype_floats))
def test_functions_return_index(self, dtype, default_int, default_float):
dtypes.default_int, dtypes.default_float = default_int, default_float
assert Tensor([0, 1], dtype=dtype).argmax().dtype == dtypes.int32
assert Tensor([0, 1], dtype=dtype).argmin().dtype == dtypes.int32
assert Tensor([0, 1], dtype=dtype).multinomial().dtype == dtypes.int32
@given(strat.sampled_from(core_dtypes), strat.sampled_from(dtype_ints))
def test_tensor_indexing_returns_same_dtype(self, data_dtype, indices_dtype):
X_data = Tensor.ones(60000, 1, 28, 28, dtype=data_dtype)
indices = Tensor.randint(512, high=X_data.shape[0]).cast(indices_dtype)
assert X_data[indices].dtype == X_data.dtype
@given(strat.sampled_from(core_dtypes), strat.sampled_from(dtype_ints))
def test_gather_returns_same_dtype(self, data_dtype, indices_dtype):
X_data = Tensor([[1, 0], [0, 1]], dtype=data_dtype)
indices = Tensor([[0, 0], [1, 0]], dtype=indices_dtype)
assert X_data.gather(0, indices).dtype == X_data.dtype
assert X_data.gather(1, indices).dtype == X_data.dtype
@given(strat.sampled_from(dtype_floats), strat.sampled_from(dtype_floats))
def test_attention_returns_same_dtype(self, data_dtype, default_float):
dtypes.default_float = default_float
query = Tensor.rand(32, 8, 128, 64, dtype=data_dtype)
key = Tensor.rand(32, 8, 128, 64, dtype=data_dtype)
value = Tensor.rand(32, 8, 128, 64, dtype=data_dtype)
mask = (Tensor.rand(32, 8, 128, 128) < 0.5)
assert query.scaled_dot_product_attention(key, value, is_causal=True).dtype == data_dtype
assert query.scaled_dot_product_attention(key, value, is_causal=True, dropout_p=0.3).dtype == data_dtype
assert query.scaled_dot_product_attention(key, value, is_causal=False).dtype == data_dtype
assert query.scaled_dot_product_attention(key, value, attn_mask=mask).dtype == data_dtype
class TestTypePromotion(unittest.TestCase): class TestTypePromotion(unittest.TestCase):
@given(strat.sampled_from(core_dtypes)) @given(strat.sampled_from(core_dtypes))
def test_self_promo_to_self(self, dtype): def test_self_promo_to_self(self, dtype):
@ -398,237 +228,5 @@ class TestTypePromotion(unittest.TestCase):
assert least_upper_dtype(dtypes.fp8e5m2, dtypes.int64) == dtypes.fp8e5m2 assert least_upper_dtype(dtypes.fp8e5m2, dtypes.int64) == dtypes.fp8e5m2
assert least_upper_dtype(dtypes.fp8e5m2, dtypes.uint64) == dtypes.fp8e5m2 assert least_upper_dtype(dtypes.fp8e5m2, dtypes.uint64) == dtypes.fp8e5m2
class TestAutoCastType(unittest.TestCase): if __name__ == '__main__':
def setUp(self): unittest.main()
self.old_default_int, self.old_default_float = dtypes.default_int, dtypes.default_float
def tearDown(self):
dtypes.default_int, dtypes.default_float = self.old_default_int, self.old_default_float
@given(strat.sampled_from(dtype_floats), strat.sampled_from(dtype_floats))
def test_least_upper_float_input_is_float(self, input_dtype, default_float):
dtypes.default_float = default_float
self.assertEqual(least_upper_float(input_dtype), input_dtype)
@given(strat.sampled_from(dtype_ints), strat.sampled_from(dtype_floats))
def test_least_upper_float_input_is_int(self, input_dtype, default_float):
dtypes.default_float = default_float
self.assertEqual(least_upper_float(input_dtype), default_float)
@given(strat.sampled_from([d for d in core_dtypes if dtypes.is_int(d) and is_dtype_supported(d)]))
def test_int_to_float_unary_func(self, dtype):
for func in [
lambda t: t.exp(),
lambda t: t.exp2(),
lambda t: t.log(),
lambda t: t.log2(),
lambda t: t.sqrt(),
lambda t: t.rsqrt(),
lambda t: t.sin(),
lambda t: t.cos(),
lambda t: t.tan(),
lambda t: t.sigmoid(),
]:
a = [2, 3, 4]
# float16 can have larger precision errors
np.testing.assert_allclose(func(Tensor(a, dtype=dtype)).numpy(), func(torch.tensor(a)), rtol=1e-3, atol=1e-3)
@given(strat.sampled_from(core_dtypes))
def test_broadcast_scalar(self, dt):
assert (Tensor.ones(4, 4, dtype=dt) + 2.3).dtype == (dt if dtypes.is_float(dt) else dtypes.default_float)
assert (Tensor.ones(4, 4, dtype=dt) + 2).dtype == (dt if dtypes.is_float(dt) or dtypes.is_int(dt) else dtypes.default_int)
assert (Tensor.ones(4, 4, dtype=dt) + True).dtype == dt
@given(strat.sampled_from(dtype_floats))
def test_int_div_int(self, default_float):
dtypes.default_float = default_float
self.assertEqual(Tensor([1]).div(Tensor([2])).dtype, default_float)
def test_sum(self):
assert (Tensor([0, 1], dtype=dtypes.bool)).sum().dtype == dtypes.int32
assert (Tensor([0, 1], dtype=dtypes.int8)).sum().dtype == dtypes.int32
assert (Tensor([0, 1], dtype=dtypes.int16)).sum().dtype == dtypes.int32
assert (Tensor([0, 1], dtype=dtypes.int32)).sum().dtype == dtypes.int32
assert (Tensor([0, 1], dtype=dtypes.int64)).sum().dtype == dtypes.int64
assert (Tensor([0, 1], dtype=dtypes.uint8)).sum().dtype == dtypes.uint32
assert (Tensor([0, 1], dtype=dtypes.uint16)).sum().dtype == dtypes.uint32
assert (Tensor([0, 1], dtype=dtypes.uint32)).sum().dtype == dtypes.uint32
assert (Tensor([0, 1], dtype=dtypes.uint64)).sum().dtype == dtypes.uint64
assert (Tensor([0, 1], dtype=dtypes.fp8e4m3)).sum().dtype == dtypes.fp8e4m3
assert (Tensor([0, 1], dtype=dtypes.fp8e5m2)).sum().dtype == dtypes.fp8e5m2
assert (Tensor([0, 1], dtype=dtypes.float16)).sum().dtype == dtypes.float16
assert (Tensor([0, 1], dtype=dtypes.bfloat16)).sum().dtype == dtypes.bfloat16
assert (Tensor([0, 1], dtype=dtypes.float32)).sum().dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.float64)).sum().dtype == dtypes.float64
@unittest.skipUnless(is_dtype_supported(dtypes.float16), "need float16")
def test_sum_dtype_arg(self):
t = Tensor([40000, 40000], dtype=dtypes.float16)
# default float16 sum returns in float16, overflowed in this case
assert t.sum().dtype == dtypes.float16
assert math.isinf(t.sum().numpy().item())
# specifiying dtype and it's not downcasted
assert t.sum(dtype=dtypes.float32).dtype == dtypes.float32
np.testing.assert_allclose(t.sum(dtype=dtypes.float32).numpy(), 80000)
def test_prod_dtype_arg(self):
t = Tensor([100, 200], dtype=dtypes.int32)
assert t.prod().dtype == dtypes.int32
np.testing.assert_allclose(t.prod().numpy(), 20000)
assert t.prod(dtype=dtypes.float32).dtype == dtypes.float32
np.testing.assert_allclose(t.prod(dtype=dtypes.float32).numpy(), 20000)
def test_mean(self):
assert (Tensor([0, 1], dtype=dtypes.bool)).mean().dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.int8)).mean().dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.int16)).mean().dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.int32)).mean().dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.int64)).mean().dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.uint8)).mean().dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.uint16)).mean().dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.uint32)).mean().dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.uint64)).mean().dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.fp8e4m3)).mean().dtype == dtypes.fp8e4m3
assert (Tensor([0, 1], dtype=dtypes.fp8e5m2)).mean().dtype == dtypes.fp8e5m2
assert (Tensor([0, 1], dtype=dtypes.float16)).mean().dtype == dtypes.float16
assert (Tensor([0, 1], dtype=dtypes.bfloat16)).mean().dtype == dtypes.bfloat16
assert (Tensor([0, 1], dtype=dtypes.float32)).mean().dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.float64)).mean().dtype == dtypes.float64
def test_cumsum(self):
assert (Tensor([0, 1], dtype=dtypes.bool)).cumsum(0).dtype == dtypes.int32
assert (Tensor([0, 1], dtype=dtypes.int8)).cumsum(0).dtype == dtypes.int32
assert (Tensor([0, 1], dtype=dtypes.int16)).cumsum(0).dtype == dtypes.int32
assert (Tensor([0, 1], dtype=dtypes.int32)).cumsum(0).dtype == dtypes.int32
assert (Tensor([0, 1], dtype=dtypes.int64)).cumsum(0).dtype == dtypes.int64
assert (Tensor([0, 1], dtype=dtypes.uint8)).cumsum(0).dtype == dtypes.uint32
assert (Tensor([0, 1], dtype=dtypes.uint16)).cumsum(0).dtype == dtypes.uint32
assert (Tensor([0, 1], dtype=dtypes.uint32)).cumsum(0).dtype == dtypes.uint32
assert (Tensor([0, 1], dtype=dtypes.uint64)).cumsum(0).dtype == dtypes.uint64
assert (Tensor([0, 1], dtype=dtypes.fp8e4m3)).cumsum(0).dtype == dtypes.fp8e4m3
assert (Tensor([0, 1], dtype=dtypes.fp8e5m2)).cumsum(0).dtype == dtypes.fp8e5m2
assert (Tensor([0, 1], dtype=dtypes.float16)).cumsum(0).dtype == dtypes.float16
assert (Tensor([0, 1], dtype=dtypes.bfloat16)).cumsum(0).dtype == dtypes.bfloat16
assert (Tensor([0, 1], dtype=dtypes.float32)).cumsum(0).dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.float64)).cumsum(0).dtype == dtypes.float64
@given(strat.sampled_from(core_dtypes), strat.sampled_from(core_dtypes), strat.sampled_from(core_dtypes))
def test_matmul(self, dt1, dt2, acc_dt):
t1 = Tensor([0, 1], dtype=dt1)
t2 = Tensor([0, 1], dtype=dt2)
self.assertEqual(t1.matmul(t2).dtype, least_upper_dtype(t1.dtype, t2.dtype))
# if dtype is specified, return in dtype
self.assertEqual(t1.matmul(t2, dtype=acc_dt).dtype, acc_dt)
@given(strat.sampled_from(core_dtypes), strat.sampled_from(core_dtypes), strat.sampled_from(core_dtypes), strat.sampled_from(core_dtypes))
def test_linear(self, dt1, dt2, dt3, acc_dt):
x = Tensor([0, 1], dtype=dt1)
w = Tensor([0, 1], dtype=dt2)
b = Tensor([0, 1], dtype=dt3)
self.assertEqual(x.linear(w).dtype, least_upper_dtype(x.dtype, w.dtype))
self.assertEqual(x.linear(w, b).dtype, least_upper_dtype(least_upper_dtype(x.dtype, w.dtype), b.dtype))
# if dtype is specified, return in dtype
self.assertEqual(x.linear(w, dtype=acc_dt).dtype, acc_dt)
self.assertEqual(x.linear(w, b, dtype=acc_dt).dtype, acc_dt)
@staticmethod
def check_where_alternate_input_other(input_, other, data_type):
assert (Tensor([True, False]).where(input_, other)).dtype == data_type
assert (Tensor([True, False]).where(other, input_)).dtype == data_type
@given(strat.sampled_from(core_dtypes), strat.sampled_from(core_dtypes))
def test_where_no_scalar(self, dt1, dt2):
self.check_where_alternate_input_other(Tensor(2, dtype=dt1), Tensor(3, dtype=dt2), least_upper_dtype(dt1, dt2))
@given(strat.sampled_from(core_dtypes))
def test_where_one_scalar(self, dt):
t = Tensor(2, dtype=dt)
self.check_where_alternate_input_other(t, 3.2, (dt if dtypes.is_float(dt) else dtypes.default_float))
self.check_where_alternate_input_other(t, 3, (dt if dtypes.is_float(dt) or dtypes.is_int(dt) else dtypes.default_int))
self.check_where_alternate_input_other(t, True, dt)
def test_where_two_scalars(self):
self.check_where_alternate_input_other(3.1, 3.2, dtypes.default_float)
self.check_where_alternate_input_other(3.1, 3, dtypes.default_float)
self.check_where_alternate_input_other(3.1, True, dtypes.default_float)
self.check_where_alternate_input_other(3, 2, dtypes.default_int)
self.check_where_alternate_input_other(3, True, dtypes.default_int)
self.check_where_alternate_input_other(False, True, dtypes.bool)
@given(strat.sampled_from(core_dtypes), strat.sampled_from(core_dtypes))
def test_maximum(self, dt1, dt2):
assert Tensor([0, 1, 2], dtype=dt1).maximum(Tensor([2, 0, 5], dtype=dt2)).dtype == least_upper_dtype(dt1, dt2)
@given(strat.sampled_from(core_dtypes))
def test_maximum_const(self, dt):
assert Tensor([1, 2], dtype=dt).maximum(3.1).dtype == (dt if dtypes.is_float(dt) else dtypes.default_float)
assert Tensor([1, 2], dtype=dt).maximum(3).dtype == (dt if dtypes.is_float(dt) or dtypes.is_int(dt) else dtypes.default_int)
assert Tensor([1, 2], dtype=dt).maximum(True).dtype == dt
def test_div(self):
assert (Tensor([1, 2], dtype=dtypes.int32) / Tensor([2, 2], dtype=dtypes.int32)).dtype == dtypes.default_float
assert (Tensor([1, 2], dtype=dtypes.int16) / Tensor([2, 2], dtype=dtypes.int32)).dtype == dtypes.default_float
assert (Tensor([1, 2], dtype=dtypes.float32) / Tensor([2, 2], dtype=dtypes.float16)).dtype == dtypes.float32
assert (Tensor([1, 2], dtype=dtypes.int32) / Tensor([2, 2], dtype=dtypes.float16)).dtype == dtypes.float16
def test_div_const(self):
assert (Tensor([1, 2], dtype=dtypes.int32) / 2).dtype == dtypes.default_float
assert (Tensor([1, 2], dtype=dtypes.int32) / 2.0).dtype == dtypes.default_float
assert (Tensor([1, 2], dtype=dtypes.float16) / 2).dtype == dtypes.float16
assert (Tensor([1, 2], dtype=dtypes.float16) / 2.0).dtype == dtypes.float16
def test_gradient_dtype(self):
old_default_float = dtypes.default_float
for default_dtype in dtypes.floats:
if not is_dtype_supported(default_dtype): continue
dtypes.default_float = default_dtype
for dtype in dtypes.floats:
if not is_dtype_supported(dtype): continue
if DEBUG >= 2:
print(f"testing {default_dtype=}, {dtype=}")
a = Tensor([1, 2, 3], dtype=dtype, requires_grad=True)
b = (a * 5).sum()
b.backward() # if there is dtype mismatch, lazy should assert
assert a.grad.dtype == a.dtype
np.testing.assert_allclose(a.grad.numpy(), [5, 5, 5])
dtypes.default_float = old_default_float
@unittest.skipIf(Device.DEFAULT == "PYTHON", "very slow")
@slow
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "Binding size is larger than the maximum storage buffer binding size")
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
def test_mean_half_precision_underflow(self):
N = 10000
x = 0.001
t = Tensor([[x]], dtype=dtypes.half, requires_grad=True).expand(N, N).contiguous()
np.testing.assert_allclose(t.mean(axis=1).numpy(), np.array([x] * N, dtype=np.float16), rtol=1e-3)
@unittest.skip("this test only works with SPLIT_REDUCEOP=1")
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
def test_mean_half_precision_overflow(self):
N = 256
t = Tensor([60000] * N*N, dtype=dtypes.half, requires_grad=True).reshape(N, N)
np.testing.assert_allclose(t.mean().numpy(), 60000)
t.square().mean().backward()
np.testing.assert_allclose(t.grad.numpy().flatten(), [60000 * 2 / (N*N)] * N*N)
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "Precision error")
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
def test_softmax_dtype(self):
data = [1, 2, 3]
t = Tensor(data, dtype=dtypes.half)
tt = torch.tensor(data, dtype=torch.half)
out = t.softmax(0)
self.assertEqual(out.dtype, dtypes.half)
np.testing.assert_allclose(out.numpy(), tt.softmax(0).numpy(), rtol=1e-3)
out = t.softmax(0, dtype=dtypes.float)
self.assertEqual(out.dtype, dtypes.float)
np.testing.assert_allclose(out.numpy(), tt.softmax(0, dtype=torch.float).numpy(), rtol=1e-3)
out = t.log_softmax(0)
self.assertEqual(out.dtype, dtypes.half)
np.testing.assert_allclose(out.numpy(), tt.log_softmax(0).numpy(), rtol=1e-3)
out = t.log_softmax(0, dtype=dtypes.float)
self.assertEqual(out.dtype, dtypes.float)
np.testing.assert_allclose(out.numpy(), tt.log_softmax(0, dtype=torch.float).numpy(), rtol=1e-3)

View file

@ -1,7 +1,6 @@
from typing import Callable from typing import Callable
import unittest, math import unittest, math
import torch import torch
import numpy as np
from tinygrad import Tensor from tinygrad import Tensor
from tinygrad.dtype import dtypes from tinygrad.dtype import dtypes
from tinygrad.uop.ops import UOp from tinygrad.uop.ops import UOp
@ -63,71 +62,6 @@ class TestGradient(unittest.TestCase):
def test_big_chain(self): self._test_two_input_function(lambda x,y: (1.0/x*y)+x*y) def test_big_chain(self): self._test_two_input_function(lambda x,y: (1.0/x*y)+x*y)
def test_where(self): self._test_two_input_function(lambda x,y: (x<y).where(x,y), lambda x,y: torch.where(x<y,x,y)) def test_where(self): self._test_two_input_function(lambda x,y: (x<y).where(x,y), lambda x,y: torch.where(x<y,x,y))
class TestTensorGradient(unittest.TestCase):
def test_example(self):
x = Tensor.eye(3)
y = Tensor([[2.0,0,-2.0]])
z = y.matmul(x).sum()
dx, dy = z.gradient(x, y)
self.assertListEqual(dx.tolist(), [[2.0, 2.0, 2.0], [0.0, 0.0, 0.0], [-2.0, -2.0, -2.0]])
self.assertListEqual(dy.tolist(), [[1.0, 1.0, 1.0]])
def test_raises(self):
x = Tensor([1.0, 2.0, 3.0])
w = Tensor.randn((3,))
with self.assertRaises(RuntimeError): x.sum().gradient(w)
def test_with_custom_gradient(self):
x = Tensor([1.0, 2.0, 3.0])
z = (x * x).sum()
dx = z.gradient(x, gradient=Tensor([3.0]))[0]
self.assertListEqual(dx.tolist(), [6.0, 12.0, 18.0])
def test_broadcast_gradient(self):
x = Tensor([[1.0], [2.0], [3.0]])
y = Tensor([[10.0, 20.0, 30.0, 40.0]])
z = (x + y).sum()
dx, dy = z.gradient(x, y)
self.assertListEqual(dx.tolist(), [[4.0], [4.0], [4.0]])
self.assertListEqual(dy.tolist(), [[3.0, 3.0, 3.0, 3.0]])
def test_non_scalar_output(self):
x = Tensor([1.0, 2.0, 3.0])
z = x * x
with self.assertRaises(AssertionError): z.gradient(x)
dz = Tensor([1.0, 1.0, 1.0])
dx = z.gradient(x, gradient=dz)[0]
self.assertListEqual(dx.tolist(), [2.0, 4.0, 6.0])
def test_cast_before_view(self):
x = Tensor([1.0, 1, 1, 1])
x_reshaped = x.reshape(2,2)
x_casted = x_reshaped.cast(dtypes.float16)
x_casted.mean().gradient(x_reshaped)
def test_non_float_tensor_raise(self):
x = Tensor([1, 2, 3])
with self.assertRaises(RuntimeError): x.sum().gradient(x)
with self.assertRaises(RuntimeError): x.float().sum().gradient(x)
def test_copy_to_device_gradient(self):
t = Tensor([1.0, 2, 3], requires_grad=True).realize()
t.to("CPU:1").square().sum().backward()
self.assertEqual(t.grad.device, t.device)
self.assertListEqual(t.grad.tolist(), [2.0, 4.0, 6.0])
def test_multiple_backward(self):
x = Tensor([3.], requires_grad=True)
(x*2)[0].backward()
np.testing.assert_allclose(x.grad.numpy(), [2.0])
old_grad = x.grad
(x*3)[0].backward()
np.testing.assert_allclose(x.grad.numpy(), [2.0+3.0])
self.assertIs(x.grad, old_grad)
(x*x)[0].backward()
np.testing.assert_allclose(x.grad.numpy(), [2.0+3.0+2*3.0])
self.assertIs(x.grad, old_grad)
class TestRealizeMeansRealize(unittest.TestCase): class TestRealizeMeansRealize(unittest.TestCase):
def test_randn_realizes(self): def test_randn_realizes(self):
x = Tensor.randn(2, 3, 64, 64, requires_grad=True).realize() x = Tensor.randn(2, 3, 64, 64, requires_grad=True).realize()
@ -148,18 +82,5 @@ class TestRealizeMeansRealize(unittest.TestCase):
y = x * 2 y = x * 2
y.sum().gradient(x)[0].realize() y.sum().gradient(x)[0].realize()
class TestViewGradient(unittest.TestCase):
def test_expand(self):
# this test shows that if Tensors collapse to the views and create a disconnected graph
# there's no way to recover the proper gradient
x = Tensor.randn(5,2)
a = Tensor([3.], requires_grad=True)
aex = a.expand(10)
(aex.reshape(5,2) * x).sum().backward()
np.testing.assert_allclose(aex.grad.numpy(), x.reshape(10).numpy())
# NOTE: aex.grad is *not* a.grad.expand(10)!
with self.assertRaises(AssertionError):
np.testing.assert_allclose(aex.grad.numpy(), a.grad.expand(10).numpy())
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View file

@ -356,10 +356,6 @@ class TestPolyN(unittest.TestCase):
np.testing.assert_allclose(polyN(3.0, [1.0, -2.0, 1.0]), 4.0) np.testing.assert_allclose(polyN(3.0, [1.0, -2.0, 1.0]), 4.0)
np.testing.assert_allclose(polyN(4.0, [1.0, -2.0, 1.0]), 9.0) np.testing.assert_allclose(polyN(4.0, [1.0, -2.0, 1.0]), 9.0)
def test_tensor(self):
from tinygrad.tensor import Tensor
np.testing.assert_allclose(polyN(Tensor([1.0, 2.0, 3.0, 4.0]), [1.0, -2.0, 1.0]).numpy(), [0.0, 1.0, 4.0, 9.0])
def test_uop(self): def test_uop(self):
from tinygrad.dtype import dtypes from tinygrad.dtype import dtypes
from tinygrad.uop.ops import UOp from tinygrad.uop.ops import UOp