mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
add fun with nhwc convs
This commit is contained in:
parent
907499b02c
commit
f5a3b33d33
1 changed files with 38 additions and 0 deletions
38
test/external/external_nhwc_conv.py
vendored
Normal file
38
test/external/external_nhwc_conv.py
vendored
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
from tinygrad import Tensor, nn, Context, GlobalCounters
|
||||
|
||||
if __name__ == "__main__":
|
||||
conv = nn.Conv2d(64, 128, 3)
|
||||
img = Tensor.randn((1,64,128,128))
|
||||
with Context(DEBUG=0, BEAM=0):
|
||||
Tensor.realize(img, conv.weight, conv.bias)
|
||||
|
||||
tst = conv(img).permute(0,2,3,1).realize()
|
||||
print(tst.shape)
|
||||
|
||||
print("NEW")
|
||||
img_perm = img.permute(0,2,3,1).contiguous()
|
||||
print(img_perm.shape)
|
||||
pp = img_perm.permute(0,3,1,2)._pool((3,3)).permute(0,2,3,4,5,1)
|
||||
|
||||
def hwio(pp, conv):
|
||||
pp = pp.unsqueeze(-1)
|
||||
weight = conv.weight.permute(2,3,1,0).contiguous()
|
||||
print(pp.shape, weight.shape, (pp*weight).shape)
|
||||
return (pp * weight).sum([-4,-3, -2])
|
||||
|
||||
def ohwi(pp, conv):
|
||||
pp = pp.unsqueeze(-4)
|
||||
weight = conv.weight.permute(0,2,3,1).contiguous()
|
||||
print(pp.shape, weight.shape, (pp*weight).shape)
|
||||
return (pp * weight).sum([-3,-2,-1])
|
||||
|
||||
for f in [hwio, ohwi]:
|
||||
GlobalCounters.reset()
|
||||
print("\n**************", f.__name__, "**************")
|
||||
out = f(pp, conv)
|
||||
out.realize()
|
||||
print(out.shape)
|
||||
|
||||
with Context(DEBUG=0, BEAM=0):
|
||||
err = (tst-out).square()
|
||||
print(err.mean().item(), err.max().item())
|
||||
Loading…
Add table
Add a link
Reference in a new issue