mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
* preallocate all realized buffers
* contiguous
* work
* comment that out
* move to schedule
* better
* correct fix
* just buffer
* disk bufs
* fixes disk tensor stuff
* fix symbolic stuff
* fix multi
* 162 failures
* bugfixes
* don't check that anymore
* fix schedule tests
* mnist should be contiguious
* type and buffer
* fix tests
* shrink axis correction
* mypy fixes
* tests skips
* same 37 failures
* dedup
* no shrink in the graph
* 29 failures
* skips
* fix custom kernel
* fix training
* those optimizations aren't supported currently
* simpler
* more correct
* tests
* 14 failures
* works
* fix that test
* broken
* 11 failures
* only kernel counts left
* fixes
* all tests pass
* remove tensor_map
* op test
* 200 -> 230
* test fixes
* fixes
* revert test_tiny thing
* guard
* revert that
* test tiny passes
* no contigs there
* base realize back
* Revert "no contigs there"
This reverts commit c45bb9fcfd.
* revert that
* chop many assigns
* 12 failures
* fix tests
* tests
* apply after
* pre-commit
* remove old code
* delete that
* fix types
* remove extra contig
* fix dataloader
* torch fix
* disk fix
* update kernel fusion numbres
* runs on amd
* restore kernel count
* add that rule back
* that
* disable that
* wrong
* add the correct rule for that folding
* more tests
* guard c1.arg
* no newlines
* realize those
* split into a different file
* remove detach/contig back
* skip 2
* update that
144 lines
4.5 KiB
Python
144 lines
4.5 KiB
Python
# simple tests
|
|
import unittest
|
|
import torch
|
|
import warnings
|
|
from tinygrad.helpers import getenv, GlobalCounters
|
|
if getenv("TINY_BACKEND2"):
|
|
import extra.torch_backend.backend2
|
|
device = "cpu"
|
|
else:
|
|
import extra.torch_backend.backend
|
|
device = "tiny"
|
|
|
|
|
|
class TestKernelFusionRegression(unittest.TestCase):
|
|
def _realize(self, t): _ = t.detach().cpu().numpy()
|
|
|
|
def _check_kernel_count(self, fn, expected_kernels):
|
|
torch.manual_seed(42)
|
|
GlobalCounters.reset()
|
|
fn().detach().cpu().numpy()
|
|
expectation = f"{GlobalCounters.kernel_count} vs {expected_kernels} expected."
|
|
if GlobalCounters.kernel_count < expected_kernels: warnings.warn(f"{expectation} Expectation can be lowered.", UserWarning)
|
|
self.assertLessEqual(GlobalCounters.kernel_count, expected_kernels, f"{expectation}")
|
|
|
|
def test_elementwise_fusion(self):
|
|
def fn():
|
|
x = torch.randn(128, 128, device=device)
|
|
return (x + 1.0) * 2.0 - 0.5
|
|
self._check_kernel_count(fn, 5)
|
|
|
|
def test_relu_fusion(self):
|
|
def fn():
|
|
x = torch.randn(1, 3, 32, 32, device=device)
|
|
conv = torch.nn.Conv2d(3, 16, 3, padding=1).to(device)
|
|
with torch.no_grad():
|
|
return torch.nn.functional.relu(conv(x))
|
|
self._check_kernel_count(fn, 8)
|
|
|
|
def test_batchnorm_fusion(self):
|
|
def fn():
|
|
x = torch.randn(2, 3, 16, 16, device=device)
|
|
conv = torch.nn.Conv2d(3, 8, 3, padding=1).to(device)
|
|
bn = torch.nn.BatchNorm2d(8).to(device)
|
|
bn.eval()
|
|
with torch.no_grad():
|
|
return torch.nn.functional.relu(bn(conv(x)))
|
|
self._check_kernel_count(fn, 16)
|
|
|
|
def test_reduce_fusion(self):
|
|
def fn():
|
|
x = torch.randn(64, 64, device=device)
|
|
return (x * 2.0).sum()
|
|
self._check_kernel_count(fn, 5)
|
|
|
|
def test_matmul_elementwise_fusion(self):
|
|
def fn():
|
|
x = torch.randn(32, 32, device=device)
|
|
w = torch.randn(32, 32, device=device)
|
|
return torch.nn.functional.relu(x @ w + 1.0)
|
|
self._check_kernel_count(fn, 7)
|
|
|
|
def test_pooling_fusion(self):
|
|
def fn():
|
|
x = torch.randn(1, 8, 16, 16, device=device)
|
|
return torch.nn.functional.max_pool2d(x * 2.0, 2)
|
|
self._check_kernel_count(fn, 5)
|
|
|
|
def test_residual_add_relu_fusion(self):
|
|
def fn():
|
|
x = torch.randn(1, 8, 16, 16, device=device)
|
|
identity = torch.randn(1, 8, 16, 16, device=device)
|
|
out = x + identity
|
|
return torch.nn.functional.relu(out)
|
|
self._check_kernel_count(fn, 7)
|
|
|
|
def test_inplace_add_relu_fusion(self):
|
|
def fn():
|
|
x = torch.randn(1, 16, 32, 32, device=device)
|
|
y = torch.randn(1, 16, 32, 32, device=device)
|
|
x += y
|
|
return torch.nn.functional.relu(x)
|
|
self._check_kernel_count(fn, 7)
|
|
|
|
def test_conv_bn_add_relu_fusion(self):
|
|
def fn():
|
|
x = torch.randn(1, 8, 16, 16, device=device)
|
|
identity = torch.randn(1, 8, 16, 16, device=device)
|
|
conv = torch.nn.Conv2d(8, 8, 3, padding=1, bias=False).to(device)
|
|
bn = torch.nn.BatchNorm2d(8).to(device)
|
|
bn.eval()
|
|
with torch.no_grad():
|
|
out = bn(conv(x))
|
|
out += identity
|
|
return torch.nn.functional.relu(out)
|
|
self._check_kernel_count(fn, 17)
|
|
|
|
def test_multiple_inplace_ops_fusion(self):
|
|
def fn():
|
|
x = torch.randn(64, 64, device=device)
|
|
x += 1.0
|
|
x *= 2.0
|
|
return torch.nn.functional.relu(x)
|
|
self._check_kernel_count(fn, 4)
|
|
|
|
def test_view_inplace_no_fusion_break(self):
|
|
def fn():
|
|
x = torch.randn(4, 64, device=device)
|
|
view = x[1:3]
|
|
view += 1.0
|
|
return x.sum()
|
|
self._check_kernel_count(fn, 8)
|
|
|
|
def test_batchnorm_running_stats_update(self):
|
|
def fn():
|
|
x = torch.randn(2, 8, 8, 8, device=device)
|
|
bn = torch.nn.BatchNorm2d(8).to(device)
|
|
bn.train()
|
|
with torch.no_grad():
|
|
return bn(x)
|
|
self._check_kernel_count(fn, 10)
|
|
|
|
# this is a minimal extra/other_mnist/beautiful_mnist_torch.py to cover fusion for training with optimizer
|
|
def test_mnist_training_fusion(self):
|
|
def fn():
|
|
model = torch.nn.Sequential(
|
|
torch.nn.Conv2d(1, 8, 3, padding=1),
|
|
torch.nn.ReLU(),
|
|
torch.nn.MaxPool2d(2),
|
|
torch.nn.Flatten(),
|
|
torch.nn.Linear(8*14*14, 10)
|
|
).to(device)
|
|
optimizer = torch.optim.Adam(model.parameters(), 1e-3)
|
|
x = torch.randn(32, 1, 28, 28, device=device)
|
|
labels = torch.randint(0, 10, (32,), device=device)
|
|
out = model(x)
|
|
loss = torch.nn.functional.cross_entropy(out, labels)
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
optimizer.step()
|
|
return loss
|
|
self._check_kernel_count(fn, 28)
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|