rangeify: no copies for write+read of same slice (#16585)

* failing test

* cleaner failing tests

* assign and read of same slice shouldn't create copies

* err in the changes

* shrink with no overlapping regions in dest is fine
This commit is contained in:
qazal 2026-06-13 01:19:47 +08:00 committed by GitHub
commit b2e95b2db3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 27 additions and 1 deletions

View file

@ -44,6 +44,32 @@ class TestAssign(unittest.TestCase):
c.realize()
self.assertEqual(GlobalCounters.kernel_count, 1)
def test_assign_slice(self):
X = Tensor([1,2,3,4]).realize()
xs = X[2:4]
xs.assign(xs+1)
GlobalCounters.reset()
self.assertListEqual(X.tolist(), [1,2,4,5])
self.assertEqual(GlobalCounters.kernel_count, 1)
def test_assign_slice_alt(self):
X = Tensor([1,2,3,4]).realize()
xs1, xs2 = X[1:3], X[2:4]
xs1.assign(xs2+1)
GlobalCounters.reset()
self.assertListEqual(X.tolist(), [1,4,5,4])
self.assertEqual(GlobalCounters.kernel_count, 2)
def test_assign_flip(self):
ref = np.arange(16, dtype=np.float32)
X = Tensor(ref, device="CPU").contiguous().realize()
GlobalCounters.reset()
xs = X[::-1]
xs.assign(xs + X)
ref = ref + ref[::-1]
np.testing.assert_allclose(X.numpy(), ref)
self.assertEqual(GlobalCounters.kernel_count, 2)
def test_assign_add(self):
for T in (1, 2, 10):#, 100): # this crashes in CI, not sure why
x = Tensor([0]).realize()

View file

@ -95,7 +95,7 @@ def fix_store_hazard(target:UOp, src:UOp):
reaches_base: dict[UOp, bool] = {}
for s in src.toposort(gate=lambda s: s.op is not Ops.CONTIGUOUS):
reaches_base[s] = s is base or any(reaches_base.get(c) for c in s.src)
if reaches_base[s] and s.op in unsafe: return target.store(src.contiguous())
if reaches_base[s] and s.op in unsafe and not (s is target and s.op is Ops.SHRINK): return target.store(src.contiguous())
def split_reduceop(reduce:UOp, x:UOp):
if prod(reduce.shape) == 0: return None