mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
fast im2col
This commit is contained in:
parent
c9968756d1
commit
67506eb6ba
2 changed files with 11 additions and 7 deletions
|
|
@ -67,7 +67,7 @@ class TestTinygrad(unittest.TestCase):
|
|||
class TestOps(unittest.TestCase):
|
||||
def test_conv2d(self):
|
||||
x = torch.randn((5,2,10,7), requires_grad=True)
|
||||
w = torch.randn((4,2,3,3), requires_grad=True)
|
||||
w = torch.randn((4,2,3,2), requires_grad=True)
|
||||
xt = Tensor(x.detach().numpy())
|
||||
wt = Tensor(w.detach().numpy())
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import numpy as np
|
||||
from functools import lru_cache
|
||||
|
||||
def mask_like(like, mask_inx, mask_value = 1.0):
|
||||
mask = np.zeros_like(like).reshape(-1)
|
||||
|
|
@ -31,14 +32,17 @@ def fetch_mnist():
|
|||
# these are matlab functions used to speed up convs
|
||||
# write them fast and the convs will be fast?
|
||||
|
||||
@lru_cache
|
||||
def get_im2col_indexes(oy, ox, cin, H, W):
|
||||
idxc = np.tile(np.arange(cin).repeat(H*W), oy*ox)
|
||||
idxy = np.tile(np.arange(H).repeat(W), oy*ox*cin) + np.arange(oy).repeat(ox*cin*H*W)
|
||||
idxx = np.tile(np.arange(W), oy*ox*cin*H) + np.tile(np.arange(ox), oy).repeat(cin*H*W)
|
||||
return idxc, idxy, idxx
|
||||
|
||||
def im2col(x, H, W):
|
||||
bs,cin,oy,ox = x.shape[0], x.shape[1], x.shape[2]-(H-1), x.shape[3]-(W-1)
|
||||
|
||||
# TODO: use something like np.take for speed
|
||||
tx = np.empty((bs, oy, ox, cin*W*H), dtype=x.dtype)
|
||||
for Y in range(oy):
|
||||
for X in range(ox):
|
||||
tx[:, Y, X] = x[:, :, Y:Y+H, X:X+W].reshape(bs, -1)
|
||||
ic, iy, ix = get_im2col_indexes(oy, ox, cin, H, W)
|
||||
tx = x[:, ic, iy, ix]
|
||||
return tx.reshape(-1, cin*W*H)
|
||||
|
||||
def col2im(tx, H, W, OY, OX):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue