contiguous to mixin and cleanups [PR] (#16711)

This commit is contained in:
chenyu 2026-06-22 20:18:18 -04:00 committed by GitHub
commit 15988b5941
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 10 additions and 13 deletions

View file

@ -45,7 +45,11 @@ class ElementwiseMixin(CreationMixin):
"""
return self.cast(dtypes.bool).ne(True)
def contiguous(self, *args, **kwargs) -> Self: raise NotImplementedError
def contiguous(self, **kwargs) -> Self:
"""
Returns a contiguous tensor.
"""
return self._wrap_uop(self._uop.contiguous(**kwargs))
def contiguous_backward(self) -> Self:
"""

View file

@ -130,9 +130,9 @@ class Tensor(RandMixin, metaclass=TensorMeta):
@suppress_finalizing
def __del__(self): all_tensors.pop(weakref.ref(self), None)
def _apply_uop(self, fxn:Callable[..., UOp], *x:Tensor, extra_args=(), **kwargs) -> Tensor:
def _apply_uop(self, fxn:Callable[..., UOp], *x:Tensor, **kwargs) -> Tensor:
srcs = (self,)+x
new_uop: UOp = fxn(*[t.uop for t in srcs], *extra_args, **kwargs)
new_uop: UOp = fxn(*[t.uop for t in srcs], **kwargs)
if TRACEMETA >= 1 and (metadata:=_METADATA.get()) is not None: all_metadata[new_uop] = (metadata,)
# directly create the Tensor
ret = Tensor.__new__(Tensor)
@ -264,6 +264,7 @@ class Tensor(RandMixin, metaclass=TensorMeta):
x = self.cast(self.dtype.base).contiguous()
if self.uop.device is None or isinstance(self.device, tuple): x = x.clone("CPU")
return cast(Buffer, x.realize().uop.buffer).ensure_allocated()
def _data(self) -> memoryview: return self._buffer().as_memoryview()
def data(self) -> memoryview:
@ -279,7 +280,7 @@ class Tensor(RandMixin, metaclass=TensorMeta):
assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}"
assert self.dtype.base.fmt is not None, f"no fmt dtype for {self.dtype.base}"
assert self.dtype.base.fmt != "e" or sys.version_info >= (3, 12)
return self._buffer().as_memoryview().cast(self.dtype.base.fmt, self.shape)
return self._data().cast(self.dtype.base.fmt, self.shape)
# NOTE: list[Any] because return type is recursive (list[list[...]] for higher dimensions)
def tolist(self) -> PyConst|list[Any]:
@ -529,7 +530,7 @@ class Tensor(RandMixin, metaclass=TensorMeta):
# ***** movement ops *****
def _mop(self, op:Ops, arg) -> Tensor: return self._apply_uop(UOp._mop, extra_args=(op,), arg=arg)
def _mop(self, op:Ops, arg) -> Tensor: return self._apply_uop(UOp._mop, op=op, arg=arg)
def _rop(self, op:Ops, axis:tuple[int, ...]) -> Tensor: return self._apply_uop(UOp._rop, op=op, axis=axis)
def __setitem__(self, indices, v:Tensor|PyConst|list|tuple) -> None:
@ -718,14 +719,6 @@ class Tensor(RandMixin, metaclass=TensorMeta):
if IMAGE: return self.image_dot(w, dtype)
return super().dot(w, dtype)
# ***** unary ops *****
def contiguous(self, *args, **kwargs) -> Tensor:
"""
Returns a contiguous tensor.
"""
return self._apply_uop(UOp.contiguous, extra_args=args, **kwargs)
# ***** broadcasted elementwise ops *****
def where(self:Tensor, x:Tensor|ConstType|sint, y:Tensor|ConstType|sint) -> Tensor: