fix repeat_interleave with negative dim (#7734)

This commit is contained in:
chenyu 2024-11-16 10:15:29 -05:00 committed by GitHub
commit f1efd84c92
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 3 additions and 1 deletions

View file

@ -2047,6 +2047,8 @@ class TestOps(unittest.TestCase):
helper_test_op([(3, 3)], lambda x: x.repeat_interleave(6))
helper_test_op([(3, 3)], lambda x: x.repeat_interleave(2, 1))
helper_test_op([(3, 3)], lambda x: x.repeat_interleave(2, 0))
helper_test_op([(3, 3)], lambda x: x.repeat_interleave(2, -1))
helper_test_op([(3, 3)], lambda x: x.repeat_interleave(2, -2))
def test_simple_repeat(self):
repeats = [3, 3, 4]

View file

@ -1279,7 +1279,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
print(t.repeat_interleave(2).numpy())
```
"""
x, dim = (self.flatten(), 0) if dim is None else (self, dim)
x, dim = (self.flatten(), 0) if dim is None else (self, self._resolve_dim(dim))
shp = x.shape
return x.reshape(*shp[:dim+1], 1, *shp[dim+1:]).expand(*shp[:dim+1], repeats, *shp[dim+1:]).reshape(*shp[:dim], shp[dim]*repeats, *shp[dim+1:])