Tensor.repeat cleanup (#7735)

flatten instead of double for loop comprehension
This commit is contained in:
chenyu 2024-11-16 10:43:45 -05:00 committed by GitHub
commit e777211a00
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -1297,11 +1297,11 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
```
"""
repeats = argfix(repeats, *args)
base_shape = (1,) * (len(repeats) - self.ndim) + self.shape
new_shape = [x for b in base_shape for x in [1, b]]
expand_shape = [x for rs in zip(repeats, base_shape) for x in rs]
base_shape = _pad_left(self.shape, repeats)[0]
unsqueezed_shape = flatten([[1, s] for s in base_shape])
expanded_shape = flatten([[r, s] for r,s in zip(repeats, base_shape)])
final_shape = [r*s for r,s in zip(repeats, base_shape)]
return self.reshape(new_shape).expand(expand_shape).reshape(final_shape)
return self.reshape(unsqueezed_shape).expand(expanded_shape).reshape(final_shape)
def _resolve_dim(self, dim:int, *, outer:bool=False) -> int:
if not -max(1, self.ndim+outer) <= dim < max(1, self.ndim+outer):