mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
everything has dtype.long now (#14661)
* everything has dtype.long now * int64/uint64 are everywhere now * that doesn't work
This commit is contained in:
parent
cdb78954cb
commit
8dc46dde07
11 changed files with 12 additions and 30 deletions
|
|
@ -1,10 +1,6 @@
|
|||
import unittest
|
||||
from extra.models import resnet
|
||||
from tinygrad import dtypes
|
||||
from tinygrad.device import is_dtype_supported
|
||||
|
||||
# pretrained weights contain num_batches_tracked as int64
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.int64), "need int64 support")
|
||||
class TestResnet(unittest.TestCase):
|
||||
def test_model_load(self):
|
||||
model = resnet.ResNet18()
|
||||
|
|
|
|||
|
|
@ -390,7 +390,6 @@ class TestSchedule(unittest.TestCase):
|
|||
out = bn(c1(img)).relu()
|
||||
check_schedule(out, 4, [c1.weight, c1.bias])
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.ulong), "Needs ulong")
|
||||
def test_fold_conv_batchnorm_optim(self):
|
||||
# this is too high
|
||||
for optim, cnt in [(nn.optim.Adam, 27), (nn.optim.SGD, 7)]:
|
||||
|
|
@ -796,7 +795,6 @@ class TestSchedule(unittest.TestCase):
|
|||
c2(c1(img).relu()).relu().sum().backward()
|
||||
check_schedule(opt.schedule_step(), 7)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.ulong), "Needs ulong")
|
||||
def test_fold_2convs_sgd_nesterov_momentum_wd(self):
|
||||
with Tensor.train():
|
||||
img = Tensor.empty(2,3,4,4)
|
||||
|
|
|
|||
|
|
@ -34,9 +34,8 @@ class TestMovedConstFolding(unittest.TestCase):
|
|||
_check_ast_count(1, Tensor.full(4, fill_value=-1).pad(((1, 1),)).cast(dtypes.uint16))
|
||||
np.testing.assert_equal(Tensor.full(4, fill_value=-1).pad(((1, 1),)).cast(dtypes.uint16).numpy(), [0, 65535, 65535, 65535, 65535, 0])
|
||||
# folded
|
||||
if is_dtype_supported(dtypes.int64):
|
||||
_check_ast_count(1, Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int64))
|
||||
np.testing.assert_equal(Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int64).numpy(), [0, 1, 1, 1, 1, 0])
|
||||
_check_ast_count(1, Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int64))
|
||||
np.testing.assert_equal(Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int64).numpy(), [0, 1, 1, 1, 1, 0])
|
||||
|
||||
class TestReduceOpsConstFolding(unittest.TestCase):
|
||||
def test_const_sum(self):
|
||||
|
|
|
|||
|
|
@ -227,7 +227,6 @@ class TestDTypeALU(unittest.TestCase):
|
|||
@given(ht.int32, ht.int32, strat.sampled_from(integer_binary_operations))
|
||||
def test_int32(self, a, b, op): universal_test(a, b, dtypes.int32, op)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.int64), f"no int64 on {Device.DEFAULT}")
|
||||
@given(ht.int64, ht.int64, strat.sampled_from(integer_binary_operations))
|
||||
def test_int64(self, a, b, op): universal_test(a, b, dtypes.int64, op)
|
||||
|
||||
|
|
@ -265,7 +264,6 @@ class TestDTypeALU(unittest.TestCase):
|
|||
@given(ht.int32, strat.sampled_from(integer_unary_operations))
|
||||
def test_int32_unary(self, a, op): universal_test_unary(a, dtypes.int32, op)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.int64), f"no int64 on {Device.DEFAULT}")
|
||||
@given(ht.int64, strat.sampled_from(integer_unary_operations))
|
||||
def test_int64_unary(self, a, op): universal_test_unary(a, dtypes.int64, op)
|
||||
|
||||
|
|
|
|||
|
|
@ -596,7 +596,7 @@ class TestOps(unittest.TestCase):
|
|||
helper_test_op(None, lambda x: x//2, forward_only=True, vals=[[3, 4, 5]])
|
||||
helper_test_op(None, functools.partial(torch.div, rounding_mode="trunc"), Tensor.idiv, forward_only=True,
|
||||
vals=[[-4, 7, 5, 4, -7, 8], [2, -3, 8, -2, 3, 5]])
|
||||
if is_dtype_supported(dtypes.uint64) and not COMPILE_ONLY:
|
||||
if not COMPILE_ONLY:
|
||||
x = Tensor(2**64 - 1, dtype=dtypes.uint64).idiv(1)
|
||||
np.testing.assert_equal(x.numpy(), 2**64 - 1)
|
||||
|
||||
|
|
|
|||
|
|
@ -83,7 +83,7 @@ class TestRandomness(unittest.TestCase):
|
|||
self.assertTrue(r1.uop.is_realized, "tensor should be realized after .realize()")
|
||||
self.assertTrue(r2.uop.is_realized, "tensor should be realized after .realize()")
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.float16) and is_dtype_supported(dtypes.ulong), "need float16 and ulong support")
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.float16), "need float16 support")
|
||||
def test_rand_float16(self):
|
||||
N = 128
|
||||
x = Tensor.rand((2, N, N), dtype=dtypes.float16)
|
||||
|
|
|
|||
|
|
@ -762,7 +762,7 @@ class TestSchedule(unittest.TestCase):
|
|||
def test_conv2d(self): _test_conv2d(5 if SPLIT_REDUCEOP else 4)
|
||||
def test_conv2d_fused(self): _test_conv2d(5 if SPLIT_REDUCEOP else 4)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.half) and is_dtype_supported(dtypes.ulong), "need half and ulong")
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
|
||||
def test_conv2d_half(self): _test_conv2d(5 if SPLIT_REDUCEOP else 4, dtype=dtypes.half)
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
|
||||
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "Causes other tests to fail")
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ import unittest
|
|||
from tinygrad import Tensor, Variable, GlobalCounters
|
||||
from tinygrad.uop.ops import sym_infer
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from examples.gpt2 import Attention
|
||||
import numpy as np
|
||||
|
||||
|
|
@ -273,7 +272,6 @@ class TestSymbolicOps(unittest.TestCase):
|
|||
symbolic = symbolic_result[:].numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=0)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.uint64), "no uint64")
|
||||
def test_bitcast_up(self):
|
||||
a = Tensor.rand(10, 4)
|
||||
for i in range(1, 5):
|
||||
|
|
|
|||
|
|
@ -82,8 +82,7 @@ class TestTypeSpec(unittest.TestCase):
|
|||
|
||||
_assert_eq(Tensor.eye(0), dtypes.default_float, np.eye(0))
|
||||
_assert_eq(Tensor.eye(3), dtypes.default_float, np.eye(3))
|
||||
if is_dtype_supported(dtypes.int64):
|
||||
_assert_eq(Tensor.eye(3, dtype=dtypes.int64), dtypes.int64, np.eye(3))
|
||||
_assert_eq(Tensor.eye(3, dtype=dtypes.int64), dtypes.int64, np.eye(3))
|
||||
if is_dtype_supported(dtypes.float16):
|
||||
_assert_eq(Tensor.eye(3, dtype=dtypes.float16), dtypes.float16, np.eye(3))
|
||||
|
||||
|
|
@ -92,23 +91,20 @@ class TestTypeSpec(unittest.TestCase):
|
|||
dtypes.default_int, dtypes.default_float = default_int, default_float
|
||||
|
||||
_assert_eq(Tensor.zeros((2, 3)), dtypes.default_float, np.zeros((2, 3)))
|
||||
if is_dtype_supported(dtypes.int64):
|
||||
_assert_eq(Tensor.zeros((2, 3), dtype=dtypes.int64), dtypes.int64, np.zeros((2, 3)))
|
||||
_assert_eq(Tensor.zeros((2, 3), dtype=dtypes.int64), dtypes.int64, np.zeros((2, 3)))
|
||||
if is_dtype_supported(dtypes.float16):
|
||||
_assert_eq(Tensor.zeros((2, 3), dtype=dtypes.float16), dtypes.float16, np.zeros((2, 3)))
|
||||
|
||||
_assert_eq(Tensor.ones((2, 3)), dtypes.default_float, np.ones((2, 3)))
|
||||
if is_dtype_supported(dtypes.int64):
|
||||
_assert_eq(Tensor.ones((2, 3), dtype=dtypes.int64), dtypes.int64, np.ones((2, 3)))
|
||||
_assert_eq(Tensor.ones((2, 3), dtype=dtypes.int64), dtypes.int64, np.ones((2, 3)))
|
||||
if is_dtype_supported(dtypes.float16):
|
||||
_assert_eq(Tensor.ones((2, 3), dtype=dtypes.float16), dtypes.float16, np.ones((2, 3)))
|
||||
|
||||
_assert_eq(Tensor.full((2, 3), 3.0), dtypes.default_float, np.full((2, 3), 3.0))
|
||||
_assert_eq(Tensor.full((2, 3), 3), dtypes.default_int, np.full((2, 3), 3))
|
||||
_assert_eq(Tensor.full((2, 3), True), dtypes.bool, np.full((2, 3), True))
|
||||
if is_dtype_supported(dtypes.int64):
|
||||
_assert_eq(Tensor.full((2, 3), 3, dtype=dtypes.int64), dtypes.int64, np.full((2, 3), 3))
|
||||
_assert_eq(Tensor.full((2, 3), 3.0, dtype=dtypes.int64), dtypes.int64, np.full((2, 3), 3))
|
||||
_assert_eq(Tensor.full((2, 3), 3, dtype=dtypes.int64), dtypes.int64, np.full((2, 3), 3))
|
||||
_assert_eq(Tensor.full((2, 3), 3.0, dtype=dtypes.int64), dtypes.int64, np.full((2, 3), 3))
|
||||
if is_dtype_supported(dtypes.float16):
|
||||
_assert_eq(Tensor.full((2, 3), 3, dtype=dtypes.float16), dtypes.float16, np.full((2, 3), 3))
|
||||
_assert_eq(Tensor.full((2, 3), 3.0, dtype=dtypes.float16), dtypes.float16, np.full((2, 3), 3))
|
||||
|
|
@ -130,8 +126,7 @@ class TestTypeSpec(unittest.TestCase):
|
|||
_assert_eq(Tensor.arange(5.0), dtypes.default_float, np.arange(5))
|
||||
if is_dtype_supported(dtypes.int16):
|
||||
_assert_eq(Tensor.arange(5, dtype=dtypes.int16), dtypes.int16, np.arange(5))
|
||||
if is_dtype_supported(dtypes.int64):
|
||||
_assert_eq(Tensor.arange(5, dtype=dtypes.int64), dtypes.int64, np.arange(5))
|
||||
_assert_eq(Tensor.arange(5, dtype=dtypes.int64), dtypes.int64, np.arange(5))
|
||||
if is_dtype_supported(dtypes.float16):
|
||||
_assert_eq(Tensor.arange(5, dtype=dtypes.float16), dtypes.float16, np.arange(5))
|
||||
_assert_eq(Tensor.arange(3, 9, 0.7), dtypes.default_float, np.arange(3, 9, 0.7), 1e-6 if Device.DEFAULT == "WEBGPU" else 1e-7)
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ import unittest, random, warnings
|
|||
import numpy as np
|
||||
|
||||
from tinygrad import Tensor, dtypes, Device, TinyJit
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.helpers import all_same, prod
|
||||
from test.helpers import slow
|
||||
|
||||
|
|
@ -525,7 +524,6 @@ class TestIndexing(unittest.TestCase):
|
|||
a = src[0].mul(src[1])
|
||||
self.assertEqual(a[0,1].item(), 2)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.int64), "need dtypes.int64")
|
||||
def test_getitem_scalars(self):
|
||||
zero = Tensor(0, dtype=dtypes.int64)
|
||||
one = Tensor(1, dtype=dtypes.int64)
|
||||
|
|
@ -649,7 +647,6 @@ class TestIndexing(unittest.TestCase):
|
|||
i, j = indices
|
||||
numpy_testing_assert_equal_helper(x[i:j], x[0:1])
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.int64), "tensor indexing uses int64 internally")
|
||||
def test_ellipsis_tensor(self):
|
||||
x = Tensor.arange(0, 9).reshape(3, 3)
|
||||
idx = Tensor([0, 2])
|
||||
|
|
|
|||
|
|
@ -341,6 +341,7 @@ class Compiled:
|
|||
# override this in your device implementation
|
||||
|
||||
# TODO: move this to each Device
|
||||
# this only tracks if the dtype is natively supported, it may be supported in the frontend using decomps
|
||||
def is_dtype_supported(dtype:DType, device:str|None=None) -> bool:
|
||||
if dtype == dtypes.index: return False
|
||||
if device is None: device = Device.DEFAULT
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue