mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
remove the "no copy" line from copy_to_device (#8702)
* delete the no copy one * add tests
This commit is contained in:
parent
3628f89929
commit
d6bf1feaab
2 changed files with 11 additions and 2 deletions
|
|
@ -2258,6 +2258,17 @@ class TestCopyFolding(unittest.TestCase):
|
|||
add = schedule_graph_rewrite(add)
|
||||
assert all_same([x.device for x in add.src]), f"ALU has different devices! {[x.device for x in add.src]}"
|
||||
|
||||
def test_copy_to_same_device(self):
|
||||
a = Tensor.empty(4).lazydata
|
||||
b = a.copy_to_device(a.device)
|
||||
check_schedule(b, 0, filter_sink=False)
|
||||
b = schedule_graph_rewrite(b)
|
||||
self.assertIs(b, a)
|
||||
|
||||
def test_clone(self):
|
||||
a = Tensor.empty(4).lazydata
|
||||
check_schedule(a.clone(), 1, filter_sink=False)
|
||||
|
||||
class TestTensorUOpSpec(unittest.TestCase):
|
||||
def test_const_must_be_unmasked(self):
|
||||
a = Tensor.ones((4, 4)).pad((2, 2))
|
||||
|
|
|
|||
|
|
@ -427,8 +427,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
|||
# otherwise it's just a VIEW(BUFFER)
|
||||
return UOp(Ops.VIEW, dtype, (UOp.new_buffer(device, (st:=ShapeTracker.from_shape(shape)).size, dtype),), st)
|
||||
def copy_to_device(self, device:str, clone:bool=False) -> UOp:
|
||||
# no COPY
|
||||
if self.device == device and not clone: return self
|
||||
# if it's a shrink, do the shrink before the copy with CONTIGUOUS
|
||||
if prod(self.shape) < prod(self.base.shape): return self.contiguous().copy_to_device(device)
|
||||
# COPY is COPY(DEVICE, copyin.base) -> VIEW(copyin.st)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue