mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
contiguous to mixin and cleanups [PR] (#16711)
This commit is contained in:
parent
cbfcf36e44
commit
15988b5941
2 changed files with 10 additions and 13 deletions
|
|
@ -45,7 +45,11 @@ class ElementwiseMixin(CreationMixin):
|
|||
"""
|
||||
return self.cast(dtypes.bool).ne(True)
|
||||
|
||||
def contiguous(self, *args, **kwargs) -> Self: raise NotImplementedError
|
||||
def contiguous(self, **kwargs) -> Self:
|
||||
"""
|
||||
Returns a contiguous tensor.
|
||||
"""
|
||||
return self._wrap_uop(self._uop.contiguous(**kwargs))
|
||||
|
||||
def contiguous_backward(self) -> Self:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -130,9 +130,9 @@ class Tensor(RandMixin, metaclass=TensorMeta):
|
|||
@suppress_finalizing
|
||||
def __del__(self): all_tensors.pop(weakref.ref(self), None)
|
||||
|
||||
def _apply_uop(self, fxn:Callable[..., UOp], *x:Tensor, extra_args=(), **kwargs) -> Tensor:
|
||||
def _apply_uop(self, fxn:Callable[..., UOp], *x:Tensor, **kwargs) -> Tensor:
|
||||
srcs = (self,)+x
|
||||
new_uop: UOp = fxn(*[t.uop for t in srcs], *extra_args, **kwargs)
|
||||
new_uop: UOp = fxn(*[t.uop for t in srcs], **kwargs)
|
||||
if TRACEMETA >= 1 and (metadata:=_METADATA.get()) is not None: all_metadata[new_uop] = (metadata,)
|
||||
# directly create the Tensor
|
||||
ret = Tensor.__new__(Tensor)
|
||||
|
|
@ -264,6 +264,7 @@ class Tensor(RandMixin, metaclass=TensorMeta):
|
|||
x = self.cast(self.dtype.base).contiguous()
|
||||
if self.uop.device is None or isinstance(self.device, tuple): x = x.clone("CPU")
|
||||
return cast(Buffer, x.realize().uop.buffer).ensure_allocated()
|
||||
|
||||
def _data(self) -> memoryview: return self._buffer().as_memoryview()
|
||||
|
||||
def data(self) -> memoryview:
|
||||
|
|
@ -279,7 +280,7 @@ class Tensor(RandMixin, metaclass=TensorMeta):
|
|||
assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}"
|
||||
assert self.dtype.base.fmt is not None, f"no fmt dtype for {self.dtype.base}"
|
||||
assert self.dtype.base.fmt != "e" or sys.version_info >= (3, 12)
|
||||
return self._buffer().as_memoryview().cast(self.dtype.base.fmt, self.shape)
|
||||
return self._data().cast(self.dtype.base.fmt, self.shape)
|
||||
|
||||
# NOTE: list[Any] because return type is recursive (list[list[...]] for higher dimensions)
|
||||
def tolist(self) -> PyConst|list[Any]:
|
||||
|
|
@ -529,7 +530,7 @@ class Tensor(RandMixin, metaclass=TensorMeta):
|
|||
|
||||
# ***** movement ops *****
|
||||
|
||||
def _mop(self, op:Ops, arg) -> Tensor: return self._apply_uop(UOp._mop, extra_args=(op,), arg=arg)
|
||||
def _mop(self, op:Ops, arg) -> Tensor: return self._apply_uop(UOp._mop, op=op, arg=arg)
|
||||
def _rop(self, op:Ops, axis:tuple[int, ...]) -> Tensor: return self._apply_uop(UOp._rop, op=op, axis=axis)
|
||||
|
||||
def __setitem__(self, indices, v:Tensor|PyConst|list|tuple) -> None:
|
||||
|
|
@ -718,14 +719,6 @@ class Tensor(RandMixin, metaclass=TensorMeta):
|
|||
if IMAGE: return self.image_dot(w, dtype)
|
||||
return super().dot(w, dtype)
|
||||
|
||||
# ***** unary ops *****
|
||||
|
||||
def contiguous(self, *args, **kwargs) -> Tensor:
|
||||
"""
|
||||
Returns a contiguous tensor.
|
||||
"""
|
||||
return self._apply_uop(UOp.contiguous, extra_args=args, **kwargs)
|
||||
|
||||
# ***** broadcasted elementwise ops *****
|
||||
|
||||
def where(self:Tensor, x:Tensor|ConstType|sint, y:Tensor|ConstType|sint) -> Tensor:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue