everything has dtype.long now (#14661)

* everything has dtype.long now

* int64/uint64 are everywhere now

* that doesn't work
This commit is contained in:
George Hotz 2026-02-10 15:08:50 +08:00 committed by GitHub
commit 8dc46dde07
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 12 additions and 30 deletions

View file

@ -1,10 +1,6 @@
import unittest
from extra.models import resnet
from tinygrad import dtypes
from tinygrad.device import is_dtype_supported
# pretrained weights contain num_batches_tracked as int64
@unittest.skipUnless(is_dtype_supported(dtypes.int64), "need int64 support")
class TestResnet(unittest.TestCase):
def test_model_load(self):
model = resnet.ResNet18()

View file

@ -390,7 +390,6 @@ class TestSchedule(unittest.TestCase):
out = bn(c1(img)).relu()
check_schedule(out, 4, [c1.weight, c1.bias])
@unittest.skipUnless(is_dtype_supported(dtypes.ulong), "Needs ulong")
def test_fold_conv_batchnorm_optim(self):
# this is too high
for optim, cnt in [(nn.optim.Adam, 27), (nn.optim.SGD, 7)]:
@ -796,7 +795,6 @@ class TestSchedule(unittest.TestCase):
c2(c1(img).relu()).relu().sum().backward()
check_schedule(opt.schedule_step(), 7)
@unittest.skipUnless(is_dtype_supported(dtypes.ulong), "Needs ulong")
def test_fold_2convs_sgd_nesterov_momentum_wd(self):
with Tensor.train():
img = Tensor.empty(2,3,4,4)

View file

@ -34,9 +34,8 @@ class TestMovedConstFolding(unittest.TestCase):
_check_ast_count(1, Tensor.full(4, fill_value=-1).pad(((1, 1),)).cast(dtypes.uint16))
np.testing.assert_equal(Tensor.full(4, fill_value=-1).pad(((1, 1),)).cast(dtypes.uint16).numpy(), [0, 65535, 65535, 65535, 65535, 0])
# folded
if is_dtype_supported(dtypes.int64):
_check_ast_count(1, Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int64))
np.testing.assert_equal(Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int64).numpy(), [0, 1, 1, 1, 1, 0])
_check_ast_count(1, Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int64))
np.testing.assert_equal(Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int64).numpy(), [0, 1, 1, 1, 1, 0])
class TestReduceOpsConstFolding(unittest.TestCase):
def test_const_sum(self):

View file

@ -227,7 +227,6 @@ class TestDTypeALU(unittest.TestCase):
@given(ht.int32, ht.int32, strat.sampled_from(integer_binary_operations))
def test_int32(self, a, b, op): universal_test(a, b, dtypes.int32, op)
@unittest.skipUnless(is_dtype_supported(dtypes.int64), f"no int64 on {Device.DEFAULT}")
@given(ht.int64, ht.int64, strat.sampled_from(integer_binary_operations))
def test_int64(self, a, b, op): universal_test(a, b, dtypes.int64, op)
@ -265,7 +264,6 @@ class TestDTypeALU(unittest.TestCase):
@given(ht.int32, strat.sampled_from(integer_unary_operations))
def test_int32_unary(self, a, op): universal_test_unary(a, dtypes.int32, op)
@unittest.skipUnless(is_dtype_supported(dtypes.int64), f"no int64 on {Device.DEFAULT}")
@given(ht.int64, strat.sampled_from(integer_unary_operations))
def test_int64_unary(self, a, op): universal_test_unary(a, dtypes.int64, op)

View file

@ -596,7 +596,7 @@ class TestOps(unittest.TestCase):
helper_test_op(None, lambda x: x//2, forward_only=True, vals=[[3, 4, 5]])
helper_test_op(None, functools.partial(torch.div, rounding_mode="trunc"), Tensor.idiv, forward_only=True,
vals=[[-4, 7, 5, 4, -7, 8], [2, -3, 8, -2, 3, 5]])
if is_dtype_supported(dtypes.uint64) and not COMPILE_ONLY:
if not COMPILE_ONLY:
x = Tensor(2**64 - 1, dtype=dtypes.uint64).idiv(1)
np.testing.assert_equal(x.numpy(), 2**64 - 1)

View file

@ -83,7 +83,7 @@ class TestRandomness(unittest.TestCase):
self.assertTrue(r1.uop.is_realized, "tensor should be realized after .realize()")
self.assertTrue(r2.uop.is_realized, "tensor should be realized after .realize()")
@unittest.skipUnless(is_dtype_supported(dtypes.float16) and is_dtype_supported(dtypes.ulong), "need float16 and ulong support")
@unittest.skipUnless(is_dtype_supported(dtypes.float16), "need float16 support")
def test_rand_float16(self):
N = 128
x = Tensor.rand((2, N, N), dtype=dtypes.float16)

View file

@ -762,7 +762,7 @@ class TestSchedule(unittest.TestCase):
def test_conv2d(self): _test_conv2d(5 if SPLIT_REDUCEOP else 4)
def test_conv2d_fused(self): _test_conv2d(5 if SPLIT_REDUCEOP else 4)
@unittest.skipUnless(is_dtype_supported(dtypes.half) and is_dtype_supported(dtypes.ulong), "need half and ulong")
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
def test_conv2d_half(self): _test_conv2d(5 if SPLIT_REDUCEOP else 4, dtype=dtypes.half)
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "Causes other tests to fail")

View file

@ -2,7 +2,6 @@ import unittest
from tinygrad import Tensor, Variable, GlobalCounters
from tinygrad.uop.ops import sym_infer
from tinygrad.dtype import dtypes
from tinygrad.device import is_dtype_supported
from examples.gpt2 import Attention
import numpy as np
@ -273,7 +272,6 @@ class TestSymbolicOps(unittest.TestCase):
symbolic = symbolic_result[:].numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=0)
@unittest.skipUnless(is_dtype_supported(dtypes.uint64), "no uint64")
def test_bitcast_up(self):
a = Tensor.rand(10, 4)
for i in range(1, 5):

View file

@ -82,8 +82,7 @@ class TestTypeSpec(unittest.TestCase):
_assert_eq(Tensor.eye(0), dtypes.default_float, np.eye(0))
_assert_eq(Tensor.eye(3), dtypes.default_float, np.eye(3))
if is_dtype_supported(dtypes.int64):
_assert_eq(Tensor.eye(3, dtype=dtypes.int64), dtypes.int64, np.eye(3))
_assert_eq(Tensor.eye(3, dtype=dtypes.int64), dtypes.int64, np.eye(3))
if is_dtype_supported(dtypes.float16):
_assert_eq(Tensor.eye(3, dtype=dtypes.float16), dtypes.float16, np.eye(3))
@ -92,23 +91,20 @@ class TestTypeSpec(unittest.TestCase):
dtypes.default_int, dtypes.default_float = default_int, default_float
_assert_eq(Tensor.zeros((2, 3)), dtypes.default_float, np.zeros((2, 3)))
if is_dtype_supported(dtypes.int64):
_assert_eq(Tensor.zeros((2, 3), dtype=dtypes.int64), dtypes.int64, np.zeros((2, 3)))
_assert_eq(Tensor.zeros((2, 3), dtype=dtypes.int64), dtypes.int64, np.zeros((2, 3)))
if is_dtype_supported(dtypes.float16):
_assert_eq(Tensor.zeros((2, 3), dtype=dtypes.float16), dtypes.float16, np.zeros((2, 3)))
_assert_eq(Tensor.ones((2, 3)), dtypes.default_float, np.ones((2, 3)))
if is_dtype_supported(dtypes.int64):
_assert_eq(Tensor.ones((2, 3), dtype=dtypes.int64), dtypes.int64, np.ones((2, 3)))
_assert_eq(Tensor.ones((2, 3), dtype=dtypes.int64), dtypes.int64, np.ones((2, 3)))
if is_dtype_supported(dtypes.float16):
_assert_eq(Tensor.ones((2, 3), dtype=dtypes.float16), dtypes.float16, np.ones((2, 3)))
_assert_eq(Tensor.full((2, 3), 3.0), dtypes.default_float, np.full((2, 3), 3.0))
_assert_eq(Tensor.full((2, 3), 3), dtypes.default_int, np.full((2, 3), 3))
_assert_eq(Tensor.full((2, 3), True), dtypes.bool, np.full((2, 3), True))
if is_dtype_supported(dtypes.int64):
_assert_eq(Tensor.full((2, 3), 3, dtype=dtypes.int64), dtypes.int64, np.full((2, 3), 3))
_assert_eq(Tensor.full((2, 3), 3.0, dtype=dtypes.int64), dtypes.int64, np.full((2, 3), 3))
_assert_eq(Tensor.full((2, 3), 3, dtype=dtypes.int64), dtypes.int64, np.full((2, 3), 3))
_assert_eq(Tensor.full((2, 3), 3.0, dtype=dtypes.int64), dtypes.int64, np.full((2, 3), 3))
if is_dtype_supported(dtypes.float16):
_assert_eq(Tensor.full((2, 3), 3, dtype=dtypes.float16), dtypes.float16, np.full((2, 3), 3))
_assert_eq(Tensor.full((2, 3), 3.0, dtype=dtypes.float16), dtypes.float16, np.full((2, 3), 3))
@ -130,8 +126,7 @@ class TestTypeSpec(unittest.TestCase):
_assert_eq(Tensor.arange(5.0), dtypes.default_float, np.arange(5))
if is_dtype_supported(dtypes.int16):
_assert_eq(Tensor.arange(5, dtype=dtypes.int16), dtypes.int16, np.arange(5))
if is_dtype_supported(dtypes.int64):
_assert_eq(Tensor.arange(5, dtype=dtypes.int64), dtypes.int64, np.arange(5))
_assert_eq(Tensor.arange(5, dtype=dtypes.int64), dtypes.int64, np.arange(5))
if is_dtype_supported(dtypes.float16):
_assert_eq(Tensor.arange(5, dtype=dtypes.float16), dtypes.float16, np.arange(5))
_assert_eq(Tensor.arange(3, 9, 0.7), dtypes.default_float, np.arange(3, 9, 0.7), 1e-6 if Device.DEFAULT == "WEBGPU" else 1e-7)

View file

@ -4,7 +4,6 @@ import unittest, random, warnings
import numpy as np
from tinygrad import Tensor, dtypes, Device, TinyJit
from tinygrad.device import is_dtype_supported
from tinygrad.helpers import all_same, prod
from test.helpers import slow
@ -525,7 +524,6 @@ class TestIndexing(unittest.TestCase):
a = src[0].mul(src[1])
self.assertEqual(a[0,1].item(), 2)
@unittest.skipUnless(is_dtype_supported(dtypes.int64), "need dtypes.int64")
def test_getitem_scalars(self):
zero = Tensor(0, dtype=dtypes.int64)
one = Tensor(1, dtype=dtypes.int64)
@ -649,7 +647,6 @@ class TestIndexing(unittest.TestCase):
i, j = indices
numpy_testing_assert_equal_helper(x[i:j], x[0:1])
@unittest.skipUnless(is_dtype_supported(dtypes.int64), "tensor indexing uses int64 internally")
def test_ellipsis_tensor(self):
x = Tensor.arange(0, 9).reshape(3, 3)
idx = Tensor([0, 2])

View file

@ -341,6 +341,7 @@ class Compiled:
# override this in your device implementation
# TODO: move this to each Device
# this only tracks if the dtype is natively supported, it may be supported in the frontend using decomps
def is_dtype_supported(dtype:DType, device:str|None=None) -> bool:
if dtype == dtypes.index: return False
if device is None: device = Device.DEFAULT