mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
test exact kernel count in torch_backend/test_kernel_fusion (#15091)
This commit is contained in:
parent
f80b1033c5
commit
71f228f80f
1 changed files with 6 additions and 9 deletions
|
|
@ -1,7 +1,6 @@
|
|||
# simple tests
|
||||
import unittest
|
||||
import torch
|
||||
import warnings
|
||||
from tinygrad.helpers import getenv, GlobalCounters
|
||||
if getenv("TINY_BACKEND2"):
|
||||
import extra.torch_backend.backend2
|
||||
|
|
@ -18,9 +17,7 @@ class TestKernelFusionRegression(unittest.TestCase):
|
|||
torch.manual_seed(42)
|
||||
GlobalCounters.reset()
|
||||
fn().detach().cpu().numpy()
|
||||
expectation = f"{GlobalCounters.kernel_count} vs {expected_kernels} expected."
|
||||
if GlobalCounters.kernel_count < expected_kernels: warnings.warn(f"{expectation} Expectation can be lowered.", UserWarning)
|
||||
self.assertLessEqual(GlobalCounters.kernel_count, expected_kernels, f"{expectation}")
|
||||
self.assertEqual(GlobalCounters.kernel_count, expected_kernels)
|
||||
|
||||
def test_elementwise_fusion(self):
|
||||
def fn():
|
||||
|
|
@ -34,7 +31,7 @@ class TestKernelFusionRegression(unittest.TestCase):
|
|||
conv = torch.nn.Conv2d(3, 16, 3, padding=1).to(device)
|
||||
with torch.no_grad():
|
||||
return torch.nn.functional.relu(conv(x))
|
||||
self._check_kernel_count(fn, 8)
|
||||
self._check_kernel_count(fn, 6)
|
||||
|
||||
def test_batchnorm_fusion(self):
|
||||
def fn():
|
||||
|
|
@ -44,7 +41,7 @@ class TestKernelFusionRegression(unittest.TestCase):
|
|||
bn.eval()
|
||||
with torch.no_grad():
|
||||
return torch.nn.functional.relu(bn(conv(x)))
|
||||
self._check_kernel_count(fn, 16)
|
||||
self._check_kernel_count(fn, 10)
|
||||
|
||||
def test_reduce_fusion(self):
|
||||
def fn():
|
||||
|
|
@ -92,7 +89,7 @@ class TestKernelFusionRegression(unittest.TestCase):
|
|||
out = bn(conv(x))
|
||||
out += identity
|
||||
return torch.nn.functional.relu(out)
|
||||
self._check_kernel_count(fn, 17)
|
||||
self._check_kernel_count(fn, 12)
|
||||
|
||||
def test_multiple_inplace_ops_fusion(self):
|
||||
def fn():
|
||||
|
|
@ -117,7 +114,7 @@ class TestKernelFusionRegression(unittest.TestCase):
|
|||
bn.train()
|
||||
with torch.no_grad():
|
||||
return bn(x)
|
||||
self._check_kernel_count(fn, 10)
|
||||
self._check_kernel_count(fn, 8)
|
||||
|
||||
# this is a minimal extra/other_mnist/beautiful_mnist_torch.py to cover fusion for training with optimizer
|
||||
def test_mnist_training_fusion(self):
|
||||
|
|
@ -138,7 +135,7 @@ class TestKernelFusionRegression(unittest.TestCase):
|
|||
loss.backward()
|
||||
optimizer.step()
|
||||
return loss
|
||||
self._check_kernel_count(fn, 28)
|
||||
self._check_kernel_count(fn, 24)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue