UOp cast and bitcast takes DTypeLike [PR] (#16582)

* UOp cast and bitcast takes DTypeLike [PR]

match Tensor

* fix type
This commit is contained in:
chenyu 2026-06-11 22:38:54 -04:00 committed by GitHub
commit a2cec397f3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 15 additions and 2 deletions

View file

@ -209,8 +209,18 @@ class TestTensorUOpAllclose(unittest.TestCase):
a, b = _t(4).float(), _t(4).float()
self.assertIs(a.allclose(b).uop, a.uop.allclose(b.uop))
class TestTensorUOpCast(unittest.TestCase):
def test_cast_str_dtype(self):
t = _t(4)
self.assertIs(t.cast("float32").uop, t.uop.cast("float32"))
self.assertIs(t.uop.cast("float32").dtype, dtypes.float32)
class TestTensorUOpBitcast(unittest.TestCase):
def test_bitcast_same_dtype(self): _check(self, _t(4).float(), lambda x: x.bitcast(dtypes.float32))
def test_bitcast_str_dtype(self):
t = _t(4)
self.assertIs(t.bitcast("uint32").uop, t.uop.bitcast("uint32"))
self.assertIs(t.uop.bitcast("uint32").dtype, dtypes.uint32)
class TestTensorUOpRand(unittest.TestCase):
def test_random_bits(self):

View file

@ -504,12 +504,15 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
assert self.dtype.vcount == 1
if count == 1: return self
return UOp(Ops.STACK, self.dtype.vec(count), (self,)*count)
def cast(self, dtype:DType):
def cast(self, dtype:DTypeLike):
dtype = to_dtype(dtype)
# TODO: we shouldn't have to check for dtype.count == 1 here, but CAST is misused in AMD LLVM
if dtype.count == 1 and dtype.count != self.dtype.count: dtype = dtype.vec(self.dtype.count)
if self.dtype == dtype: return self
return UOp(Ops.CAST, dtype, (self,))
def bitcast(self, dtype:DType): return self if self.dtype == dtype else UOp(Ops.BITCAST, dtype, (self,))
def bitcast(self, dtype:DTypeLike):
dtype = to_dtype(dtype)
return self if self.dtype == dtype else UOp(Ops.BITCAST, dtype, (self,))
def gep(self, i:tuple[int, ...]|int):
if isinstance(i, tuple) and len(i) == 1: return self.gep(i[0])
if isinstance(i, int):