Add Interpolate Function (#5482)

* add interpolate function

* fixed linter issue

* reduced sizes in test

---------

Co-authored-by: wozeparrot <wozeparrot@gmail.com>
This commit is contained in:
Tobias Fischer 2024-07-16 12:44:01 -04:00 committed by GitHub
commit 87a2ef2bc2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 63 additions and 0 deletions

View file

@ -1695,6 +1695,42 @@ class TestOps(unittest.TestCase):
lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=(111,28)),
lambda x: Tensor.avg_pool2d(x, kernel_size=(111,28)), rtol=1e-5)
def test_interpolate_linear(self):
for in_sz, out_sz in [((52,),(29,)), ((29,),(52,))]:
helper_test_op([(2,3)+in_sz],
lambda x: torch.nn.functional.interpolate(x, size=out_sz, mode="linear"),
lambda x: Tensor.interpolate(x, size=out_sz, mode="linear"))
def test_interpolate_linear_corners_aligned(self):
for in_sz, out_sz in [((52,),(29,)), ((29,),(52,))]:
helper_test_op([(2,3)+in_sz],
lambda x: torch.nn.functional.interpolate(x, size=out_sz, mode="linear", align_corners=True),
lambda x: Tensor.interpolate(x, size=out_sz, mode="linear", align_corners=True))
def test_interpolate_bilinear(self):
for in_sz, out_sz in [((52,40),(29,31)), ((52,29),(31,40)), ((29,31),(40,52))]:
helper_test_op([(2,3)+in_sz],
lambda x: torch.nn.functional.interpolate(x, size=out_sz, mode="bilinear"),
lambda x: Tensor.interpolate(x, size=out_sz, mode="linear"), atol=1e-4)
def test_interpolate_bilinear_corners_aligned(self):
for in_sz, out_sz in [((52,40),(29,31)), ((52,29),(31,40)), ((29,31),(40,52))]:
helper_test_op([(2,3)+in_sz],
lambda x: torch.nn.functional.interpolate(x, size=out_sz, mode="bilinear", align_corners=True),
lambda x: Tensor.interpolate(x, size=out_sz, mode="linear", align_corners=True), atol=1e-4)
def test_interpolate_trilinear(self):
for in_sz, out_sz in [((5,2,8),(3,6,4))]:
helper_test_op([(2,3)+in_sz],
lambda x: torch.nn.functional.interpolate(x, size=out_sz, mode="trilinear"),
lambda x: Tensor.interpolate(x, size=out_sz, mode="linear"), atol=1e-4)
def test_interpolate_trilinear_corners_aligned(self):
for in_sz, out_sz in [((5,2,8),(3,6,4))]:
helper_test_op([(2,3)+in_sz],
lambda x: torch.nn.functional.interpolate(x, size=out_sz, mode="trilinear", align_corners=True),
lambda x: Tensor.interpolate(x, size=out_sz, mode="linear", align_corners=True), atol=1e-4)
def test_cat(self):
for dim in range(-2, 3):
helper_test_op([(45,65,9), (45,65,9), (45,65,9)], lambda x,y,z: torch.cat((x,y,z), dim), lambda x,y,z: x.cat(y, z, dim=dim))

View file

@ -1922,6 +1922,33 @@ class Tensor:
"""
return Tensor._tri(self.shape[-2], self.shape[-1], diagonal=diagonal+1, device=self.device, dtype=dtypes.bool).where(0, self).cast(self.dtype)
def interpolate(self, size:Tuple[int, ...], mode:str="linear", align_corners:bool=False) -> Tensor:
"""
Downsamples or Upsamples to the requested size, accepts 0 to N batch dimensions.
The type of sampling is selected with `mode` which currently only supports `linear`.
To run `bilinear` or `trilinear` pass in a 2D or 3D size.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([[1, 2, 3, 4], [21, 22, 23, 24], [41, 42, 43, 44]])
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.interpolate(size=(2,3), mode="linear").numpy())
```
"""
assert isinstance(size, (tuple,list)) and all(isinstance(s, int) for s in size) and len(size) > 0 and len(size) <= self.ndim
assert mode == "linear", "only linear interpolate supported right now"
x, expand = self, list(s for s in self.shape)
for i in range(-len(size), 0):
scale = (self.shape[i] - (1 if align_corners else 0)) / (size[i] - (1 if align_corners else 0))
arr, reshape = Tensor.arange(size[i]).cast(dtypes.float32), [1 for _ in range(self.ndim)]
index = (scale*arr if align_corners else (scale*(arr+0.5))-0.5).clip(0, self.shape[i]-1)
reshape[i] = expand[i] = size[i]
low, high, perc = [y.reshape(reshape).expand(expand) for y in (index.floor(), index.ceil(), index - index.floor())]
x = x.gather(i, low)*(1.0 - perc) + x.gather(i, high)*perc
return x
# ***** unary ops *****
def logical_not(self):