mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
UOp cast and bitcast takes DTypeLike [PR] (#16582)
* UOp cast and bitcast takes DTypeLike [PR] match Tensor * fix type
This commit is contained in:
parent
b97e3e01e3
commit
a2cec397f3
2 changed files with 15 additions and 2 deletions
|
|
@ -209,8 +209,18 @@ class TestTensorUOpAllclose(unittest.TestCase):
|
|||
a, b = _t(4).float(), _t(4).float()
|
||||
self.assertIs(a.allclose(b).uop, a.uop.allclose(b.uop))
|
||||
|
||||
class TestTensorUOpCast(unittest.TestCase):
|
||||
def test_cast_str_dtype(self):
|
||||
t = _t(4)
|
||||
self.assertIs(t.cast("float32").uop, t.uop.cast("float32"))
|
||||
self.assertIs(t.uop.cast("float32").dtype, dtypes.float32)
|
||||
|
||||
class TestTensorUOpBitcast(unittest.TestCase):
|
||||
def test_bitcast_same_dtype(self): _check(self, _t(4).float(), lambda x: x.bitcast(dtypes.float32))
|
||||
def test_bitcast_str_dtype(self):
|
||||
t = _t(4)
|
||||
self.assertIs(t.bitcast("uint32").uop, t.uop.bitcast("uint32"))
|
||||
self.assertIs(t.uop.bitcast("uint32").dtype, dtypes.uint32)
|
||||
|
||||
class TestTensorUOpRand(unittest.TestCase):
|
||||
def test_random_bits(self):
|
||||
|
|
|
|||
|
|
@ -504,12 +504,15 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
|
|||
assert self.dtype.vcount == 1
|
||||
if count == 1: return self
|
||||
return UOp(Ops.STACK, self.dtype.vec(count), (self,)*count)
|
||||
def cast(self, dtype:DType):
|
||||
def cast(self, dtype:DTypeLike):
|
||||
dtype = to_dtype(dtype)
|
||||
# TODO: we shouldn't have to check for dtype.count == 1 here, but CAST is misused in AMD LLVM
|
||||
if dtype.count == 1 and dtype.count != self.dtype.count: dtype = dtype.vec(self.dtype.count)
|
||||
if self.dtype == dtype: return self
|
||||
return UOp(Ops.CAST, dtype, (self,))
|
||||
def bitcast(self, dtype:DType): return self if self.dtype == dtype else UOp(Ops.BITCAST, dtype, (self,))
|
||||
def bitcast(self, dtype:DTypeLike):
|
||||
dtype = to_dtype(dtype)
|
||||
return self if self.dtype == dtype else UOp(Ops.BITCAST, dtype, (self,))
|
||||
def gep(self, i:tuple[int, ...]|int):
|
||||
if isinstance(i, tuple) and len(i) == 1: return self.gep(i[0])
|
||||
if isinstance(i, int):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue