UOp.clone [pr] (#16295)

generates the store after structure
This commit is contained in:
chenyu 2026-05-20 17:47:49 -04:00 committed by GitHub
commit beea4633fc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 19 additions and 4 deletions

View file

@ -2,7 +2,10 @@ import math, unittest
from tinygrad import Tensor, dtypes
from tinygrad.uop.ops import UOp, UPat, Ops, PatternMatcher, graph_rewrite
_strip_unique_pm = PatternMatcher([(UPat(Ops.CONST, src=(UPat(Ops.UNIQUE), UPat(Ops.DEVICE, name="d")), name="b"), lambda b,d: b.replace(src=(d,))),])
_strip_unique_pm = PatternMatcher([
(UPat(Ops.CONST, src=(UPat(Ops.UNIQUE), UPat(Ops.DEVICE, name="d")), name="b"), lambda b,d: b.replace(src=(d,))),
(UPat(Ops.BUFFER, src=(UPat(Ops.UNIQUE), UPat(Ops.DEVICE, name="d")), name="b"), lambda b,d: b.replace(src=(UOp.unique(0), d))),
])
def _strip_unique(u: UOp) -> UOp: return graph_rewrite(u, _strip_unique_pm)
def _t(*shape):
@ -49,6 +52,14 @@ class TestTensorUOpBinop(unittest.TestCase):
def test_mod_bool(self): _check(self, _t(4).cast(dtypes.bool), lambda x: x % True)
def test_fmod_bool(self): _check(self, _t(4).cast(dtypes.bool), lambda x: x.fmod(True))
class TestTensorUOpClone(unittest.TestCase):
def test_clone(self):
t = _t(3, 4).float()
self.assertIs(_strip_unique(t.clone().uop), _strip_unique(t.uop.clone()))
def test_clone_deviceless_const(self):
u = UOp.const(dtypes.float, 2.0)
self.assertIs(_strip_unique(Tensor(u).clone().uop), _strip_unique(u.clone()))
class TestTensorUOpGetitem(unittest.TestCase):
# ---- pure slice patterns ----
def test_slice_full(self): _check(self, _t(4), lambda x: x[slice(None)])

View file

@ -357,10 +357,9 @@ class Tensor(OpMixin):
Creates a clone of this tensor allocating a separate buffer for the data.
If `device` is specified, the clone is placed on that device.
"""
device = device or self.device
ret = self.empty_like(device=device)
ret = Tensor(self.uop.clone(device=device))
if self.grad is not None: ret.grad = self.grad.clone(device=device)
return ret.assign(self.to(device))
return ret
def to(self, device:str|tuple[str, ...]|None) -> Tensor:
"""

View file

@ -703,6 +703,11 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
device = canonicalize_device(self.device if device is None else device)
axis = self.axis if isinstance(device, tuple) else None
return UOp.empty(self.shard_shape if axis is not None else self.shape, self.dtype if dtype is None else dtype, device, axis)
def clone(self, device=None) -> UOp:
device = device or self.device
ret = self.empty_like(device=device)
src = self if self.device is None or self.device == device else self.copy_to_device(device)
return ret.after(ret.store(src))
@recursive_property
def device(self) -> str|tuple[str, ...]|None:
if self.op is Ops.DEVICE: return self.arg