mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
winograd should be 4 kernels (#3268)
This commit is contained in:
parent
f48b6aca77
commit
085dc87bed
2 changed files with 9 additions and 2 deletions
|
|
@ -1,6 +1,6 @@
|
|||
import unittest
|
||||
from tinygrad import Tensor, GlobalCounters
|
||||
from tinygrad.helpers import Timing, CI, Profiling, WINO
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.ops import LoadOps
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
|
||||
|
|
@ -35,5 +35,12 @@ class TestWinograd(unittest.TestCase):
|
|||
out = Tensor.conv2d(x,w).realize()
|
||||
out.numpy()
|
||||
|
||||
def test_four_kernels(self):
|
||||
x,w = Tensor.rand(1,4,9,9).realize(), Tensor.rand(4,4,3,3).realize()
|
||||
GlobalCounters.reset()
|
||||
out = Tensor.conv2d(x,w).realize()
|
||||
assert GlobalCounters.kernel_count == 4
|
||||
out.numpy()
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(verbosity=2)
|
||||
|
|
@ -218,7 +218,7 @@ def _recurse_lb(buf:LazyBuffer, realizes:Set[LazyBuffer], allbufs:Dict[LazyBuffe
|
|||
# realize all places where the buffer is expanded
|
||||
if prod(buf.base.st.shape) < prod(buf.st.shape):
|
||||
if len(buf.st.views) == 1 and buf.st.views[-1].mask and all_int(buf.base.st.shape) and \
|
||||
prod(buf.base.st.shape) == prod([y-x for x,y in buf.st.views[-1].mask]):
|
||||
prod(buf.base.st.shape) >= prod([y-x for x,y in buf.st.views[-1].mask]):
|
||||
simple_pads.add(buf.base)
|
||||
else:
|
||||
realizes.add(buf.base)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue