mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
rangeify: no copies for write+read of same slice (#16585)
* failing test * cleaner failing tests * assign and read of same slice shouldn't create copies * err in the changes * shrink with no overlapping regions in dest is fine
This commit is contained in:
parent
833cb37574
commit
b2e95b2db3
2 changed files with 27 additions and 1 deletions
|
|
@ -44,6 +44,32 @@ class TestAssign(unittest.TestCase):
|
|||
c.realize()
|
||||
self.assertEqual(GlobalCounters.kernel_count, 1)
|
||||
|
||||
def test_assign_slice(self):
|
||||
X = Tensor([1,2,3,4]).realize()
|
||||
xs = X[2:4]
|
||||
xs.assign(xs+1)
|
||||
GlobalCounters.reset()
|
||||
self.assertListEqual(X.tolist(), [1,2,4,5])
|
||||
self.assertEqual(GlobalCounters.kernel_count, 1)
|
||||
|
||||
def test_assign_slice_alt(self):
|
||||
X = Tensor([1,2,3,4]).realize()
|
||||
xs1, xs2 = X[1:3], X[2:4]
|
||||
xs1.assign(xs2+1)
|
||||
GlobalCounters.reset()
|
||||
self.assertListEqual(X.tolist(), [1,4,5,4])
|
||||
self.assertEqual(GlobalCounters.kernel_count, 2)
|
||||
|
||||
def test_assign_flip(self):
|
||||
ref = np.arange(16, dtype=np.float32)
|
||||
X = Tensor(ref, device="CPU").contiguous().realize()
|
||||
GlobalCounters.reset()
|
||||
xs = X[::-1]
|
||||
xs.assign(xs + X)
|
||||
ref = ref + ref[::-1]
|
||||
np.testing.assert_allclose(X.numpy(), ref)
|
||||
self.assertEqual(GlobalCounters.kernel_count, 2)
|
||||
|
||||
def test_assign_add(self):
|
||||
for T in (1, 2, 10):#, 100): # this crashes in CI, not sure why
|
||||
x = Tensor([0]).realize()
|
||||
|
|
|
|||
|
|
@ -95,7 +95,7 @@ def fix_store_hazard(target:UOp, src:UOp):
|
|||
reaches_base: dict[UOp, bool] = {}
|
||||
for s in src.toposort(gate=lambda s: s.op is not Ops.CONTIGUOUS):
|
||||
reaches_base[s] = s is base or any(reaches_base.get(c) for c in s.src)
|
||||
if reaches_base[s] and s.op in unsafe: return target.store(src.contiguous())
|
||||
if reaches_base[s] and s.op in unsafe and not (s is target and s.op is Ops.SHRINK): return target.store(src.contiguous())
|
||||
|
||||
def split_reduceop(reduce:UOp, x:UOp):
|
||||
if prod(reduce.shape) == 0: return None
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue