winograd should be 4 kernels (#3268)

This commit is contained in:
George Hotz 2024-01-28 09:21:26 -08:00 committed by GitHub
commit 085dc87bed
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 9 additions and 2 deletions

View file

@ -1,6 +1,6 @@
import unittest
from tinygrad import Tensor, GlobalCounters
from tinygrad.helpers import Timing, CI, Profiling, WINO
from tinygrad.tensor import Tensor
from tinygrad.ops import LoadOps
from tinygrad.codegen.linearizer import Linearizer
@ -35,5 +35,12 @@ class TestWinograd(unittest.TestCase):
out = Tensor.conv2d(x,w).realize()
out.numpy()
def test_four_kernels(self):
x,w = Tensor.rand(1,4,9,9).realize(), Tensor.rand(4,4,3,3).realize()
GlobalCounters.reset()
out = Tensor.conv2d(x,w).realize()
assert GlobalCounters.kernel_count == 4
out.numpy()
if __name__ == '__main__':
unittest.main(verbosity=2)

View file

@ -218,7 +218,7 @@ def _recurse_lb(buf:LazyBuffer, realizes:Set[LazyBuffer], allbufs:Dict[LazyBuffe
# realize all places where the buffer is expanded
if prod(buf.base.st.shape) < prod(buf.st.shape):
if len(buf.st.views) == 1 and buf.st.views[-1].mask and all_int(buf.base.st.shape) and \
prod(buf.base.st.shape) == prod([y-x for x,y in buf.st.views[-1].mask]):
prod(buf.base.st.shape) >= prod([y-x for x,y in buf.st.views[-1].mask]):
simple_pads.add(buf.base)
else:
realizes.add(buf.base)