mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
fix repeat_interleave with negative dim (#7734)
This commit is contained in:
parent
e3105675fb
commit
f1efd84c92
2 changed files with 3 additions and 1 deletions
|
|
@ -2047,6 +2047,8 @@ class TestOps(unittest.TestCase):
|
|||
helper_test_op([(3, 3)], lambda x: x.repeat_interleave(6))
|
||||
helper_test_op([(3, 3)], lambda x: x.repeat_interleave(2, 1))
|
||||
helper_test_op([(3, 3)], lambda x: x.repeat_interleave(2, 0))
|
||||
helper_test_op([(3, 3)], lambda x: x.repeat_interleave(2, -1))
|
||||
helper_test_op([(3, 3)], lambda x: x.repeat_interleave(2, -2))
|
||||
|
||||
def test_simple_repeat(self):
|
||||
repeats = [3, 3, 4]
|
||||
|
|
|
|||
|
|
@ -1279,7 +1279,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
|||
print(t.repeat_interleave(2).numpy())
|
||||
```
|
||||
"""
|
||||
x, dim = (self.flatten(), 0) if dim is None else (self, dim)
|
||||
x, dim = (self.flatten(), 0) if dim is None else (self, self._resolve_dim(dim))
|
||||
shp = x.shape
|
||||
return x.reshape(*shp[:dim+1], 1, *shp[dim+1:]).expand(*shp[:dim+1], repeats, *shp[dim+1:]).reshape(*shp[:dim], shp[dim]*repeats, *shp[dim+1:])
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue