mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
parent
a19fa2908f
commit
beea4633fc
3 changed files with 19 additions and 4 deletions
|
|
@ -2,7 +2,10 @@ import math, unittest
|
|||
from tinygrad import Tensor, dtypes
|
||||
from tinygrad.uop.ops import UOp, UPat, Ops, PatternMatcher, graph_rewrite
|
||||
|
||||
_strip_unique_pm = PatternMatcher([(UPat(Ops.CONST, src=(UPat(Ops.UNIQUE), UPat(Ops.DEVICE, name="d")), name="b"), lambda b,d: b.replace(src=(d,))),])
|
||||
_strip_unique_pm = PatternMatcher([
|
||||
(UPat(Ops.CONST, src=(UPat(Ops.UNIQUE), UPat(Ops.DEVICE, name="d")), name="b"), lambda b,d: b.replace(src=(d,))),
|
||||
(UPat(Ops.BUFFER, src=(UPat(Ops.UNIQUE), UPat(Ops.DEVICE, name="d")), name="b"), lambda b,d: b.replace(src=(UOp.unique(0), d))),
|
||||
])
|
||||
def _strip_unique(u: UOp) -> UOp: return graph_rewrite(u, _strip_unique_pm)
|
||||
|
||||
def _t(*shape):
|
||||
|
|
@ -49,6 +52,14 @@ class TestTensorUOpBinop(unittest.TestCase):
|
|||
def test_mod_bool(self): _check(self, _t(4).cast(dtypes.bool), lambda x: x % True)
|
||||
def test_fmod_bool(self): _check(self, _t(4).cast(dtypes.bool), lambda x: x.fmod(True))
|
||||
|
||||
class TestTensorUOpClone(unittest.TestCase):
|
||||
def test_clone(self):
|
||||
t = _t(3, 4).float()
|
||||
self.assertIs(_strip_unique(t.clone().uop), _strip_unique(t.uop.clone()))
|
||||
def test_clone_deviceless_const(self):
|
||||
u = UOp.const(dtypes.float, 2.0)
|
||||
self.assertIs(_strip_unique(Tensor(u).clone().uop), _strip_unique(u.clone()))
|
||||
|
||||
class TestTensorUOpGetitem(unittest.TestCase):
|
||||
# ---- pure slice patterns ----
|
||||
def test_slice_full(self): _check(self, _t(4), lambda x: x[slice(None)])
|
||||
|
|
|
|||
|
|
@ -357,10 +357,9 @@ class Tensor(OpMixin):
|
|||
Creates a clone of this tensor allocating a separate buffer for the data.
|
||||
If `device` is specified, the clone is placed on that device.
|
||||
"""
|
||||
device = device or self.device
|
||||
ret = self.empty_like(device=device)
|
||||
ret = Tensor(self.uop.clone(device=device))
|
||||
if self.grad is not None: ret.grad = self.grad.clone(device=device)
|
||||
return ret.assign(self.to(device))
|
||||
return ret
|
||||
|
||||
def to(self, device:str|tuple[str, ...]|None) -> Tensor:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -703,6 +703,11 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
device = canonicalize_device(self.device if device is None else device)
|
||||
axis = self.axis if isinstance(device, tuple) else None
|
||||
return UOp.empty(self.shard_shape if axis is not None else self.shape, self.dtype if dtype is None else dtype, device, axis)
|
||||
def clone(self, device=None) -> UOp:
|
||||
device = device or self.device
|
||||
ret = self.empty_like(device=device)
|
||||
src = self if self.device is None or self.device == device else self.copy_to_device(device)
|
||||
return ret.after(ret.store(src))
|
||||
@recursive_property
|
||||
def device(self) -> str|tuple[str, ...]|None:
|
||||
if self.op is Ops.DEVICE: return self.arg
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue