mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Add Interpolate Function (#5482)
* add interpolate function * fixed linter issue * reduced sizes in test --------- Co-authored-by: wozeparrot <wozeparrot@gmail.com>
This commit is contained in:
parent
203161c75d
commit
87a2ef2bc2
2 changed files with 63 additions and 0 deletions
|
|
@ -1695,6 +1695,42 @@ class TestOps(unittest.TestCase):
|
|||
lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=(111,28)),
|
||||
lambda x: Tensor.avg_pool2d(x, kernel_size=(111,28)), rtol=1e-5)
|
||||
|
||||
def test_interpolate_linear(self):
|
||||
for in_sz, out_sz in [((52,),(29,)), ((29,),(52,))]:
|
||||
helper_test_op([(2,3)+in_sz],
|
||||
lambda x: torch.nn.functional.interpolate(x, size=out_sz, mode="linear"),
|
||||
lambda x: Tensor.interpolate(x, size=out_sz, mode="linear"))
|
||||
|
||||
def test_interpolate_linear_corners_aligned(self):
|
||||
for in_sz, out_sz in [((52,),(29,)), ((29,),(52,))]:
|
||||
helper_test_op([(2,3)+in_sz],
|
||||
lambda x: torch.nn.functional.interpolate(x, size=out_sz, mode="linear", align_corners=True),
|
||||
lambda x: Tensor.interpolate(x, size=out_sz, mode="linear", align_corners=True))
|
||||
|
||||
def test_interpolate_bilinear(self):
|
||||
for in_sz, out_sz in [((52,40),(29,31)), ((52,29),(31,40)), ((29,31),(40,52))]:
|
||||
helper_test_op([(2,3)+in_sz],
|
||||
lambda x: torch.nn.functional.interpolate(x, size=out_sz, mode="bilinear"),
|
||||
lambda x: Tensor.interpolate(x, size=out_sz, mode="linear"), atol=1e-4)
|
||||
|
||||
def test_interpolate_bilinear_corners_aligned(self):
|
||||
for in_sz, out_sz in [((52,40),(29,31)), ((52,29),(31,40)), ((29,31),(40,52))]:
|
||||
helper_test_op([(2,3)+in_sz],
|
||||
lambda x: torch.nn.functional.interpolate(x, size=out_sz, mode="bilinear", align_corners=True),
|
||||
lambda x: Tensor.interpolate(x, size=out_sz, mode="linear", align_corners=True), atol=1e-4)
|
||||
|
||||
def test_interpolate_trilinear(self):
|
||||
for in_sz, out_sz in [((5,2,8),(3,6,4))]:
|
||||
helper_test_op([(2,3)+in_sz],
|
||||
lambda x: torch.nn.functional.interpolate(x, size=out_sz, mode="trilinear"),
|
||||
lambda x: Tensor.interpolate(x, size=out_sz, mode="linear"), atol=1e-4)
|
||||
|
||||
def test_interpolate_trilinear_corners_aligned(self):
|
||||
for in_sz, out_sz in [((5,2,8),(3,6,4))]:
|
||||
helper_test_op([(2,3)+in_sz],
|
||||
lambda x: torch.nn.functional.interpolate(x, size=out_sz, mode="trilinear", align_corners=True),
|
||||
lambda x: Tensor.interpolate(x, size=out_sz, mode="linear", align_corners=True), atol=1e-4)
|
||||
|
||||
def test_cat(self):
|
||||
for dim in range(-2, 3):
|
||||
helper_test_op([(45,65,9), (45,65,9), (45,65,9)], lambda x,y,z: torch.cat((x,y,z), dim), lambda x,y,z: x.cat(y, z, dim=dim))
|
||||
|
|
|
|||
|
|
@ -1922,6 +1922,33 @@ class Tensor:
|
|||
"""
|
||||
return Tensor._tri(self.shape[-2], self.shape[-1], diagonal=diagonal+1, device=self.device, dtype=dtypes.bool).where(0, self).cast(self.dtype)
|
||||
|
||||
def interpolate(self, size:Tuple[int, ...], mode:str="linear", align_corners:bool=False) -> Tensor:
|
||||
"""
|
||||
Downsamples or Upsamples to the requested size, accepts 0 to N batch dimensions.
|
||||
|
||||
The type of sampling is selected with `mode` which currently only supports `linear`.
|
||||
To run `bilinear` or `trilinear` pass in a 2D or 3D size.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor([[1, 2, 3, 4], [21, 22, 23, 24], [41, 42, 43, 44]])
|
||||
print(t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.interpolate(size=(2,3), mode="linear").numpy())
|
||||
```
|
||||
"""
|
||||
assert isinstance(size, (tuple,list)) and all(isinstance(s, int) for s in size) and len(size) > 0 and len(size) <= self.ndim
|
||||
assert mode == "linear", "only linear interpolate supported right now"
|
||||
x, expand = self, list(s for s in self.shape)
|
||||
for i in range(-len(size), 0):
|
||||
scale = (self.shape[i] - (1 if align_corners else 0)) / (size[i] - (1 if align_corners else 0))
|
||||
arr, reshape = Tensor.arange(size[i]).cast(dtypes.float32), [1 for _ in range(self.ndim)]
|
||||
index = (scale*arr if align_corners else (scale*(arr+0.5))-0.5).clip(0, self.shape[i]-1)
|
||||
reshape[i] = expand[i] = size[i]
|
||||
low, high, perc = [y.reshape(reshape).expand(expand) for y in (index.floor(), index.ceil(), index - index.floor())]
|
||||
x = x.gather(i, low)*(1.0 - perc) + x.gather(i, high)*perc
|
||||
return x
|
||||
|
||||
# ***** unary ops *****
|
||||
|
||||
def logical_not(self):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue