mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
feat: assert on bufferview math (#12772)
This commit is contained in:
parent
fcdf4ab37e
commit
82f10cfe2e
2 changed files with 7 additions and 7 deletions
|
|
@ -433,17 +433,15 @@ class TestDiskTensorMovement(unittest.TestCase):
|
|||
t = Tensor(self.fn)
|
||||
self.assertListEqual(t[16:18].tolist(), [16,17])
|
||||
|
||||
# TODO: fix this! at least assert on it
|
||||
@unittest.expectedFailure
|
||||
def test_slice_read_cat(self):
|
||||
t = Tensor(self.fn)
|
||||
self.assertListEqual(Tensor.cat(t[16:18], t[20:22]).tolist(), [16,17,20,21])
|
||||
with self.assertRaises(AssertionError):
|
||||
self.assertListEqual(Tensor.cat(t[16:18], t[20:22]).tolist(), [16,17,20,21])
|
||||
|
||||
# TODO: fix this! at least assert on it
|
||||
@unittest.expectedFailure
|
||||
def test_slice_sum(self):
|
||||
t = Tensor(self.fn)
|
||||
self.assertListEqual((t[16:18]+t[20:22]).tolist(), [16+20,17+21])
|
||||
with self.assertRaises(AssertionError):
|
||||
self.assertListEqual((t[16:18]+t[20:22]).tolist(), [16+20,17+21])
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -218,7 +218,9 @@ def late_buffer_view(t:UOp, b:UOp):
|
|||
|
||||
# walk up for the INDEX
|
||||
x = t
|
||||
while not any(u.op is Ops.INDEX for u in x.src): x = x.src[0]
|
||||
while not any(u.op is Ops.INDEX for u in x.src):
|
||||
assert x.op not in GroupOp.Elementwise, "can't buffer view elementwise"
|
||||
x = x.src[0]
|
||||
x = next(u for u in x.src if u.op is Ops.INDEX)
|
||||
|
||||
if len(shape) == 0: offset = x.src[1].arg
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue