remove the "no copy" line from copy_to_device (#8702)

* delete the no copy one

* add tests
This commit is contained in:
qazal 2025-01-21 10:09:33 -05:00 committed by GitHub
commit d6bf1feaab
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 11 additions and 2 deletions

View file

@ -2258,6 +2258,17 @@ class TestCopyFolding(unittest.TestCase):
add = schedule_graph_rewrite(add)
assert all_same([x.device for x in add.src]), f"ALU has different devices! {[x.device for x in add.src]}"
def test_copy_to_same_device(self):
a = Tensor.empty(4).lazydata
b = a.copy_to_device(a.device)
check_schedule(b, 0, filter_sink=False)
b = schedule_graph_rewrite(b)
self.assertIs(b, a)
def test_clone(self):
a = Tensor.empty(4).lazydata
check_schedule(a.clone(), 1, filter_sink=False)
class TestTensorUOpSpec(unittest.TestCase):
def test_const_must_be_unmasked(self):
a = Tensor.ones((4, 4)).pad((2, 2))

View file

@ -427,8 +427,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
# otherwise it's just a VIEW(BUFFER)
return UOp(Ops.VIEW, dtype, (UOp.new_buffer(device, (st:=ShapeTracker.from_shape(shape)).size, dtype),), st)
def copy_to_device(self, device:str, clone:bool=False) -> UOp:
# no COPY
if self.device == device and not clone: return self
# if it's a shrink, do the shrink before the copy with CONTIGUOUS
if prod(self.shape) < prod(self.base.shape): return self.contiguous().copy_to_device(device)
# COPY is COPY(DEVICE, copyin.base) -> VIEW(copyin.st)