mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
1 commit
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
66634c643e |
24 changed files with 527 additions and 491 deletions
425
test/test_dtype_spec.py
Normal file
425
test/test_dtype_spec.py
Normal file
|
|
@ -0,0 +1,425 @@
|
|||
import unittest, math, operator, subprocess
|
||||
from tinygrad.tensor import Tensor, dtypes, Device
|
||||
from tinygrad.dtype import DType, DTYPES_DICT, least_upper_float
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.helpers import getenv, DEBUG
|
||||
from test.helpers import slow
|
||||
from hypothesis import given, settings, strategies as strat
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
settings.register_profile("my_profile", max_examples=50, deadline=None, derandomize=getenv("DERANDOMIZE_CI", False))
|
||||
settings.load_profile("my_profile")
|
||||
|
||||
core_dtypes = list(DTYPES_DICT.values())
|
||||
dtype_ints = [dt for dt in core_dtypes if dtypes.is_int(dt) and is_dtype_supported(dt)]
|
||||
dtype_floats = [dt for dt in core_dtypes if dtypes.is_float(dt) and is_dtype_supported(dt)]
|
||||
|
||||
def _assert_eq(tensor:Tensor, target_dtype:DType, target, tol_target_dtype:float=1e-7):
|
||||
if DEBUG >= 2: print(tensor.numpy())
|
||||
try:
|
||||
assert tensor.dtype == target_dtype
|
||||
np.testing.assert_allclose(tensor.numpy(), target, rtol={dtypes.float16:1e-3, dtypes.bfloat16:1e-2,
|
||||
dtypes.fp8e4m3:1e-1, dtypes.fp8e5m2:5e-1}.get(target_dtype, tol_target_dtype))
|
||||
|
||||
except AssertionError as e:
|
||||
raise AssertionError(f"\ntensor {tensor.numpy()} dtype {tensor.dtype} does not match target {target} with dtype {target_dtype}") from e
|
||||
|
||||
class TestTypeSpec(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.old_default_int, self.old_default_float = dtypes.default_int, dtypes.default_float
|
||||
def tearDown(self):
|
||||
dtypes.default_int, dtypes.default_float = self.old_default_int, self.old_default_float
|
||||
|
||||
def test_set_dtype_default(self):
|
||||
for default_int in [dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64]:
|
||||
dtypes.default_int = default_int
|
||||
assert dtypes.default_int == default_int
|
||||
|
||||
for default_float in [*dtypes.fp8s, dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64]:
|
||||
dtypes.default_float = default_float
|
||||
assert dtypes.default_float == default_float
|
||||
|
||||
@unittest.skip("this test is slow and spawning whole pythons")
|
||||
def test_env_set_default_float(self):
|
||||
# check default
|
||||
subprocess.run(['python3 -c "from tinygrad import dtypes; assert dtypes.default_float == dtypes.float"'],
|
||||
shell=True, check=True)
|
||||
# check change
|
||||
subprocess.run(['DEFAULT_FLOAT=HALF python3 -c "from tinygrad import dtypes; assert dtypes.default_float == dtypes.half"'],
|
||||
shell=True, check=True)
|
||||
# check invalid
|
||||
with self.assertRaises(subprocess.CalledProcessError):
|
||||
subprocess.run(['DEFAULT_FLOAT=INT32 python3 -c "from tinygrad import dtypes"'],
|
||||
shell=True, check=True)
|
||||
|
||||
with self.assertRaises(subprocess.CalledProcessError):
|
||||
subprocess.run(['DEFAULT_FLOAT=TYPO python3 -c "from tinygrad import dtypes"'],
|
||||
shell=True, check=True)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.int8), f"no int8 on {Device.DEFAULT}")
|
||||
def test_dtype_str_arg(self):
|
||||
n = np.random.normal(0, 1, (10, 10)).astype(np.float32)
|
||||
tested = 0
|
||||
for dtype_str, dtype in [
|
||||
("bool", dtypes.bool), ("int8", dtypes.int8), ("int", dtypes.int), ("uint32", dtypes.uint32), ("float32", dtypes.float32)]:
|
||||
np.testing.assert_equal(Tensor(n, dtype=dtype_str).numpy(), Tensor(n, dtype=dtype).numpy())
|
||||
np.testing.assert_equal(Tensor(n).cast(dtype_str).numpy(), Tensor(n).cast(dtype).numpy())
|
||||
if dtype.itemsize == 4:
|
||||
np.testing.assert_equal(Tensor(n).bitcast(dtype_str).numpy(), Tensor(n).bitcast(dtype).numpy())
|
||||
tested += 1
|
||||
assert tested == 3
|
||||
|
||||
with self.assertRaises(AttributeError): Tensor([1, 2, 3], dtype="nonexistdtype")
|
||||
with self.assertRaises(AttributeError): Tensor([1, 2, 3], dtype="")
|
||||
|
||||
np.testing.assert_equal(Tensor(n).sum(dtype="int16").numpy(), Tensor(n).sum(dtype=dtypes.int16).numpy())
|
||||
|
||||
@given(strat.sampled_from(dtype_ints), strat.sampled_from(dtype_floats))
|
||||
def test_creation(self, default_int, default_float):
|
||||
dtypes.default_int, dtypes.default_float = default_int, default_float
|
||||
_assert_eq(Tensor(True), dtypes.bool, True)
|
||||
_assert_eq(Tensor(None), dtypes.default_float, [])
|
||||
_assert_eq(Tensor(2), dtypes.default_int, 2)
|
||||
_assert_eq(Tensor(2.34), dtypes.default_float, 2.34)
|
||||
_assert_eq(Tensor([]), dtypes.default_float, [])
|
||||
_assert_eq(Tensor([1]), dtypes.default_int, [1])
|
||||
_assert_eq(Tensor([1.1]), dtypes.default_float, [1.1])
|
||||
|
||||
_assert_eq(Tensor.eye(0), dtypes.default_float, np.eye(0))
|
||||
_assert_eq(Tensor.eye(3), dtypes.default_float, np.eye(3))
|
||||
if is_dtype_supported(dtypes.int64):
|
||||
_assert_eq(Tensor.eye(3, dtype=dtypes.int64), dtypes.int64, np.eye(3))
|
||||
if is_dtype_supported(dtypes.float16):
|
||||
_assert_eq(Tensor.eye(3, dtype=dtypes.float16), dtypes.float16, np.eye(3))
|
||||
|
||||
@given(strat.sampled_from(dtype_ints), strat.sampled_from(dtype_floats))
|
||||
def test_full(self, default_int, default_float):
|
||||
dtypes.default_int, dtypes.default_float = default_int, default_float
|
||||
|
||||
_assert_eq(Tensor.zeros((2, 3)), dtypes.default_float, np.zeros((2, 3)))
|
||||
if is_dtype_supported(dtypes.int64):
|
||||
_assert_eq(Tensor.zeros((2, 3), dtype=dtypes.int64), dtypes.int64, np.zeros((2, 3)))
|
||||
if is_dtype_supported(dtypes.float16):
|
||||
_assert_eq(Tensor.zeros((2, 3), dtype=dtypes.float16), dtypes.float16, np.zeros((2, 3)))
|
||||
|
||||
_assert_eq(Tensor.ones((2, 3)), dtypes.default_float, np.ones((2, 3)))
|
||||
if is_dtype_supported(dtypes.int64):
|
||||
_assert_eq(Tensor.ones((2, 3), dtype=dtypes.int64), dtypes.int64, np.ones((2, 3)))
|
||||
if is_dtype_supported(dtypes.float16):
|
||||
_assert_eq(Tensor.ones((2, 3), dtype=dtypes.float16), dtypes.float16, np.ones((2, 3)))
|
||||
|
||||
_assert_eq(Tensor.full((2, 3), 3.0), dtypes.default_float, np.full((2, 3), 3.0))
|
||||
_assert_eq(Tensor.full((2, 3), 3), dtypes.default_int, np.full((2, 3), 3))
|
||||
_assert_eq(Tensor.full((2, 3), True), dtypes.bool, np.full((2, 3), True))
|
||||
if is_dtype_supported(dtypes.int64):
|
||||
_assert_eq(Tensor.full((2, 3), 3, dtype=dtypes.int64), dtypes.int64, np.full((2, 3), 3))
|
||||
_assert_eq(Tensor.full((2, 3), 3.0, dtype=dtypes.int64), dtypes.int64, np.full((2, 3), 3))
|
||||
if is_dtype_supported(dtypes.float16):
|
||||
_assert_eq(Tensor.full((2, 3), 3, dtype=dtypes.float16), dtypes.float16, np.full((2, 3), 3))
|
||||
_assert_eq(Tensor.full((2, 3), 3.0, dtype=dtypes.float16), dtypes.float16, np.full((2, 3), 3))
|
||||
|
||||
@given(strat.sampled_from(dtype_ints), strat.sampled_from(dtype_floats))
|
||||
def test_reduce_0d_default(self, default_int, default_float):
|
||||
dtypes.default_int, dtypes.default_float = default_int, default_float
|
||||
_assert_eq(Tensor.ones((2,3,0)).sum(2), dtypes.default_float, np.zeros((2, 3)))
|
||||
# TODO: what should this one be?
|
||||
# _assert_eq(Tensor.ones((2,3,0), dtype=dtypes.default_int).sum(2), dtypes.default_int, np.zeros((2, 3)))
|
||||
_assert_eq(Tensor.ones((2,3,0), dtype=dtypes.int32).sum(2), dtypes.int32, np.zeros((2, 3)))
|
||||
|
||||
@given(strat.sampled_from(dtype_ints), strat.sampled_from(dtype_floats))
|
||||
def test_arange(self, default_int, default_float):
|
||||
dtypes.default_int, dtypes.default_float = default_int, default_float
|
||||
|
||||
_assert_eq(Tensor.arange(5), dtypes.default_int, np.arange(5))
|
||||
_assert_eq(Tensor.arange(120), dtypes.default_int, np.arange(120))
|
||||
_assert_eq(Tensor.arange(5.0), dtypes.default_float, np.arange(5))
|
||||
if is_dtype_supported(dtypes.int16):
|
||||
_assert_eq(Tensor.arange(5, dtype=dtypes.int16), dtypes.int16, np.arange(5))
|
||||
if is_dtype_supported(dtypes.int64):
|
||||
_assert_eq(Tensor.arange(5, dtype=dtypes.int64), dtypes.int64, np.arange(5))
|
||||
if is_dtype_supported(dtypes.float16):
|
||||
_assert_eq(Tensor.arange(5, dtype=dtypes.float16), dtypes.float16, np.arange(5))
|
||||
_assert_eq(Tensor.arange(3, 9, 0.7), dtypes.default_float, np.arange(3, 9, 0.7), 1e-6 if Device.DEFAULT == "WEBGPU" else 1e-7)
|
||||
_assert_eq(Tensor.arange(3, 8.5, 3), dtypes.default_float, np.arange(3, 8.5, 3))
|
||||
# stop-start and step have different signs
|
||||
_assert_eq(Tensor.arange(3, 5, -2), dtypes.default_int, np.arange(3, 5, -2))
|
||||
_assert_eq(Tensor.arange(5.0, 3.0), dtypes.default_float, np.arange(5.0, 3.0))
|
||||
|
||||
@given(strat.sampled_from(core_dtypes), strat.sampled_from([operator.gt, operator.ge, operator.le, operator.lt, operator.eq, operator.ne]))
|
||||
def test_bool_ops(self, dtype, op):
|
||||
assert op(Tensor.ones(4, 4, dtype=dtype), Tensor.ones(4, 4, dtype=dtype)).dtype == dtypes.bool
|
||||
|
||||
@given(strat.sampled_from(core_dtypes), strat.sampled_from(dtype_ints), strat.sampled_from(dtype_floats))
|
||||
def test_functions_return_index(self, dtype, default_int, default_float):
|
||||
dtypes.default_int, dtypes.default_float = default_int, default_float
|
||||
assert Tensor([0, 1], dtype=dtype).argmax().dtype == dtypes.int32
|
||||
assert Tensor([0, 1], dtype=dtype).argmin().dtype == dtypes.int32
|
||||
assert Tensor([0, 1], dtype=dtype).multinomial().dtype == dtypes.int32
|
||||
|
||||
@given(strat.sampled_from(core_dtypes), strat.sampled_from(dtype_ints))
|
||||
def test_tensor_indexing_returns_same_dtype(self, data_dtype, indices_dtype):
|
||||
X_data = Tensor.ones(60000, 1, 28, 28, dtype=data_dtype)
|
||||
indices = Tensor.randint(512, high=X_data.shape[0]).cast(indices_dtype)
|
||||
assert X_data[indices].dtype == X_data.dtype
|
||||
|
||||
@given(strat.sampled_from(core_dtypes), strat.sampled_from(dtype_ints))
|
||||
def test_gather_returns_same_dtype(self, data_dtype, indices_dtype):
|
||||
X_data = Tensor([[1, 0], [0, 1]], dtype=data_dtype)
|
||||
indices = Tensor([[0, 0], [1, 0]], dtype=indices_dtype)
|
||||
assert X_data.gather(0, indices).dtype == X_data.dtype
|
||||
assert X_data.gather(1, indices).dtype == X_data.dtype
|
||||
|
||||
@given(strat.sampled_from(dtype_floats), strat.sampled_from(dtype_floats))
|
||||
def test_attention_returns_same_dtype(self, data_dtype, default_float):
|
||||
dtypes.default_float = default_float
|
||||
query = Tensor.rand(32, 8, 128, 64, dtype=data_dtype)
|
||||
key = Tensor.rand(32, 8, 128, 64, dtype=data_dtype)
|
||||
value = Tensor.rand(32, 8, 128, 64, dtype=data_dtype)
|
||||
mask = (Tensor.rand(32, 8, 128, 128) < 0.5)
|
||||
assert query.scaled_dot_product_attention(key, value, is_causal=True).dtype == data_dtype
|
||||
assert query.scaled_dot_product_attention(key, value, is_causal=True, dropout_p=0.3).dtype == data_dtype
|
||||
assert query.scaled_dot_product_attention(key, value, is_causal=False).dtype == data_dtype
|
||||
assert query.scaled_dot_product_attention(key, value, attn_mask=mask).dtype == data_dtype
|
||||
|
||||
class TestAutoCastType(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.old_default_int, self.old_default_float = dtypes.default_int, dtypes.default_float
|
||||
def tearDown(self):
|
||||
dtypes.default_int, dtypes.default_float = self.old_default_int, self.old_default_float
|
||||
|
||||
@given(strat.sampled_from(dtype_floats), strat.sampled_from(dtype_floats))
|
||||
def test_least_upper_float_input_is_float(self, input_dtype, default_float):
|
||||
dtypes.default_float = default_float
|
||||
self.assertEqual(least_upper_float(input_dtype), input_dtype)
|
||||
|
||||
@given(strat.sampled_from(dtype_ints), strat.sampled_from(dtype_floats))
|
||||
def test_least_upper_float_input_is_int(self, input_dtype, default_float):
|
||||
dtypes.default_float = default_float
|
||||
self.assertEqual(least_upper_float(input_dtype), default_float)
|
||||
|
||||
@given(strat.sampled_from([d for d in core_dtypes if dtypes.is_int(d) and is_dtype_supported(d)]))
|
||||
def test_int_to_float_unary_func(self, dtype):
|
||||
for func in [
|
||||
lambda t: t.exp(),
|
||||
lambda t: t.exp2(),
|
||||
lambda t: t.log(),
|
||||
lambda t: t.log2(),
|
||||
lambda t: t.sqrt(),
|
||||
lambda t: t.rsqrt(),
|
||||
lambda t: t.sin(),
|
||||
lambda t: t.cos(),
|
||||
lambda t: t.tan(),
|
||||
lambda t: t.sigmoid(),
|
||||
]:
|
||||
a = [2, 3, 4]
|
||||
# float16 can have larger precision errors
|
||||
np.testing.assert_allclose(func(Tensor(a, dtype=dtype)).numpy(), func(torch.tensor(a)), rtol=1e-3, atol=1e-3)
|
||||
|
||||
@given(strat.sampled_from(core_dtypes))
|
||||
def test_broadcast_scalar(self, dt):
|
||||
assert (Tensor.ones(4, 4, dtype=dt) + 2.3).dtype == (dt if dtypes.is_float(dt) else dtypes.default_float)
|
||||
assert (Tensor.ones(4, 4, dtype=dt) + 2).dtype == (dt if dtypes.is_float(dt) or dtypes.is_int(dt) else dtypes.default_int)
|
||||
assert (Tensor.ones(4, 4, dtype=dt) + True).dtype == dt
|
||||
|
||||
@given(strat.sampled_from(dtype_floats))
|
||||
def test_int_div_int(self, default_float):
|
||||
dtypes.default_float = default_float
|
||||
self.assertEqual(Tensor([1]).div(Tensor([2])).dtype, default_float)
|
||||
|
||||
def test_sum(self):
|
||||
assert (Tensor([0, 1], dtype=dtypes.bool)).sum().dtype == dtypes.int32
|
||||
assert (Tensor([0, 1], dtype=dtypes.int8)).sum().dtype == dtypes.int32
|
||||
assert (Tensor([0, 1], dtype=dtypes.int16)).sum().dtype == dtypes.int32
|
||||
assert (Tensor([0, 1], dtype=dtypes.int32)).sum().dtype == dtypes.int32
|
||||
assert (Tensor([0, 1], dtype=dtypes.int64)).sum().dtype == dtypes.int64
|
||||
assert (Tensor([0, 1], dtype=dtypes.uint8)).sum().dtype == dtypes.uint32
|
||||
assert (Tensor([0, 1], dtype=dtypes.uint16)).sum().dtype == dtypes.uint32
|
||||
assert (Tensor([0, 1], dtype=dtypes.uint32)).sum().dtype == dtypes.uint32
|
||||
assert (Tensor([0, 1], dtype=dtypes.uint64)).sum().dtype == dtypes.uint64
|
||||
assert (Tensor([0, 1], dtype=dtypes.fp8e4m3)).sum().dtype == dtypes.fp8e4m3
|
||||
assert (Tensor([0, 1], dtype=dtypes.fp8e5m2)).sum().dtype == dtypes.fp8e5m2
|
||||
assert (Tensor([0, 1], dtype=dtypes.float16)).sum().dtype == dtypes.float16
|
||||
assert (Tensor([0, 1], dtype=dtypes.bfloat16)).sum().dtype == dtypes.bfloat16
|
||||
assert (Tensor([0, 1], dtype=dtypes.float32)).sum().dtype == dtypes.float32
|
||||
assert (Tensor([0, 1], dtype=dtypes.float64)).sum().dtype == dtypes.float64
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.float16), "need float16")
|
||||
def test_sum_dtype_arg(self):
|
||||
t = Tensor([40000, 40000], dtype=dtypes.float16)
|
||||
# default float16 sum returns in float16, overflowed in this case
|
||||
assert t.sum().dtype == dtypes.float16
|
||||
assert math.isinf(t.sum().numpy().item())
|
||||
# specifiying dtype and it's not downcasted
|
||||
assert t.sum(dtype=dtypes.float32).dtype == dtypes.float32
|
||||
np.testing.assert_allclose(t.sum(dtype=dtypes.float32).numpy(), 80000)
|
||||
|
||||
def test_prod_dtype_arg(self):
|
||||
t = Tensor([100, 200], dtype=dtypes.int32)
|
||||
assert t.prod().dtype == dtypes.int32
|
||||
np.testing.assert_allclose(t.prod().numpy(), 20000)
|
||||
assert t.prod(dtype=dtypes.float32).dtype == dtypes.float32
|
||||
np.testing.assert_allclose(t.prod(dtype=dtypes.float32).numpy(), 20000)
|
||||
|
||||
def test_mean(self):
|
||||
assert (Tensor([0, 1], dtype=dtypes.bool)).mean().dtype == dtypes.float32
|
||||
assert (Tensor([0, 1], dtype=dtypes.int8)).mean().dtype == dtypes.float32
|
||||
assert (Tensor([0, 1], dtype=dtypes.int16)).mean().dtype == dtypes.float32
|
||||
assert (Tensor([0, 1], dtype=dtypes.int32)).mean().dtype == dtypes.float32
|
||||
assert (Tensor([0, 1], dtype=dtypes.int64)).mean().dtype == dtypes.float32
|
||||
assert (Tensor([0, 1], dtype=dtypes.uint8)).mean().dtype == dtypes.float32
|
||||
assert (Tensor([0, 1], dtype=dtypes.uint16)).mean().dtype == dtypes.float32
|
||||
assert (Tensor([0, 1], dtype=dtypes.uint32)).mean().dtype == dtypes.float32
|
||||
assert (Tensor([0, 1], dtype=dtypes.uint64)).mean().dtype == dtypes.float32
|
||||
assert (Tensor([0, 1], dtype=dtypes.fp8e4m3)).mean().dtype == dtypes.fp8e4m3
|
||||
assert (Tensor([0, 1], dtype=dtypes.fp8e5m2)).mean().dtype == dtypes.fp8e5m2
|
||||
assert (Tensor([0, 1], dtype=dtypes.float16)).mean().dtype == dtypes.float16
|
||||
assert (Tensor([0, 1], dtype=dtypes.bfloat16)).mean().dtype == dtypes.bfloat16
|
||||
assert (Tensor([0, 1], dtype=dtypes.float32)).mean().dtype == dtypes.float32
|
||||
assert (Tensor([0, 1], dtype=dtypes.float64)).mean().dtype == dtypes.float64
|
||||
|
||||
def test_cumsum(self):
|
||||
assert (Tensor([0, 1], dtype=dtypes.bool)).cumsum(0).dtype == dtypes.int32
|
||||
assert (Tensor([0, 1], dtype=dtypes.int8)).cumsum(0).dtype == dtypes.int32
|
||||
assert (Tensor([0, 1], dtype=dtypes.int16)).cumsum(0).dtype == dtypes.int32
|
||||
assert (Tensor([0, 1], dtype=dtypes.int32)).cumsum(0).dtype == dtypes.int32
|
||||
assert (Tensor([0, 1], dtype=dtypes.int64)).cumsum(0).dtype == dtypes.int64
|
||||
assert (Tensor([0, 1], dtype=dtypes.uint8)).cumsum(0).dtype == dtypes.uint32
|
||||
assert (Tensor([0, 1], dtype=dtypes.uint16)).cumsum(0).dtype == dtypes.uint32
|
||||
assert (Tensor([0, 1], dtype=dtypes.uint32)).cumsum(0).dtype == dtypes.uint32
|
||||
assert (Tensor([0, 1], dtype=dtypes.uint64)).cumsum(0).dtype == dtypes.uint64
|
||||
assert (Tensor([0, 1], dtype=dtypes.fp8e4m3)).cumsum(0).dtype == dtypes.fp8e4m3
|
||||
assert (Tensor([0, 1], dtype=dtypes.fp8e5m2)).cumsum(0).dtype == dtypes.fp8e5m2
|
||||
assert (Tensor([0, 1], dtype=dtypes.float16)).cumsum(0).dtype == dtypes.float16
|
||||
assert (Tensor([0, 1], dtype=dtypes.bfloat16)).cumsum(0).dtype == dtypes.bfloat16
|
||||
assert (Tensor([0, 1], dtype=dtypes.float32)).cumsum(0).dtype == dtypes.float32
|
||||
assert (Tensor([0, 1], dtype=dtypes.float64)).cumsum(0).dtype == dtypes.float64
|
||||
|
||||
@given(strat.sampled_from(core_dtypes), strat.sampled_from(core_dtypes), strat.sampled_from(core_dtypes))
|
||||
def test_matmul(self, dt1, dt2, acc_dt):
|
||||
t1 = Tensor([0, 1], dtype=dt1)
|
||||
t2 = Tensor([0, 1], dtype=dt2)
|
||||
from tinygrad.dtype import least_upper_dtype
|
||||
self.assertEqual(t1.matmul(t2).dtype, least_upper_dtype(t1.dtype, t2.dtype))
|
||||
# if dtype is specified, return in dtype
|
||||
self.assertEqual(t1.matmul(t2, dtype=acc_dt).dtype, acc_dt)
|
||||
|
||||
@given(strat.sampled_from(core_dtypes), strat.sampled_from(core_dtypes), strat.sampled_from(core_dtypes), strat.sampled_from(core_dtypes))
|
||||
def test_linear(self, dt1, dt2, dt3, acc_dt):
|
||||
x = Tensor([0, 1], dtype=dt1)
|
||||
w = Tensor([0, 1], dtype=dt2)
|
||||
b = Tensor([0, 1], dtype=dt3)
|
||||
from tinygrad.dtype import least_upper_dtype
|
||||
self.assertEqual(x.linear(w).dtype, least_upper_dtype(x.dtype, w.dtype))
|
||||
self.assertEqual(x.linear(w, b).dtype, least_upper_dtype(least_upper_dtype(x.dtype, w.dtype), b.dtype))
|
||||
# if dtype is specified, return in dtype
|
||||
self.assertEqual(x.linear(w, dtype=acc_dt).dtype, acc_dt)
|
||||
self.assertEqual(x.linear(w, b, dtype=acc_dt).dtype, acc_dt)
|
||||
|
||||
@staticmethod
|
||||
def check_where_alternate_input_other(input_, other, data_type):
|
||||
assert (Tensor([True, False]).where(input_, other)).dtype == data_type
|
||||
assert (Tensor([True, False]).where(other, input_)).dtype == data_type
|
||||
|
||||
@given(strat.sampled_from(core_dtypes), strat.sampled_from(core_dtypes))
|
||||
def test_where_no_scalar(self, dt1, dt2):
|
||||
from tinygrad.dtype import least_upper_dtype
|
||||
self.check_where_alternate_input_other(Tensor(2, dtype=dt1), Tensor(3, dtype=dt2), least_upper_dtype(dt1, dt2))
|
||||
|
||||
@given(strat.sampled_from(core_dtypes))
|
||||
def test_where_one_scalar(self, dt):
|
||||
t = Tensor(2, dtype=dt)
|
||||
self.check_where_alternate_input_other(t, 3.2, (dt if dtypes.is_float(dt) else dtypes.default_float))
|
||||
self.check_where_alternate_input_other(t, 3, (dt if dtypes.is_float(dt) or dtypes.is_int(dt) else dtypes.default_int))
|
||||
self.check_where_alternate_input_other(t, True, dt)
|
||||
|
||||
def test_where_two_scalars(self):
|
||||
self.check_where_alternate_input_other(3.1, 3.2, dtypes.default_float)
|
||||
self.check_where_alternate_input_other(3.1, 3, dtypes.default_float)
|
||||
self.check_where_alternate_input_other(3.1, True, dtypes.default_float)
|
||||
self.check_where_alternate_input_other(3, 2, dtypes.default_int)
|
||||
self.check_where_alternate_input_other(3, True, dtypes.default_int)
|
||||
self.check_where_alternate_input_other(False, True, dtypes.bool)
|
||||
|
||||
@given(strat.sampled_from(core_dtypes), strat.sampled_from(core_dtypes))
|
||||
def test_maximum(self, dt1, dt2):
|
||||
from tinygrad.dtype import least_upper_dtype
|
||||
assert Tensor([0, 1, 2], dtype=dt1).maximum(Tensor([2, 0, 5], dtype=dt2)).dtype == least_upper_dtype(dt1, dt2)
|
||||
|
||||
@given(strat.sampled_from(core_dtypes))
|
||||
def test_maximum_const(self, dt):
|
||||
assert Tensor([1, 2], dtype=dt).maximum(3.1).dtype == (dt if dtypes.is_float(dt) else dtypes.default_float)
|
||||
assert Tensor([1, 2], dtype=dt).maximum(3).dtype == (dt if dtypes.is_float(dt) or dtypes.is_int(dt) else dtypes.default_int)
|
||||
assert Tensor([1, 2], dtype=dt).maximum(True).dtype == dt
|
||||
|
||||
def test_div(self):
|
||||
assert (Tensor([1, 2], dtype=dtypes.int32) / Tensor([2, 2], dtype=dtypes.int32)).dtype == dtypes.default_float
|
||||
assert (Tensor([1, 2], dtype=dtypes.int16) / Tensor([2, 2], dtype=dtypes.int32)).dtype == dtypes.default_float
|
||||
assert (Tensor([1, 2], dtype=dtypes.float32) / Tensor([2, 2], dtype=dtypes.float16)).dtype == dtypes.float32
|
||||
assert (Tensor([1, 2], dtype=dtypes.int32) / Tensor([2, 2], dtype=dtypes.float16)).dtype == dtypes.float16
|
||||
|
||||
def test_div_const(self):
|
||||
assert (Tensor([1, 2], dtype=dtypes.int32) / 2).dtype == dtypes.default_float
|
||||
assert (Tensor([1, 2], dtype=dtypes.int32) / 2.0).dtype == dtypes.default_float
|
||||
assert (Tensor([1, 2], dtype=dtypes.float16) / 2).dtype == dtypes.float16
|
||||
assert (Tensor([1, 2], dtype=dtypes.float16) / 2.0).dtype == dtypes.float16
|
||||
|
||||
def test_gradient_dtype(self):
|
||||
old_default_float = dtypes.default_float
|
||||
|
||||
for default_dtype in dtypes.floats:
|
||||
if not is_dtype_supported(default_dtype): continue
|
||||
dtypes.default_float = default_dtype
|
||||
for dtype in dtypes.floats:
|
||||
if not is_dtype_supported(dtype): continue
|
||||
if DEBUG >= 2:
|
||||
print(f"testing {default_dtype=}, {dtype=}")
|
||||
a = Tensor([1, 2, 3], dtype=dtype, requires_grad=True)
|
||||
b = (a * 5).sum()
|
||||
b.backward() # if there is dtype mismatch, lazy should assert
|
||||
assert a.grad.dtype == a.dtype
|
||||
np.testing.assert_allclose(a.grad.numpy(), [5, 5, 5])
|
||||
|
||||
dtypes.default_float = old_default_float
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "PYTHON", "very slow")
|
||||
@slow
|
||||
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "Binding size is larger than the maximum storage buffer binding size")
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
|
||||
def test_mean_half_precision_underflow(self):
|
||||
N = 10000
|
||||
x = 0.001
|
||||
t = Tensor([[x]], dtype=dtypes.half, requires_grad=True).expand(N, N).contiguous()
|
||||
np.testing.assert_allclose(t.mean(axis=1).numpy(), np.array([x] * N, dtype=np.float16), rtol=1e-3)
|
||||
|
||||
@unittest.skip("this test only works with SPLIT_REDUCEOP=1")
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
|
||||
def test_mean_half_precision_overflow(self):
|
||||
N = 256
|
||||
t = Tensor([60000] * N*N, dtype=dtypes.half, requires_grad=True).reshape(N, N)
|
||||
np.testing.assert_allclose(t.mean().numpy(), 60000)
|
||||
t.square().mean().backward()
|
||||
np.testing.assert_allclose(t.grad.numpy().flatten(), [60000 * 2 / (N*N)] * N*N)
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "Precision error")
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
|
||||
def test_softmax_dtype(self):
|
||||
data = [1, 2, 3]
|
||||
t = Tensor(data, dtype=dtypes.half)
|
||||
tt = torch.tensor(data, dtype=torch.half)
|
||||
|
||||
out = t.softmax(0)
|
||||
self.assertEqual(out.dtype, dtypes.half)
|
||||
np.testing.assert_allclose(out.numpy(), tt.softmax(0).numpy(), rtol=1e-3)
|
||||
out = t.softmax(0, dtype=dtypes.float)
|
||||
self.assertEqual(out.dtype, dtypes.float)
|
||||
np.testing.assert_allclose(out.numpy(), tt.softmax(0, dtype=torch.float).numpy(), rtol=1e-3)
|
||||
out = t.log_softmax(0)
|
||||
self.assertEqual(out.dtype, dtypes.half)
|
||||
np.testing.assert_allclose(out.numpy(), tt.log_softmax(0).numpy(), rtol=1e-3)
|
||||
out = t.log_softmax(0, dtype=dtypes.float)
|
||||
self.assertEqual(out.dtype, dtypes.float)
|
||||
np.testing.assert_allclose(out.numpy(), tt.log_softmax(0, dtype=torch.float).numpy(), rtol=1e-3)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
85
test/test_gradient.py
Normal file
85
test/test_gradient.py
Normal file
|
|
@ -0,0 +1,85 @@
|
|||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad import Tensor
|
||||
from tinygrad.dtype import dtypes
|
||||
|
||||
class TestTensorGradient(unittest.TestCase):
|
||||
def test_example(self):
|
||||
x = Tensor.eye(3)
|
||||
y = Tensor([[2.0,0,-2.0]])
|
||||
z = y.matmul(x).sum()
|
||||
dx, dy = z.gradient(x, y)
|
||||
self.assertListEqual(dx.tolist(), [[2.0, 2.0, 2.0], [0.0, 0.0, 0.0], [-2.0, -2.0, -2.0]])
|
||||
self.assertListEqual(dy.tolist(), [[1.0, 1.0, 1.0]])
|
||||
|
||||
def test_raises(self):
|
||||
x = Tensor([1.0, 2.0, 3.0])
|
||||
w = Tensor.randn((3,))
|
||||
with self.assertRaises(RuntimeError): x.sum().gradient(w)
|
||||
|
||||
def test_with_custom_gradient(self):
|
||||
x = Tensor([1.0, 2.0, 3.0])
|
||||
z = (x * x).sum()
|
||||
dx = z.gradient(x, gradient=Tensor([3.0]))[0]
|
||||
self.assertListEqual(dx.tolist(), [6.0, 12.0, 18.0])
|
||||
|
||||
def test_broadcast_gradient(self):
|
||||
x = Tensor([[1.0], [2.0], [3.0]])
|
||||
y = Tensor([[10.0, 20.0, 30.0, 40.0]])
|
||||
z = (x + y).sum()
|
||||
dx, dy = z.gradient(x, y)
|
||||
self.assertListEqual(dx.tolist(), [[4.0], [4.0], [4.0]])
|
||||
self.assertListEqual(dy.tolist(), [[3.0, 3.0, 3.0, 3.0]])
|
||||
|
||||
def test_non_scalar_output(self):
|
||||
x = Tensor([1.0, 2.0, 3.0])
|
||||
z = x * x
|
||||
with self.assertRaises(AssertionError): z.gradient(x)
|
||||
dz = Tensor([1.0, 1.0, 1.0])
|
||||
dx = z.gradient(x, gradient=dz)[0]
|
||||
self.assertListEqual(dx.tolist(), [2.0, 4.0, 6.0])
|
||||
|
||||
def test_cast_before_view(self):
|
||||
x = Tensor([1.0, 1, 1, 1])
|
||||
x_reshaped = x.reshape(2,2)
|
||||
x_casted = x_reshaped.cast(dtypes.float16)
|
||||
x_casted.mean().gradient(x_reshaped)
|
||||
|
||||
def test_non_float_tensor_raise(self):
|
||||
x = Tensor([1, 2, 3])
|
||||
with self.assertRaises(RuntimeError): x.sum().gradient(x)
|
||||
with self.assertRaises(RuntimeError): x.float().sum().gradient(x)
|
||||
|
||||
def test_copy_to_device_gradient(self):
|
||||
t = Tensor([1.0, 2, 3], requires_grad=True).realize()
|
||||
t.to("CPU:1").square().sum().backward()
|
||||
self.assertEqual(t.grad.device, t.device)
|
||||
self.assertListEqual(t.grad.tolist(), [2.0, 4.0, 6.0])
|
||||
|
||||
def test_multiple_backward(self):
|
||||
x = Tensor([3.], requires_grad=True)
|
||||
(x*2)[0].backward()
|
||||
np.testing.assert_allclose(x.grad.numpy(), [2.0])
|
||||
old_grad = x.grad
|
||||
(x*3)[0].backward()
|
||||
np.testing.assert_allclose(x.grad.numpy(), [2.0+3.0])
|
||||
self.assertIs(x.grad, old_grad)
|
||||
(x*x)[0].backward()
|
||||
np.testing.assert_allclose(x.grad.numpy(), [2.0+3.0+2*3.0])
|
||||
self.assertIs(x.grad, old_grad)
|
||||
|
||||
class TestViewGradient(unittest.TestCase):
|
||||
def test_expand(self):
|
||||
# this test shows that if Tensors collapse to the views and create a disconnected graph
|
||||
# there's no way to recover the proper gradient
|
||||
x = Tensor.randn(5,2)
|
||||
a = Tensor([3.], requires_grad=True)
|
||||
aex = a.expand(10)
|
||||
(aex.reshape(5,2) * x).sum().backward()
|
||||
np.testing.assert_allclose(aex.grad.numpy(), x.reshape(10).numpy())
|
||||
# NOTE: aex.grad is *not* a.grad.expand(10)!
|
||||
with self.assertRaises(AssertionError):
|
||||
np.testing.assert_allclose(aex.grad.numpy(), a.grad.expand(10).numpy())
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
11
test/test_helpers.py
Normal file
11
test/test_helpers.py
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad.helpers import polyN
|
||||
|
||||
class TestPolyN(unittest.TestCase):
|
||||
def test_tensor(self):
|
||||
from tinygrad.tensor import Tensor
|
||||
np.testing.assert_allclose(polyN(Tensor([1.0, 2.0, 3.0, 4.0]), [1.0, -2.0, 1.0]).numpy(), [0.0, 1.0, 4.0, 9.0])
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
@ -1,9 +1,8 @@
|
|||
import unittest, math, operator, subprocess, struct
|
||||
from tinygrad.tensor import Tensor, dtypes, Device
|
||||
from tinygrad.dtype import DType, DTYPES_DICT, truncate, float_to_fp16, float_to_bf16, _to_np_dtype, least_upper_dtype, least_upper_float
|
||||
import unittest, math, struct
|
||||
from tinygrad.tensor import dtypes
|
||||
from tinygrad.dtype import DTYPES_DICT, truncate, float_to_fp16, float_to_bf16, _to_np_dtype, least_upper_dtype
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.helpers import getenv, DEBUG
|
||||
from test.helpers import slow
|
||||
from tinygrad.helpers import getenv
|
||||
from hypothesis import given, settings, strategies as strat
|
||||
import numpy as np
|
||||
import torch
|
||||
|
|
@ -12,22 +11,10 @@ settings.register_profile("my_profile", max_examples=50, deadline=None, derandom
|
|||
settings.load_profile("my_profile")
|
||||
|
||||
core_dtypes = list(DTYPES_DICT.values())
|
||||
dtype_ints = [dt for dt in core_dtypes if dtypes.is_int(dt) and is_dtype_supported(dt)]
|
||||
dtype_floats = [dt for dt in core_dtypes if dtypes.is_float(dt) and is_dtype_supported(dt)]
|
||||
|
||||
FP8E4M3_MAX = 448.0
|
||||
FP8E5M2_MAX = 57344.0
|
||||
|
||||
def _assert_eq(tensor:Tensor, target_dtype:DType, target, tol_target_dtype:float=1e-7):
|
||||
if DEBUG >= 2: print(tensor.numpy())
|
||||
try:
|
||||
assert tensor.dtype == target_dtype
|
||||
np.testing.assert_allclose(tensor.numpy(), target, rtol={dtypes.float16:1e-3, dtypes.bfloat16:1e-2,
|
||||
dtypes.fp8e4m3:1e-1, dtypes.fp8e5m2:5e-1}.get(target_dtype, tol_target_dtype))
|
||||
|
||||
except AssertionError as e:
|
||||
raise AssertionError(f"\ntensor {tensor.numpy()} dtype {tensor.dtype} does not match target {target} with dtype {target_dtype}") from e
|
||||
|
||||
def u32_to_f32(u): return struct.unpack('f', struct.pack('I', u))[0]
|
||||
def f32_to_u32(f): return struct.unpack('I', struct.pack('f', f))[0]
|
||||
|
||||
|
|
@ -202,163 +189,6 @@ class TestHelpers(unittest.TestCase):
|
|||
elif x < -FP8E5M2_MAX: np.testing.assert_equal(truncate[dtypes.fp8e5m2](x), -FP8E5M2_MAX)
|
||||
else: np.testing.assert_equal(truncate[dtypes.fp8e5m2](x), torch.tensor(x, dtype=torch.float8_e5m2).float().item())
|
||||
|
||||
class TestTypeSpec(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.old_default_int, self.old_default_float = dtypes.default_int, dtypes.default_float
|
||||
def tearDown(self):
|
||||
dtypes.default_int, dtypes.default_float = self.old_default_int, self.old_default_float
|
||||
|
||||
def test_set_dtype_default(self):
|
||||
for default_int in [dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64]:
|
||||
dtypes.default_int = default_int
|
||||
assert dtypes.default_int == default_int
|
||||
|
||||
for default_float in [*dtypes.fp8s, dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64]:
|
||||
dtypes.default_float = default_float
|
||||
assert dtypes.default_float == default_float
|
||||
|
||||
@unittest.skip("this test is slow and spawning whole pythons")
|
||||
def test_env_set_default_float(self):
|
||||
# check default
|
||||
subprocess.run(['python3 -c "from tinygrad import dtypes; assert dtypes.default_float == dtypes.float"'],
|
||||
shell=True, check=True)
|
||||
# check change
|
||||
subprocess.run(['DEFAULT_FLOAT=HALF python3 -c "from tinygrad import dtypes; assert dtypes.default_float == dtypes.half"'],
|
||||
shell=True, check=True)
|
||||
# check invalid
|
||||
with self.assertRaises(subprocess.CalledProcessError):
|
||||
subprocess.run(['DEFAULT_FLOAT=INT32 python3 -c "from tinygrad import dtypes"'],
|
||||
shell=True, check=True)
|
||||
|
||||
with self.assertRaises(subprocess.CalledProcessError):
|
||||
subprocess.run(['DEFAULT_FLOAT=TYPO python3 -c "from tinygrad import dtypes"'],
|
||||
shell=True, check=True)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.int8), f"no int8 on {Device.DEFAULT}")
|
||||
def test_dtype_str_arg(self):
|
||||
n = np.random.normal(0, 1, (10, 10)).astype(np.float32)
|
||||
tested = 0
|
||||
for dtype_str, dtype in [
|
||||
("bool", dtypes.bool), ("int8", dtypes.int8), ("int", dtypes.int), ("uint32", dtypes.uint32), ("float32", dtypes.float32)]:
|
||||
np.testing.assert_equal(Tensor(n, dtype=dtype_str).numpy(), Tensor(n, dtype=dtype).numpy())
|
||||
np.testing.assert_equal(Tensor(n).cast(dtype_str).numpy(), Tensor(n).cast(dtype).numpy())
|
||||
if dtype.itemsize == 4:
|
||||
np.testing.assert_equal(Tensor(n).bitcast(dtype_str).numpy(), Tensor(n).bitcast(dtype).numpy())
|
||||
tested += 1
|
||||
assert tested == 3
|
||||
|
||||
with self.assertRaises(AttributeError): Tensor([1, 2, 3], dtype="nonexistdtype")
|
||||
with self.assertRaises(AttributeError): Tensor([1, 2, 3], dtype="")
|
||||
|
||||
np.testing.assert_equal(Tensor(n).sum(dtype="int16").numpy(), Tensor(n).sum(dtype=dtypes.int16).numpy())
|
||||
|
||||
@given(strat.sampled_from(dtype_ints), strat.sampled_from(dtype_floats))
|
||||
def test_creation(self, default_int, default_float):
|
||||
dtypes.default_int, dtypes.default_float = default_int, default_float
|
||||
_assert_eq(Tensor(True), dtypes.bool, True)
|
||||
_assert_eq(Tensor(None), dtypes.default_float, [])
|
||||
_assert_eq(Tensor(2), dtypes.default_int, 2)
|
||||
_assert_eq(Tensor(2.34), dtypes.default_float, 2.34)
|
||||
_assert_eq(Tensor([]), dtypes.default_float, [])
|
||||
_assert_eq(Tensor([1]), dtypes.default_int, [1])
|
||||
_assert_eq(Tensor([1.1]), dtypes.default_float, [1.1])
|
||||
|
||||
_assert_eq(Tensor.eye(0), dtypes.default_float, np.eye(0))
|
||||
_assert_eq(Tensor.eye(3), dtypes.default_float, np.eye(3))
|
||||
if is_dtype_supported(dtypes.int64):
|
||||
_assert_eq(Tensor.eye(3, dtype=dtypes.int64), dtypes.int64, np.eye(3))
|
||||
if is_dtype_supported(dtypes.float16):
|
||||
_assert_eq(Tensor.eye(3, dtype=dtypes.float16), dtypes.float16, np.eye(3))
|
||||
|
||||
@given(strat.sampled_from(dtype_ints), strat.sampled_from(dtype_floats))
|
||||
def test_full(self, default_int, default_float):
|
||||
dtypes.default_int, dtypes.default_float = default_int, default_float
|
||||
|
||||
_assert_eq(Tensor.zeros((2, 3)), dtypes.default_float, np.zeros((2, 3)))
|
||||
if is_dtype_supported(dtypes.int64):
|
||||
_assert_eq(Tensor.zeros((2, 3), dtype=dtypes.int64), dtypes.int64, np.zeros((2, 3)))
|
||||
if is_dtype_supported(dtypes.float16):
|
||||
_assert_eq(Tensor.zeros((2, 3), dtype=dtypes.float16), dtypes.float16, np.zeros((2, 3)))
|
||||
|
||||
_assert_eq(Tensor.ones((2, 3)), dtypes.default_float, np.ones((2, 3)))
|
||||
if is_dtype_supported(dtypes.int64):
|
||||
_assert_eq(Tensor.ones((2, 3), dtype=dtypes.int64), dtypes.int64, np.ones((2, 3)))
|
||||
if is_dtype_supported(dtypes.float16):
|
||||
_assert_eq(Tensor.ones((2, 3), dtype=dtypes.float16), dtypes.float16, np.ones((2, 3)))
|
||||
|
||||
_assert_eq(Tensor.full((2, 3), 3.0), dtypes.default_float, np.full((2, 3), 3.0))
|
||||
_assert_eq(Tensor.full((2, 3), 3), dtypes.default_int, np.full((2, 3), 3))
|
||||
_assert_eq(Tensor.full((2, 3), True), dtypes.bool, np.full((2, 3), True))
|
||||
if is_dtype_supported(dtypes.int64):
|
||||
_assert_eq(Tensor.full((2, 3), 3, dtype=dtypes.int64), dtypes.int64, np.full((2, 3), 3))
|
||||
_assert_eq(Tensor.full((2, 3), 3.0, dtype=dtypes.int64), dtypes.int64, np.full((2, 3), 3))
|
||||
if is_dtype_supported(dtypes.float16):
|
||||
_assert_eq(Tensor.full((2, 3), 3, dtype=dtypes.float16), dtypes.float16, np.full((2, 3), 3))
|
||||
_assert_eq(Tensor.full((2, 3), 3.0, dtype=dtypes.float16), dtypes.float16, np.full((2, 3), 3))
|
||||
|
||||
@given(strat.sampled_from(dtype_ints), strat.sampled_from(dtype_floats))
|
||||
def test_reduce_0d_default(self, default_int, default_float):
|
||||
dtypes.default_int, dtypes.default_float = default_int, default_float
|
||||
_assert_eq(Tensor.ones((2,3,0)).sum(2), dtypes.default_float, np.zeros((2, 3)))
|
||||
# TODO: what should this one be?
|
||||
# _assert_eq(Tensor.ones((2,3,0), dtype=dtypes.default_int).sum(2), dtypes.default_int, np.zeros((2, 3)))
|
||||
_assert_eq(Tensor.ones((2,3,0), dtype=dtypes.int32).sum(2), dtypes.int32, np.zeros((2, 3)))
|
||||
|
||||
@given(strat.sampled_from(dtype_ints), strat.sampled_from(dtype_floats))
|
||||
def test_arange(self, default_int, default_float):
|
||||
dtypes.default_int, dtypes.default_float = default_int, default_float
|
||||
|
||||
_assert_eq(Tensor.arange(5), dtypes.default_int, np.arange(5))
|
||||
_assert_eq(Tensor.arange(120), dtypes.default_int, np.arange(120))
|
||||
_assert_eq(Tensor.arange(5.0), dtypes.default_float, np.arange(5))
|
||||
if is_dtype_supported(dtypes.int16):
|
||||
_assert_eq(Tensor.arange(5, dtype=dtypes.int16), dtypes.int16, np.arange(5))
|
||||
if is_dtype_supported(dtypes.int64):
|
||||
_assert_eq(Tensor.arange(5, dtype=dtypes.int64), dtypes.int64, np.arange(5))
|
||||
if is_dtype_supported(dtypes.float16):
|
||||
_assert_eq(Tensor.arange(5, dtype=dtypes.float16), dtypes.float16, np.arange(5))
|
||||
_assert_eq(Tensor.arange(3, 9, 0.7), dtypes.default_float, np.arange(3, 9, 0.7), 1e-6 if Device.DEFAULT == "WEBGPU" else 1e-7)
|
||||
_assert_eq(Tensor.arange(3, 8.5, 3), dtypes.default_float, np.arange(3, 8.5, 3))
|
||||
# stop-start and step have different signs
|
||||
_assert_eq(Tensor.arange(3, 5, -2), dtypes.default_int, np.arange(3, 5, -2))
|
||||
_assert_eq(Tensor.arange(5.0, 3.0), dtypes.default_float, np.arange(5.0, 3.0))
|
||||
|
||||
@given(strat.sampled_from(core_dtypes), strat.sampled_from([operator.gt, operator.ge, operator.le, operator.lt, operator.eq, operator.ne]))
|
||||
def test_bool_ops(self, dtype, op):
|
||||
assert op(Tensor.ones(4, 4, dtype=dtype), Tensor.ones(4, 4, dtype=dtype)).dtype == dtypes.bool
|
||||
|
||||
@given(strat.sampled_from(core_dtypes), strat.sampled_from(dtype_ints), strat.sampled_from(dtype_floats))
|
||||
def test_functions_return_index(self, dtype, default_int, default_float):
|
||||
dtypes.default_int, dtypes.default_float = default_int, default_float
|
||||
assert Tensor([0, 1], dtype=dtype).argmax().dtype == dtypes.int32
|
||||
assert Tensor([0, 1], dtype=dtype).argmin().dtype == dtypes.int32
|
||||
assert Tensor([0, 1], dtype=dtype).multinomial().dtype == dtypes.int32
|
||||
|
||||
@given(strat.sampled_from(core_dtypes), strat.sampled_from(dtype_ints))
|
||||
def test_tensor_indexing_returns_same_dtype(self, data_dtype, indices_dtype):
|
||||
X_data = Tensor.ones(60000, 1, 28, 28, dtype=data_dtype)
|
||||
indices = Tensor.randint(512, high=X_data.shape[0]).cast(indices_dtype)
|
||||
assert X_data[indices].dtype == X_data.dtype
|
||||
|
||||
@given(strat.sampled_from(core_dtypes), strat.sampled_from(dtype_ints))
|
||||
def test_gather_returns_same_dtype(self, data_dtype, indices_dtype):
|
||||
X_data = Tensor([[1, 0], [0, 1]], dtype=data_dtype)
|
||||
indices = Tensor([[0, 0], [1, 0]], dtype=indices_dtype)
|
||||
assert X_data.gather(0, indices).dtype == X_data.dtype
|
||||
assert X_data.gather(1, indices).dtype == X_data.dtype
|
||||
|
||||
@given(strat.sampled_from(dtype_floats), strat.sampled_from(dtype_floats))
|
||||
def test_attention_returns_same_dtype(self, data_dtype, default_float):
|
||||
dtypes.default_float = default_float
|
||||
query = Tensor.rand(32, 8, 128, 64, dtype=data_dtype)
|
||||
key = Tensor.rand(32, 8, 128, 64, dtype=data_dtype)
|
||||
value = Tensor.rand(32, 8, 128, 64, dtype=data_dtype)
|
||||
mask = (Tensor.rand(32, 8, 128, 128) < 0.5)
|
||||
assert query.scaled_dot_product_attention(key, value, is_causal=True).dtype == data_dtype
|
||||
assert query.scaled_dot_product_attention(key, value, is_causal=True, dropout_p=0.3).dtype == data_dtype
|
||||
assert query.scaled_dot_product_attention(key, value, is_causal=False).dtype == data_dtype
|
||||
assert query.scaled_dot_product_attention(key, value, attn_mask=mask).dtype == data_dtype
|
||||
|
||||
class TestTypePromotion(unittest.TestCase):
|
||||
@given(strat.sampled_from(core_dtypes))
|
||||
def test_self_promo_to_self(self, dtype):
|
||||
|
|
@ -398,237 +228,5 @@ class TestTypePromotion(unittest.TestCase):
|
|||
assert least_upper_dtype(dtypes.fp8e5m2, dtypes.int64) == dtypes.fp8e5m2
|
||||
assert least_upper_dtype(dtypes.fp8e5m2, dtypes.uint64) == dtypes.fp8e5m2
|
||||
|
||||
class TestAutoCastType(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.old_default_int, self.old_default_float = dtypes.default_int, dtypes.default_float
|
||||
def tearDown(self):
|
||||
dtypes.default_int, dtypes.default_float = self.old_default_int, self.old_default_float
|
||||
|
||||
@given(strat.sampled_from(dtype_floats), strat.sampled_from(dtype_floats))
|
||||
def test_least_upper_float_input_is_float(self, input_dtype, default_float):
|
||||
dtypes.default_float = default_float
|
||||
self.assertEqual(least_upper_float(input_dtype), input_dtype)
|
||||
|
||||
@given(strat.sampled_from(dtype_ints), strat.sampled_from(dtype_floats))
|
||||
def test_least_upper_float_input_is_int(self, input_dtype, default_float):
|
||||
dtypes.default_float = default_float
|
||||
self.assertEqual(least_upper_float(input_dtype), default_float)
|
||||
|
||||
@given(strat.sampled_from([d for d in core_dtypes if dtypes.is_int(d) and is_dtype_supported(d)]))
|
||||
def test_int_to_float_unary_func(self, dtype):
|
||||
for func in [
|
||||
lambda t: t.exp(),
|
||||
lambda t: t.exp2(),
|
||||
lambda t: t.log(),
|
||||
lambda t: t.log2(),
|
||||
lambda t: t.sqrt(),
|
||||
lambda t: t.rsqrt(),
|
||||
lambda t: t.sin(),
|
||||
lambda t: t.cos(),
|
||||
lambda t: t.tan(),
|
||||
lambda t: t.sigmoid(),
|
||||
]:
|
||||
a = [2, 3, 4]
|
||||
# float16 can have larger precision errors
|
||||
np.testing.assert_allclose(func(Tensor(a, dtype=dtype)).numpy(), func(torch.tensor(a)), rtol=1e-3, atol=1e-3)
|
||||
|
||||
@given(strat.sampled_from(core_dtypes))
|
||||
def test_broadcast_scalar(self, dt):
|
||||
assert (Tensor.ones(4, 4, dtype=dt) + 2.3).dtype == (dt if dtypes.is_float(dt) else dtypes.default_float)
|
||||
assert (Tensor.ones(4, 4, dtype=dt) + 2).dtype == (dt if dtypes.is_float(dt) or dtypes.is_int(dt) else dtypes.default_int)
|
||||
assert (Tensor.ones(4, 4, dtype=dt) + True).dtype == dt
|
||||
|
||||
@given(strat.sampled_from(dtype_floats))
|
||||
def test_int_div_int(self, default_float):
|
||||
dtypes.default_float = default_float
|
||||
self.assertEqual(Tensor([1]).div(Tensor([2])).dtype, default_float)
|
||||
|
||||
def test_sum(self):
|
||||
assert (Tensor([0, 1], dtype=dtypes.bool)).sum().dtype == dtypes.int32
|
||||
assert (Tensor([0, 1], dtype=dtypes.int8)).sum().dtype == dtypes.int32
|
||||
assert (Tensor([0, 1], dtype=dtypes.int16)).sum().dtype == dtypes.int32
|
||||
assert (Tensor([0, 1], dtype=dtypes.int32)).sum().dtype == dtypes.int32
|
||||
assert (Tensor([0, 1], dtype=dtypes.int64)).sum().dtype == dtypes.int64
|
||||
assert (Tensor([0, 1], dtype=dtypes.uint8)).sum().dtype == dtypes.uint32
|
||||
assert (Tensor([0, 1], dtype=dtypes.uint16)).sum().dtype == dtypes.uint32
|
||||
assert (Tensor([0, 1], dtype=dtypes.uint32)).sum().dtype == dtypes.uint32
|
||||
assert (Tensor([0, 1], dtype=dtypes.uint64)).sum().dtype == dtypes.uint64
|
||||
assert (Tensor([0, 1], dtype=dtypes.fp8e4m3)).sum().dtype == dtypes.fp8e4m3
|
||||
assert (Tensor([0, 1], dtype=dtypes.fp8e5m2)).sum().dtype == dtypes.fp8e5m2
|
||||
assert (Tensor([0, 1], dtype=dtypes.float16)).sum().dtype == dtypes.float16
|
||||
assert (Tensor([0, 1], dtype=dtypes.bfloat16)).sum().dtype == dtypes.bfloat16
|
||||
assert (Tensor([0, 1], dtype=dtypes.float32)).sum().dtype == dtypes.float32
|
||||
assert (Tensor([0, 1], dtype=dtypes.float64)).sum().dtype == dtypes.float64
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.float16), "need float16")
|
||||
def test_sum_dtype_arg(self):
|
||||
t = Tensor([40000, 40000], dtype=dtypes.float16)
|
||||
# default float16 sum returns in float16, overflowed in this case
|
||||
assert t.sum().dtype == dtypes.float16
|
||||
assert math.isinf(t.sum().numpy().item())
|
||||
# specifiying dtype and it's not downcasted
|
||||
assert t.sum(dtype=dtypes.float32).dtype == dtypes.float32
|
||||
np.testing.assert_allclose(t.sum(dtype=dtypes.float32).numpy(), 80000)
|
||||
|
||||
def test_prod_dtype_arg(self):
|
||||
t = Tensor([100, 200], dtype=dtypes.int32)
|
||||
assert t.prod().dtype == dtypes.int32
|
||||
np.testing.assert_allclose(t.prod().numpy(), 20000)
|
||||
assert t.prod(dtype=dtypes.float32).dtype == dtypes.float32
|
||||
np.testing.assert_allclose(t.prod(dtype=dtypes.float32).numpy(), 20000)
|
||||
|
||||
def test_mean(self):
|
||||
assert (Tensor([0, 1], dtype=dtypes.bool)).mean().dtype == dtypes.float32
|
||||
assert (Tensor([0, 1], dtype=dtypes.int8)).mean().dtype == dtypes.float32
|
||||
assert (Tensor([0, 1], dtype=dtypes.int16)).mean().dtype == dtypes.float32
|
||||
assert (Tensor([0, 1], dtype=dtypes.int32)).mean().dtype == dtypes.float32
|
||||
assert (Tensor([0, 1], dtype=dtypes.int64)).mean().dtype == dtypes.float32
|
||||
assert (Tensor([0, 1], dtype=dtypes.uint8)).mean().dtype == dtypes.float32
|
||||
assert (Tensor([0, 1], dtype=dtypes.uint16)).mean().dtype == dtypes.float32
|
||||
assert (Tensor([0, 1], dtype=dtypes.uint32)).mean().dtype == dtypes.float32
|
||||
assert (Tensor([0, 1], dtype=dtypes.uint64)).mean().dtype == dtypes.float32
|
||||
assert (Tensor([0, 1], dtype=dtypes.fp8e4m3)).mean().dtype == dtypes.fp8e4m3
|
||||
assert (Tensor([0, 1], dtype=dtypes.fp8e5m2)).mean().dtype == dtypes.fp8e5m2
|
||||
assert (Tensor([0, 1], dtype=dtypes.float16)).mean().dtype == dtypes.float16
|
||||
assert (Tensor([0, 1], dtype=dtypes.bfloat16)).mean().dtype == dtypes.bfloat16
|
||||
assert (Tensor([0, 1], dtype=dtypes.float32)).mean().dtype == dtypes.float32
|
||||
assert (Tensor([0, 1], dtype=dtypes.float64)).mean().dtype == dtypes.float64
|
||||
|
||||
def test_cumsum(self):
|
||||
assert (Tensor([0, 1], dtype=dtypes.bool)).cumsum(0).dtype == dtypes.int32
|
||||
assert (Tensor([0, 1], dtype=dtypes.int8)).cumsum(0).dtype == dtypes.int32
|
||||
assert (Tensor([0, 1], dtype=dtypes.int16)).cumsum(0).dtype == dtypes.int32
|
||||
assert (Tensor([0, 1], dtype=dtypes.int32)).cumsum(0).dtype == dtypes.int32
|
||||
assert (Tensor([0, 1], dtype=dtypes.int64)).cumsum(0).dtype == dtypes.int64
|
||||
assert (Tensor([0, 1], dtype=dtypes.uint8)).cumsum(0).dtype == dtypes.uint32
|
||||
assert (Tensor([0, 1], dtype=dtypes.uint16)).cumsum(0).dtype == dtypes.uint32
|
||||
assert (Tensor([0, 1], dtype=dtypes.uint32)).cumsum(0).dtype == dtypes.uint32
|
||||
assert (Tensor([0, 1], dtype=dtypes.uint64)).cumsum(0).dtype == dtypes.uint64
|
||||
assert (Tensor([0, 1], dtype=dtypes.fp8e4m3)).cumsum(0).dtype == dtypes.fp8e4m3
|
||||
assert (Tensor([0, 1], dtype=dtypes.fp8e5m2)).cumsum(0).dtype == dtypes.fp8e5m2
|
||||
assert (Tensor([0, 1], dtype=dtypes.float16)).cumsum(0).dtype == dtypes.float16
|
||||
assert (Tensor([0, 1], dtype=dtypes.bfloat16)).cumsum(0).dtype == dtypes.bfloat16
|
||||
assert (Tensor([0, 1], dtype=dtypes.float32)).cumsum(0).dtype == dtypes.float32
|
||||
assert (Tensor([0, 1], dtype=dtypes.float64)).cumsum(0).dtype == dtypes.float64
|
||||
|
||||
@given(strat.sampled_from(core_dtypes), strat.sampled_from(core_dtypes), strat.sampled_from(core_dtypes))
|
||||
def test_matmul(self, dt1, dt2, acc_dt):
|
||||
t1 = Tensor([0, 1], dtype=dt1)
|
||||
t2 = Tensor([0, 1], dtype=dt2)
|
||||
self.assertEqual(t1.matmul(t2).dtype, least_upper_dtype(t1.dtype, t2.dtype))
|
||||
# if dtype is specified, return in dtype
|
||||
self.assertEqual(t1.matmul(t2, dtype=acc_dt).dtype, acc_dt)
|
||||
|
||||
@given(strat.sampled_from(core_dtypes), strat.sampled_from(core_dtypes), strat.sampled_from(core_dtypes), strat.sampled_from(core_dtypes))
|
||||
def test_linear(self, dt1, dt2, dt3, acc_dt):
|
||||
x = Tensor([0, 1], dtype=dt1)
|
||||
w = Tensor([0, 1], dtype=dt2)
|
||||
b = Tensor([0, 1], dtype=dt3)
|
||||
self.assertEqual(x.linear(w).dtype, least_upper_dtype(x.dtype, w.dtype))
|
||||
self.assertEqual(x.linear(w, b).dtype, least_upper_dtype(least_upper_dtype(x.dtype, w.dtype), b.dtype))
|
||||
# if dtype is specified, return in dtype
|
||||
self.assertEqual(x.linear(w, dtype=acc_dt).dtype, acc_dt)
|
||||
self.assertEqual(x.linear(w, b, dtype=acc_dt).dtype, acc_dt)
|
||||
|
||||
@staticmethod
|
||||
def check_where_alternate_input_other(input_, other, data_type):
|
||||
assert (Tensor([True, False]).where(input_, other)).dtype == data_type
|
||||
assert (Tensor([True, False]).where(other, input_)).dtype == data_type
|
||||
|
||||
@given(strat.sampled_from(core_dtypes), strat.sampled_from(core_dtypes))
|
||||
def test_where_no_scalar(self, dt1, dt2):
|
||||
self.check_where_alternate_input_other(Tensor(2, dtype=dt1), Tensor(3, dtype=dt2), least_upper_dtype(dt1, dt2))
|
||||
|
||||
@given(strat.sampled_from(core_dtypes))
|
||||
def test_where_one_scalar(self, dt):
|
||||
t = Tensor(2, dtype=dt)
|
||||
self.check_where_alternate_input_other(t, 3.2, (dt if dtypes.is_float(dt) else dtypes.default_float))
|
||||
self.check_where_alternate_input_other(t, 3, (dt if dtypes.is_float(dt) or dtypes.is_int(dt) else dtypes.default_int))
|
||||
self.check_where_alternate_input_other(t, True, dt)
|
||||
|
||||
def test_where_two_scalars(self):
|
||||
self.check_where_alternate_input_other(3.1, 3.2, dtypes.default_float)
|
||||
self.check_where_alternate_input_other(3.1, 3, dtypes.default_float)
|
||||
self.check_where_alternate_input_other(3.1, True, dtypes.default_float)
|
||||
self.check_where_alternate_input_other(3, 2, dtypes.default_int)
|
||||
self.check_where_alternate_input_other(3, True, dtypes.default_int)
|
||||
self.check_where_alternate_input_other(False, True, dtypes.bool)
|
||||
|
||||
@given(strat.sampled_from(core_dtypes), strat.sampled_from(core_dtypes))
|
||||
def test_maximum(self, dt1, dt2):
|
||||
assert Tensor([0, 1, 2], dtype=dt1).maximum(Tensor([2, 0, 5], dtype=dt2)).dtype == least_upper_dtype(dt1, dt2)
|
||||
|
||||
@given(strat.sampled_from(core_dtypes))
|
||||
def test_maximum_const(self, dt):
|
||||
assert Tensor([1, 2], dtype=dt).maximum(3.1).dtype == (dt if dtypes.is_float(dt) else dtypes.default_float)
|
||||
assert Tensor([1, 2], dtype=dt).maximum(3).dtype == (dt if dtypes.is_float(dt) or dtypes.is_int(dt) else dtypes.default_int)
|
||||
assert Tensor([1, 2], dtype=dt).maximum(True).dtype == dt
|
||||
|
||||
def test_div(self):
|
||||
assert (Tensor([1, 2], dtype=dtypes.int32) / Tensor([2, 2], dtype=dtypes.int32)).dtype == dtypes.default_float
|
||||
assert (Tensor([1, 2], dtype=dtypes.int16) / Tensor([2, 2], dtype=dtypes.int32)).dtype == dtypes.default_float
|
||||
assert (Tensor([1, 2], dtype=dtypes.float32) / Tensor([2, 2], dtype=dtypes.float16)).dtype == dtypes.float32
|
||||
assert (Tensor([1, 2], dtype=dtypes.int32) / Tensor([2, 2], dtype=dtypes.float16)).dtype == dtypes.float16
|
||||
|
||||
def test_div_const(self):
|
||||
assert (Tensor([1, 2], dtype=dtypes.int32) / 2).dtype == dtypes.default_float
|
||||
assert (Tensor([1, 2], dtype=dtypes.int32) / 2.0).dtype == dtypes.default_float
|
||||
assert (Tensor([1, 2], dtype=dtypes.float16) / 2).dtype == dtypes.float16
|
||||
assert (Tensor([1, 2], dtype=dtypes.float16) / 2.0).dtype == dtypes.float16
|
||||
|
||||
def test_gradient_dtype(self):
|
||||
old_default_float = dtypes.default_float
|
||||
|
||||
for default_dtype in dtypes.floats:
|
||||
if not is_dtype_supported(default_dtype): continue
|
||||
dtypes.default_float = default_dtype
|
||||
for dtype in dtypes.floats:
|
||||
if not is_dtype_supported(dtype): continue
|
||||
if DEBUG >= 2:
|
||||
print(f"testing {default_dtype=}, {dtype=}")
|
||||
a = Tensor([1, 2, 3], dtype=dtype, requires_grad=True)
|
||||
b = (a * 5).sum()
|
||||
b.backward() # if there is dtype mismatch, lazy should assert
|
||||
assert a.grad.dtype == a.dtype
|
||||
np.testing.assert_allclose(a.grad.numpy(), [5, 5, 5])
|
||||
|
||||
dtypes.default_float = old_default_float
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "PYTHON", "very slow")
|
||||
@slow
|
||||
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "Binding size is larger than the maximum storage buffer binding size")
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
|
||||
def test_mean_half_precision_underflow(self):
|
||||
N = 10000
|
||||
x = 0.001
|
||||
t = Tensor([[x]], dtype=dtypes.half, requires_grad=True).expand(N, N).contiguous()
|
||||
np.testing.assert_allclose(t.mean(axis=1).numpy(), np.array([x] * N, dtype=np.float16), rtol=1e-3)
|
||||
|
||||
@unittest.skip("this test only works with SPLIT_REDUCEOP=1")
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
|
||||
def test_mean_half_precision_overflow(self):
|
||||
N = 256
|
||||
t = Tensor([60000] * N*N, dtype=dtypes.half, requires_grad=True).reshape(N, N)
|
||||
np.testing.assert_allclose(t.mean().numpy(), 60000)
|
||||
t.square().mean().backward()
|
||||
np.testing.assert_allclose(t.grad.numpy().flatten(), [60000 * 2 / (N*N)] * N*N)
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "Precision error")
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
|
||||
def test_softmax_dtype(self):
|
||||
data = [1, 2, 3]
|
||||
t = Tensor(data, dtype=dtypes.half)
|
||||
tt = torch.tensor(data, dtype=torch.half)
|
||||
|
||||
out = t.softmax(0)
|
||||
self.assertEqual(out.dtype, dtypes.half)
|
||||
np.testing.assert_allclose(out.numpy(), tt.softmax(0).numpy(), rtol=1e-3)
|
||||
out = t.softmax(0, dtype=dtypes.float)
|
||||
self.assertEqual(out.dtype, dtypes.float)
|
||||
np.testing.assert_allclose(out.numpy(), tt.softmax(0, dtype=torch.float).numpy(), rtol=1e-3)
|
||||
out = t.log_softmax(0)
|
||||
self.assertEqual(out.dtype, dtypes.half)
|
||||
np.testing.assert_allclose(out.numpy(), tt.log_softmax(0).numpy(), rtol=1e-3)
|
||||
out = t.log_softmax(0, dtype=dtypes.float)
|
||||
self.assertEqual(out.dtype, dtypes.float)
|
||||
np.testing.assert_allclose(out.numpy(), tt.log_softmax(0, dtype=torch.float).numpy(), rtol=1e-3)
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
from typing import Callable
|
||||
import unittest, math
|
||||
import torch
|
||||
import numpy as np
|
||||
from tinygrad import Tensor
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.uop.ops import UOp
|
||||
|
|
@ -63,71 +62,6 @@ class TestGradient(unittest.TestCase):
|
|||
def test_big_chain(self): self._test_two_input_function(lambda x,y: (1.0/x*y)+x*y)
|
||||
def test_where(self): self._test_two_input_function(lambda x,y: (x<y).where(x,y), lambda x,y: torch.where(x<y,x,y))
|
||||
|
||||
class TestTensorGradient(unittest.TestCase):
|
||||
def test_example(self):
|
||||
x = Tensor.eye(3)
|
||||
y = Tensor([[2.0,0,-2.0]])
|
||||
z = y.matmul(x).sum()
|
||||
dx, dy = z.gradient(x, y)
|
||||
self.assertListEqual(dx.tolist(), [[2.0, 2.0, 2.0], [0.0, 0.0, 0.0], [-2.0, -2.0, -2.0]])
|
||||
self.assertListEqual(dy.tolist(), [[1.0, 1.0, 1.0]])
|
||||
|
||||
def test_raises(self):
|
||||
x = Tensor([1.0, 2.0, 3.0])
|
||||
w = Tensor.randn((3,))
|
||||
with self.assertRaises(RuntimeError): x.sum().gradient(w)
|
||||
|
||||
def test_with_custom_gradient(self):
|
||||
x = Tensor([1.0, 2.0, 3.0])
|
||||
z = (x * x).sum()
|
||||
dx = z.gradient(x, gradient=Tensor([3.0]))[0]
|
||||
self.assertListEqual(dx.tolist(), [6.0, 12.0, 18.0])
|
||||
|
||||
def test_broadcast_gradient(self):
|
||||
x = Tensor([[1.0], [2.0], [3.0]])
|
||||
y = Tensor([[10.0, 20.0, 30.0, 40.0]])
|
||||
z = (x + y).sum()
|
||||
dx, dy = z.gradient(x, y)
|
||||
self.assertListEqual(dx.tolist(), [[4.0], [4.0], [4.0]])
|
||||
self.assertListEqual(dy.tolist(), [[3.0, 3.0, 3.0, 3.0]])
|
||||
|
||||
def test_non_scalar_output(self):
|
||||
x = Tensor([1.0, 2.0, 3.0])
|
||||
z = x * x
|
||||
with self.assertRaises(AssertionError): z.gradient(x)
|
||||
dz = Tensor([1.0, 1.0, 1.0])
|
||||
dx = z.gradient(x, gradient=dz)[0]
|
||||
self.assertListEqual(dx.tolist(), [2.0, 4.0, 6.0])
|
||||
|
||||
def test_cast_before_view(self):
|
||||
x = Tensor([1.0, 1, 1, 1])
|
||||
x_reshaped = x.reshape(2,2)
|
||||
x_casted = x_reshaped.cast(dtypes.float16)
|
||||
x_casted.mean().gradient(x_reshaped)
|
||||
|
||||
def test_non_float_tensor_raise(self):
|
||||
x = Tensor([1, 2, 3])
|
||||
with self.assertRaises(RuntimeError): x.sum().gradient(x)
|
||||
with self.assertRaises(RuntimeError): x.float().sum().gradient(x)
|
||||
|
||||
def test_copy_to_device_gradient(self):
|
||||
t = Tensor([1.0, 2, 3], requires_grad=True).realize()
|
||||
t.to("CPU:1").square().sum().backward()
|
||||
self.assertEqual(t.grad.device, t.device)
|
||||
self.assertListEqual(t.grad.tolist(), [2.0, 4.0, 6.0])
|
||||
|
||||
def test_multiple_backward(self):
|
||||
x = Tensor([3.], requires_grad=True)
|
||||
(x*2)[0].backward()
|
||||
np.testing.assert_allclose(x.grad.numpy(), [2.0])
|
||||
old_grad = x.grad
|
||||
(x*3)[0].backward()
|
||||
np.testing.assert_allclose(x.grad.numpy(), [2.0+3.0])
|
||||
self.assertIs(x.grad, old_grad)
|
||||
(x*x)[0].backward()
|
||||
np.testing.assert_allclose(x.grad.numpy(), [2.0+3.0+2*3.0])
|
||||
self.assertIs(x.grad, old_grad)
|
||||
|
||||
class TestRealizeMeansRealize(unittest.TestCase):
|
||||
def test_randn_realizes(self):
|
||||
x = Tensor.randn(2, 3, 64, 64, requires_grad=True).realize()
|
||||
|
|
@ -148,18 +82,5 @@ class TestRealizeMeansRealize(unittest.TestCase):
|
|||
y = x * 2
|
||||
y.sum().gradient(x)[0].realize()
|
||||
|
||||
class TestViewGradient(unittest.TestCase):
|
||||
def test_expand(self):
|
||||
# this test shows that if Tensors collapse to the views and create a disconnected graph
|
||||
# there's no way to recover the proper gradient
|
||||
x = Tensor.randn(5,2)
|
||||
a = Tensor([3.], requires_grad=True)
|
||||
aex = a.expand(10)
|
||||
(aex.reshape(5,2) * x).sum().backward()
|
||||
np.testing.assert_allclose(aex.grad.numpy(), x.reshape(10).numpy())
|
||||
# NOTE: aex.grad is *not* a.grad.expand(10)!
|
||||
with self.assertRaises(AssertionError):
|
||||
np.testing.assert_allclose(aex.grad.numpy(), a.grad.expand(10).numpy())
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -356,10 +356,6 @@ class TestPolyN(unittest.TestCase):
|
|||
np.testing.assert_allclose(polyN(3.0, [1.0, -2.0, 1.0]), 4.0)
|
||||
np.testing.assert_allclose(polyN(4.0, [1.0, -2.0, 1.0]), 9.0)
|
||||
|
||||
def test_tensor(self):
|
||||
from tinygrad.tensor import Tensor
|
||||
np.testing.assert_allclose(polyN(Tensor([1.0, 2.0, 3.0, 4.0]), [1.0, -2.0, 1.0]).numpy(), [0.0, 1.0, 4.0, 9.0])
|
||||
|
||||
def test_uop(self):
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.uop.ops import UOp
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue