mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Merge branch 'master' into single_file_onnx/1
This commit is contained in:
commit
8b59601ca6
16 changed files with 386 additions and 328 deletions
44
.github/workflows/test.yml
vendored
44
.github/workflows/test.yml
vendored
|
|
@ -870,28 +870,28 @@ jobs:
|
|||
run: WEBGPU=1 PYTHONPATH=. python3 test/external/external_test_onnx_runner.py
|
||||
|
||||
osxremote:
|
||||
name: MacOS (remote metal)
|
||||
runs-on: macos-15
|
||||
timeout-minutes: 10
|
||||
env:
|
||||
REMOTE: 1
|
||||
REMOTEDEV: METAL
|
||||
steps:
|
||||
- name: Checkout Code
|
||||
uses: actions/checkout@v4
|
||||
- name: Setup Environment
|
||||
uses: ./.github/actions/setup-tinygrad
|
||||
with:
|
||||
key: macos-remote
|
||||
deps: testing_minimal
|
||||
- name: Check Device.DEFAULT and print some source
|
||||
run: |
|
||||
python -c "from tinygrad import Device; assert Device.DEFAULT == 'REMOTE', Device.DEFAULT"
|
||||
python -c "from tinygrad import Device; assert Device.default.properties.real_device == 'METAL', Device.default.properties.real_device"
|
||||
DEBUG=4 python3 test/test_tiny.py TestTiny.test_plus
|
||||
- name: Run REMOTE=1 Test
|
||||
run: |
|
||||
python3 -m pytest test/test_tiny.py test/test_jit.py test/test_subbuffer.py test/test_graph.py test/test_multitensor.py test/test_tensor_variable.py
|
||||
name: MacOS (remote metal)
|
||||
runs-on: macos-15
|
||||
timeout-minutes: 10
|
||||
env:
|
||||
REMOTE: 1
|
||||
REMOTEDEV: METAL
|
||||
steps:
|
||||
- name: Checkout Code
|
||||
uses: actions/checkout@v4
|
||||
- name: Setup Environment
|
||||
uses: ./.github/actions/setup-tinygrad
|
||||
with:
|
||||
key: macos-remote
|
||||
deps: testing_minimal
|
||||
- name: Check Device.DEFAULT and print some source
|
||||
run: |
|
||||
python -c "from tinygrad import Device; assert Device.DEFAULT == 'REMOTE', Device.DEFAULT"
|
||||
python -c "from tinygrad import Device; assert Device.default.properties.real_device == 'METAL', Device.default.properties.real_device"
|
||||
DEBUG=4 python3 test/test_tiny.py TestTiny.test_plus
|
||||
- name: Run REMOTE=1 Test
|
||||
run: |
|
||||
python3 -m pytest test/test_tiny.py test/test_jit.py test/test_subbuffer.py test/test_graph.py test/test_multitensor.py test/test_tensor_variable.py
|
||||
|
||||
amdremote:
|
||||
name: Linux (remote)
|
||||
|
|
|
|||
2
setup.py
2
setup.py
|
|
@ -9,7 +9,7 @@ with open(directory / 'README.md', encoding='utf-8') as f:
|
|||
|
||||
testing_minimal = [
|
||||
"numpy",
|
||||
"torch",
|
||||
"torch==2.7.1",
|
||||
"pytest",
|
||||
"pytest-xdist",
|
||||
"hypothesis",
|
||||
|
|
|
|||
197
test/test_jit.py
197
test/test_jit.py
|
|
@ -5,10 +5,10 @@ import numpy as np
|
|||
from hypothesis import given, settings, strategies as strat
|
||||
from test.helpers import assert_jit_cache_len, not_support_multi_device, REAL_DEV
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.engine.jit import TinyJit, GraphRunner
|
||||
from tinygrad.engine.jit import TinyJit, GraphRunner, MultiGraphRunner, graph_class
|
||||
from tinygrad.engine.realize import CompiledRunner, BufferCopy, BufferXfer
|
||||
from tinygrad.device import Device
|
||||
from tinygrad.helpers import Context, JIT, GlobalCounters
|
||||
from tinygrad.runtime.support.hcq import HCQCompiled
|
||||
from tinygrad.helpers import Context, JIT, GlobalCounters, getenv
|
||||
from tinygrad.dtype import dtypes
|
||||
from extra.models.unet import ResBlock
|
||||
|
||||
|
|
@ -472,32 +472,6 @@ class TestJit(unittest.TestCase):
|
|||
np.testing.assert_allclose((a.numpy()+b.numpy()), zc.numpy(), atol=1e-4, rtol=1e-5)
|
||||
np.testing.assert_allclose((a.numpy()*b.numpy()), wc.numpy(), atol=1e-4, rtol=1e-5)
|
||||
|
||||
@unittest.skipUnless((not isinstance(Device.default, HCQCompiled)) and Device.default.graph is not None, "must be non-hcq with graph")
|
||||
def test_jit_several_incompatible_devs(self):
|
||||
assert isinstance(Device["CPU"], HCQCompiled) and Device["CPU"].graph is not None
|
||||
assert (not isinstance(Device.default, HCQCompiled)) and Device.default.graph is not None
|
||||
|
||||
d0, d1 = Device.DEFAULT, "CPU"
|
||||
|
||||
@TinyJit
|
||||
def f(a0, b0):
|
||||
a1 = (a + 2.0).contiguous().realize()
|
||||
a2 = (a1 * 2.0).contiguous().realize()
|
||||
|
||||
b1 = (b0 + 2.0).contiguous().realize()
|
||||
b2 = (b1 * 2.0).contiguous().realize()
|
||||
|
||||
return a2, b2
|
||||
|
||||
for _ in range(5):
|
||||
a = Tensor.randn(10, 10, device=d0).realize()
|
||||
b = Tensor.randn(10, 10, device=d1).realize()
|
||||
a1, b1 = f(a, b)
|
||||
np.testing.assert_allclose(((a.numpy()+2.0)*2.0), a1.numpy(), atol=1e-4, rtol=1e-5)
|
||||
np.testing.assert_allclose(((b.numpy()+2.0)*2.0), b1.numpy(), atol=1e-4, rtol=1e-5)
|
||||
|
||||
assert all(isinstance(ei.prg, GraphRunner) for ei in f.jit_cache), repr(f.jit_cache)
|
||||
|
||||
@unittest.skipIf(not_support_multi_device(), "no multi")
|
||||
def test_jitted_view(self):
|
||||
d0, d1 = f"{Device.DEFAULT}:0", f"{Device.DEFAULT}:1"
|
||||
|
|
@ -696,5 +670,170 @@ class TestJitFree(unittest.TestCase):
|
|||
out = fxn(Tensor([11,1,2,3,4]))
|
||||
self.assertEqual(out.item(), 13600)
|
||||
|
||||
class TestJitGraphSplit(unittest.TestCase):
|
||||
def compute(self, device, inp):
|
||||
assert inp.device == device, f"Input device {inp.device} does not match expected {device}"
|
||||
return (inp + 1.0).contiguous().realize()
|
||||
|
||||
def copy(self, device, to_device, inp):
|
||||
assert inp.device == device, f"Input device {inp.device} does not match expected {device}"
|
||||
return inp.to(to_device).realize()
|
||||
|
||||
def expect(self, f, *args, graph=None, multigraph=None, hcqgraph=None):
|
||||
def _numpies(tpl): return tpl.numpy() if tpl.__class__ is Tensor else tuple([t.numpy() for t in tpl])
|
||||
|
||||
expected = _numpies(f(*args))
|
||||
for i in range(4):
|
||||
res = _numpies(f(*args))
|
||||
np.testing.assert_allclose(res, expected, atol=1e-4, rtol=1e-5)
|
||||
|
||||
dev = Device[Device.DEFAULT]
|
||||
graph_t = graph_class(dev)
|
||||
if graph_t is None: return
|
||||
|
||||
got = f.jit_cache
|
||||
from tinygrad.runtime.graph.hcq import HCQGraph
|
||||
if graph_t is HCQGraph:
|
||||
validate = hcqgraph
|
||||
elif issubclass(graph_t, MultiGraphRunner):
|
||||
validate = multigraph
|
||||
else:
|
||||
validate = graph
|
||||
|
||||
assert len(got) == len(validate), f"Expected {len(validate)} operations, got {len(got)}"
|
||||
for expected, got in zip(validate, got):
|
||||
if expected["type"] == "graph":
|
||||
assert isinstance(got.prg, GraphRunner), f"Expected GraphRunner, got {type(got.prg)}"
|
||||
assert len(got.prg.jit_cache) == expected["cnt"], f"Expected {expected['cnt']} operations in graph, got {len(got.prg.jit_cache)}"
|
||||
elif expected["type"] == "comp":
|
||||
assert isinstance(got.prg, CompiledRunner), f"Expected CompiledRunner, got {type(got.prg)}"
|
||||
elif expected["type"] == "copy":
|
||||
assert isinstance(got.prg, BufferCopy), f"Expected BufferCopy, got {type(got.prg)}"
|
||||
elif expected["type"] == "xfer":
|
||||
assert isinstance(got.prg, BufferXfer), f"Expected BufferXfer, got {type(got.prg)}"
|
||||
|
||||
def ji_graph(self, cnt): return {"type": "graph", "cnt": cnt}
|
||||
def ji_comp(self): return {"type": "comp"}
|
||||
def ji_copy(self): return {"type": "copy"}
|
||||
def ji_xfer(self): return {"type": "xfer"}
|
||||
|
||||
def test_jit_split_simple(self):
|
||||
if Device.DEFAULT == "REMOTE": raise unittest.SkipTest("REMOTE gpu is broken")
|
||||
|
||||
@TinyJit
|
||||
def f(inp):
|
||||
op0 = self.compute(Device.DEFAULT, inp)
|
||||
op1 = self.compute(Device.DEFAULT, op0)
|
||||
op2 = self.compute(Device.DEFAULT, op1)
|
||||
return op2
|
||||
|
||||
inp = Tensor.randn(10, 10, device=Device.DEFAULT).realize()
|
||||
self.expect(f, inp,
|
||||
graph=[self.ji_graph(3)],
|
||||
multigraph=[self.ji_graph(3)],
|
||||
hcqgraph=[self.ji_graph(3)])
|
||||
|
||||
def test_jit_cpu_simple(self):
|
||||
if Device.DEFAULT == "CPU": raise unittest.SkipTest("CPU is not a valid default device for this test")
|
||||
|
||||
@TinyJit
|
||||
def f(inp, inp_cpu):
|
||||
op0 = self.compute(Device.DEFAULT, inp)
|
||||
op1 = self.compute(Device.DEFAULT, op0)
|
||||
op2 = self.compute("CPU", inp_cpu)
|
||||
op3 = self.compute(Device.DEFAULT, op1)
|
||||
return op2, op3
|
||||
|
||||
inp = Tensor.randn(10, 10, device=Device.DEFAULT).realize()
|
||||
inp_cpu = Tensor.randn(10, 10, device="CPU").realize()
|
||||
self.expect(f, inp, inp_cpu,
|
||||
graph=[self.ji_graph(2), self.ji_comp(), self.ji_comp()],
|
||||
multigraph=[self.ji_graph(2), self.ji_comp(), self.ji_comp()],
|
||||
hcqgraph=[self.ji_graph(4)])
|
||||
|
||||
def test_jit_cpu_several(self):
|
||||
if Device.DEFAULT == "CPU": raise unittest.SkipTest("CPU is not a valid default device for this test")
|
||||
|
||||
@TinyJit
|
||||
def f(inp, inp_cpu):
|
||||
op0 = self.compute(Device.DEFAULT, inp)
|
||||
op1 = self.compute(Device.DEFAULT, op0)
|
||||
op2 = self.compute("CPU", inp_cpu)
|
||||
op3 = self.compute("CPU", op2)
|
||||
op4 = self.compute(Device.DEFAULT, op1)
|
||||
return op3, op4
|
||||
|
||||
inp = Tensor.randn(10, 10, device=Device.DEFAULT).realize()
|
||||
inp_cpu = Tensor.randn(10, 10, device="CPU").realize()
|
||||
self.expect(f, inp, inp_cpu,
|
||||
graph=[self.ji_graph(2), self.ji_graph(2), self.ji_comp()],
|
||||
multigraph=[self.ji_graph(2), self.ji_graph(2), self.ji_comp()],
|
||||
hcqgraph=[self.ji_graph(5)])
|
||||
|
||||
def test_jit_multidev(self):
|
||||
if Device.DEFAULT == "CPU": raise unittest.SkipTest("CPU is not a valid default device for this test")
|
||||
|
||||
try: Device[f"{Device.DEFAULT}:1"]
|
||||
except Exception: raise unittest.SkipTest("no multidevice")
|
||||
|
||||
@TinyJit
|
||||
def f(inp, inp_d1):
|
||||
op0 = self.compute(Device.DEFAULT, inp)
|
||||
op1 = self.compute(Device.DEFAULT, op0)
|
||||
op2 = self.compute(f"{Device.DEFAULT}:1", inp_d1)
|
||||
op3 = self.compute(f"{Device.DEFAULT}:1", op2)
|
||||
op4 = self.compute(Device.DEFAULT, op1)
|
||||
return op3, op4
|
||||
|
||||
inp = Tensor.randn(10, 10, device=Device.DEFAULT).realize()
|
||||
inp_d1 = Tensor.randn(10, 10, device=f"{Device.DEFAULT}:1").realize()
|
||||
self.expect(f, inp, inp_d1,
|
||||
graph=[self.ji_graph(2), self.ji_graph(2), self.ji_comp()],
|
||||
multigraph=[self.ji_graph(5)],
|
||||
hcqgraph=[self.ji_graph(5)])
|
||||
|
||||
@unittest.skip("flaky")
|
||||
def test_jit_multidev_xfer(self):
|
||||
if Device.DEFAULT in {"CPU", "LLVM"}: raise unittest.SkipTest("CPU/LLVM is not a valid default device for this test (zero-copies)")
|
||||
|
||||
try: Device[f"{Device.DEFAULT}:1"]
|
||||
except Exception: raise unittest.SkipTest("no multidevice")
|
||||
|
||||
@TinyJit
|
||||
def f(inp, inp_d1):
|
||||
op0 = self.compute(Device.DEFAULT, inp)
|
||||
op1 = self.compute(Device.DEFAULT, op0)
|
||||
op2 = self.compute(f"{Device.DEFAULT}:1", inp_d1)
|
||||
op3 = self.copy(f"{Device.DEFAULT}:1", Device.DEFAULT, op2)
|
||||
op4 = self.compute(f"{Device.DEFAULT}:1", op2)
|
||||
op5 = self.compute(Device.DEFAULT, op3)
|
||||
return op1, op4, op5
|
||||
|
||||
inp = Tensor.randn(10, 10, device=Device.DEFAULT).realize()
|
||||
inp_d1 = Tensor.randn(10, 10, device=f"{Device.DEFAULT}:1").realize()
|
||||
self.expect(f, inp, inp_d1,
|
||||
graph=[self.ji_graph(2), self.ji_comp(), self.ji_xfer(), self.ji_comp(), self.ji_comp()],
|
||||
multigraph=[self.ji_graph(6)],
|
||||
hcqgraph=[self.ji_graph(6)])
|
||||
|
||||
@unittest.skipIf(getenv("MOCKGPU"), "MockGPU does not support parallel copies")
|
||||
def test_jit_multidev_copy(self):
|
||||
if Device.DEFAULT in {"CPU", "LLVM"}: raise unittest.SkipTest("CPU/LLVM is not a valid default device for this test (zero-copies)")
|
||||
if Device.DEFAULT == "REMOTE": raise unittest.SkipTest("REMOTE gpu is broken")
|
||||
|
||||
@TinyJit
|
||||
def f(inp):
|
||||
op0 = self.compute(Device.DEFAULT, inp)
|
||||
op1 = self.compute(Device.DEFAULT, op0)
|
||||
op2 = self.copy(Device.DEFAULT, "CPU", op1)
|
||||
op3 = self.compute("CPU", op2)
|
||||
return op3
|
||||
|
||||
inp = Tensor.randn(10, 10, device=Device.DEFAULT).realize()
|
||||
self.expect(f, inp,
|
||||
graph=[self.ji_graph(2), self.ji_copy(), self.ji_comp()],
|
||||
multigraph=[self.ji_graph(2), self.ji_copy(), self.ji_comp()],
|
||||
hcqgraph=[self.ji_graph(4)])
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
|||
134
test/test_nn.py
134
test/test_nn.py
|
|
@ -108,105 +108,39 @@ class TestNN(unittest.TestCase):
|
|||
_test_linear(Tensor.randn(BS, in_dim), in_dim, out_dim)
|
||||
_test_linear(Tensor.randn(BS, T, in_dim), in_dim, out_dim) # test with more dims
|
||||
|
||||
def test_conv1d(self):
|
||||
BS, C1, W = 4, 16, 224//4
|
||||
C2, K, S, P = 64, 7, 2, 1
|
||||
|
||||
def _test_conv(self, tiny_conv, torch_conv, BS, C1, DIMS, C2, K, S, P, D=1):
|
||||
# create in tinygrad
|
||||
layer = Conv1d(C1, C2, kernel_size=K, stride=S, padding=P)
|
||||
layer = tiny_conv(C1, C2, kernel_size=K, stride=S, padding=P, dilation=D)
|
||||
|
||||
# create in torch
|
||||
with torch.no_grad():
|
||||
torch_layer = torch.nn.Conv1d(C1, C2, kernel_size=K, stride=S, padding=P).eval()
|
||||
torch_layer = torch_conv(C1, C2, kernel_size=K, stride=S, padding=P, dilation=D).eval()
|
||||
torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32)
|
||||
torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32)
|
||||
|
||||
# test
|
||||
x = Tensor.uniform(BS, C1, W)
|
||||
x = Tensor.uniform(BS, C1, *DIMS)
|
||||
z = layer(x)
|
||||
torch_x = torch.tensor(x.numpy())
|
||||
torch_z = torch_layer(torch_x)
|
||||
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5)
|
||||
|
||||
def test_conv2d(self):
|
||||
BS, C1, H, W = 4, 16, 224//4, 224//4
|
||||
C2, K, S, P = 64, 7, 2, 1
|
||||
|
||||
# create in tinygrad
|
||||
layer = Conv2d(C1, C2, kernel_size=K, stride=S, padding=P)
|
||||
|
||||
# create in torch
|
||||
with torch.no_grad():
|
||||
torch_layer = torch.nn.Conv2d(C1, C2, kernel_size=K, stride=S, padding=P).eval()
|
||||
torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32)
|
||||
torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32)
|
||||
|
||||
# test
|
||||
x = Tensor.uniform(BS, C1, H, W)
|
||||
z = layer(x)
|
||||
torch_x = torch.tensor(x.numpy())
|
||||
torch_z = torch_layer(torch_x)
|
||||
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5)
|
||||
def test_conv1d(self): self._test_conv(Conv1d, torch.nn.Conv1d, BS=4, C1=16, DIMS=[224//4], C2=64, K=7, S=2, P=1)
|
||||
def test_conv2d(self): self._test_conv(Conv2d, torch.nn.Conv2d, BS=4, C1=16, DIMS=[224//4, 224//4], C2=64, K=7, S=2, P=1)
|
||||
|
||||
def test_conv1d_same_padding(self):
|
||||
BS, C1, W = 8, 3, 32
|
||||
C2, K, S, P = 16, 3, 1, 'same'
|
||||
|
||||
# create in tinygrad
|
||||
layer = Conv1d(C1, C2, kernel_size=K, stride=S, padding=P)
|
||||
|
||||
# create in torch
|
||||
with torch.no_grad():
|
||||
torch_layer = torch.nn.Conv1d(C1, C2, kernel_size=K, stride=S, padding=P).eval()
|
||||
torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32)
|
||||
torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32)
|
||||
|
||||
# test
|
||||
x = Tensor.uniform(BS, C1, W)
|
||||
z = layer(x)
|
||||
torch_x = torch.tensor(x.numpy())
|
||||
torch_z = torch_layer(torch_x)
|
||||
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5)
|
||||
|
||||
def _run_conv2d_same_padding_test(self, BS, C1, C2, H, W, K, S, padding='same', D=1):
|
||||
# create in tinygrad
|
||||
layer = Conv2d(C1, C2, kernel_size=K, stride=S, padding=padding, dilation=D)
|
||||
|
||||
# create in torch
|
||||
with torch.no_grad():
|
||||
torch_layer = torch.nn.Conv2d(C1, C2, kernel_size=K, stride=S, padding=padding, dilation=D).eval()
|
||||
torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32)
|
||||
torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32)
|
||||
|
||||
# test
|
||||
x = Tensor.uniform(BS, C1, H, W)
|
||||
z = layer(x)
|
||||
torch_x = torch.tensor(x.numpy())
|
||||
torch_z = torch_layer(torch_x)
|
||||
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5)
|
||||
|
||||
self._test_conv(Conv1d, torch.nn.Conv1d, BS=8, C1=3, DIMS=[32], C2=16, K=3, S=1, P='same')
|
||||
def test_conv2d_same_padding_odd_input(self):
|
||||
BS, C1, H, W = 16, 16, 29, 31
|
||||
C2, K, S, P = 32, 5, 1, 'same'
|
||||
self._run_conv2d_same_padding_test(BS, C1, C2, H, W, K, S, P)
|
||||
|
||||
self._test_conv(Conv2d, torch.nn.Conv2d, BS=16, C1=16, DIMS=[29, 31], C2=32, K=5, S=1, P='same')
|
||||
def test_conv2d_same_padding_large_kernel(self):
|
||||
BS, C1, H, W = 16, 16, 28, 33
|
||||
C2, K, S, P = 32, 9, 1, 'same'
|
||||
self._run_conv2d_same_padding_test(BS, C1, C2, H, W, K, S, P)
|
||||
|
||||
self._test_conv(Conv2d, torch.nn.Conv2d, BS=16, C1=16, DIMS=[28, 33], C2=32, K=9, S=1, P='same')
|
||||
def test_conv2d_same_padding_with_dilation(self):
|
||||
BS, C1, H, W = 16, 3, 28, 28
|
||||
C2, K, S, P, D = 32, 3, 1, 'same', 3
|
||||
self._run_conv2d_same_padding_test(BS, C1, C2, H, W, K, S, P, D)
|
||||
self._test_conv(Conv2d, torch.nn.Conv2d, BS=16, C1=3, DIMS=[28, 28], C2=32, K=3, S=1, P='same', D=3)
|
||||
|
||||
def test_conv2d_same_padding_invalid_stride(self):
|
||||
C1, C2, K, S, P = 16, 32, 2, 2, 'same'
|
||||
self.assertRaises(ValueError, Conv2d, C1, C2, kernel_size=K, stride=S, padding=P)
|
||||
|
||||
self.assertRaises(ValueError, Conv2d, in_channels=16, out_channels=32, kernel_size=2, stride=2, padding='same')
|
||||
def test_conv2d_same_padding_invalid_padding_str(self):
|
||||
C1, C2, K, S, P = 16, 32, 2, 1, 'not_same'
|
||||
self.assertRaises(ValueError, Conv2d, C1, C2, kernel_size=K, stride=S, padding=P)
|
||||
self.assertRaises(ValueError, Conv2d, in_channels=16, out_channels=32, kernel_size=2, stride=1, padding='not_same')
|
||||
|
||||
@unittest.skip("Takes too long to compile for Compiled backends")
|
||||
def test_conv2d_winograd(self):
|
||||
|
|
@ -229,12 +163,13 @@ class TestNN(unittest.TestCase):
|
|||
with Context(WINO=1):
|
||||
z = layer(x)
|
||||
|
||||
m = z.mean()
|
||||
m.backward()
|
||||
|
||||
torch_x = torch.tensor(x.numpy(), requires_grad=True)
|
||||
torch_z = torch_layer(torch_x)
|
||||
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5)
|
||||
|
||||
m = z.mean()
|
||||
m.backward()
|
||||
gw = layer.weight.grad.realize()
|
||||
gb = layer.bias.grad.realize()
|
||||
gx = x.grad.realize()
|
||||
|
|
@ -245,44 +180,9 @@ class TestNN(unittest.TestCase):
|
|||
np.testing.assert_allclose(gx.numpy(), torch_x.grad.numpy(), atol=5e-4, rtol=1e-5)
|
||||
|
||||
def test_conv_transpose1d(self):
|
||||
BS, C1, W = 4, 16, 224//4
|
||||
C2, K, S, P = 64, 7, 2, 1
|
||||
|
||||
# create in tinygrad
|
||||
layer = ConvTranspose1d(C1, C2, kernel_size=K, stride=S, padding=P)
|
||||
|
||||
# create in torch
|
||||
with torch.no_grad():
|
||||
torch_layer = torch.nn.ConvTranspose1d(C1, C2, kernel_size=K, stride=S, padding=P).eval()
|
||||
torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32)
|
||||
torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32)
|
||||
|
||||
# test
|
||||
x = Tensor.uniform(BS, C1, W)
|
||||
z = layer(x)
|
||||
torch_x = torch.tensor(x.numpy())
|
||||
torch_z = torch_layer(torch_x)
|
||||
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5)
|
||||
|
||||
self._test_conv(ConvTranspose1d, torch.nn.ConvTranspose1d, BS=4, C1=16, DIMS=[224//4], C2=64, K=7, S=2, P=1)
|
||||
def test_conv_transpose2d(self):
|
||||
BS, C1, H, W = 4, 16, 224//4, 224//4
|
||||
C2, K, S, P = 64, 7, 2, 1
|
||||
|
||||
# create in tinygrad
|
||||
layer = ConvTranspose2d(C1, C2, kernel_size=K, stride=S, padding=P)
|
||||
|
||||
# create in torch
|
||||
with torch.no_grad():
|
||||
torch_layer = torch.nn.ConvTranspose2d(C1, C2, kernel_size=K, stride=S, padding=P).eval()
|
||||
torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32)
|
||||
torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32)
|
||||
|
||||
# test
|
||||
x = Tensor.uniform(BS, C1, H, W)
|
||||
z = layer(x)
|
||||
torch_x = torch.tensor(x.numpy())
|
||||
torch_z = torch_layer(torch_x)
|
||||
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5)
|
||||
self._test_conv(ConvTranspose2d, torch.nn.ConvTranspose2d, BS=4, C1=16, DIMS=[224//4, 224//4], C2=64, K=7, S=2, P=1)
|
||||
|
||||
def test_groupnorm(self):
|
||||
BS, H, W, C, G = 20, 10, 10, 6, 3
|
||||
|
|
|
|||
|
|
@ -892,13 +892,13 @@ class TestIdxUpcast(unittest.TestCase):
|
|||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.long), "int64 is supported")
|
||||
def test_overflow_sym(self):
|
||||
self.do_op_then_assert(dtypes.long, 2048, 2048, UOp.variable("dim3", 0, 2048).bind(32))
|
||||
self.do_op_then_assert(dtypes.long, 2048, 2048, UOp.variable("dim3", 1, 2048).bind(32))
|
||||
|
||||
def test_regular(self):
|
||||
self.do_op_then_assert(dtypes.int, 64, 64, 64)
|
||||
|
||||
def test_regular_sym(self):
|
||||
self.do_op_then_assert(dtypes.int, 2048, 2048, UOp.variable("dim3", 0, 64).bind(32))
|
||||
self.do_op_then_assert(dtypes.int, 2048, 2048, UOp.variable("dim3", 1, 64).bind(32))
|
||||
|
||||
@unittest.skipIf(PTX, "PTX always convert Ops.INDEX to int64")
|
||||
def test_symfold(self):
|
||||
|
|
@ -910,7 +910,7 @@ class TestIdxUpcast(unittest.TestCase):
|
|||
@unittest.skipIf(is_dtype_supported(dtypes.long), "int64 is supported")
|
||||
def test_int64_unsupported_overflow_sym(self):
|
||||
with self.assertRaises(KeyError):
|
||||
self.do_op_then_assert(dtypes.long, 2048, 2048, UOp.variable("dim3", 0, 2048).bind(32))
|
||||
self.do_op_then_assert(dtypes.long, 2048, 2048, UOp.variable("dim3", 1, 2048).bind(32))
|
||||
|
||||
@unittest.skipIf(is_dtype_supported(dtypes.long), "int64 is supported")
|
||||
def test_int64_unsupported_overflow(self):
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ from tinygrad.codegen.devectorizer import load_store_folding, load_store_indexin
|
|||
ReduceContext, correct_load_store, pm_render
|
||||
from tinygrad.codegen.optional import get_late_rewrite_patterns
|
||||
from tinygrad.codegen.linearize import block_create, pm_blockend_merge, block_merge, pm_finalize, BlockContext
|
||||
from tinygrad.opt import pm_optimize
|
||||
|
||||
@dataclass
|
||||
class RewriteStep:
|
||||
|
|
@ -42,6 +43,10 @@ def get_rewrites_for_renderer(opts:Renderer, linearizer:bool=True) -> list[Rewri
|
|||
def _get_rewrites_for_renderer(opts:Renderer, linearizer:bool, _QUANTIZE, _DEVECTORIZE, _TRANSCENDENTAL) -> list[RewriteStep]:
|
||||
# ** lowerer (rewrite_shapetracker_with_index) **
|
||||
ret: list[RewriteStep] = []
|
||||
|
||||
# this is kernel.py
|
||||
ret.append(RewriteStep(pm_optimize, ctx=lambda _: opts, name="optimize ast"))
|
||||
|
||||
if _QUANTIZE and opts.device in {"CPU", "DSP"}: ret.append(RewriteStep(pm_quant, name="quantize"))
|
||||
ret.append(RewriteStep(pm_lowerer, get_index, name="lowerer", bottom_up=True))
|
||||
|
||||
|
|
|
|||
|
|
@ -7,9 +7,7 @@ from tinygrad.uop.ops import Ops, PatternMatcher, UOp, UPat, Variable, sym_infer
|
|||
from tinygrad.device import Device, Buffer
|
||||
from tinygrad.renderer import Renderer, ProgramSpec, Estimates
|
||||
from tinygrad.engine.schedule import ScheduleItem
|
||||
from tinygrad.opt import get_optimized_ast
|
||||
from tinygrad.codegen import full_rewrite
|
||||
from tinygrad.uop.spec import type_verify
|
||||
|
||||
# **************** Program Creation ****************
|
||||
|
||||
|
|
@ -27,16 +25,13 @@ def get_program(ast:UOp, renderer:Renderer) -> ProgramSpec:
|
|||
"""
|
||||
|
||||
if getenv("VIZ"): graph_rewrite(ast, PatternMatcher([]), name="View Base AST")
|
||||
modified_ast = get_optimized_ast(ast, renderer) if ast.arg is None or ast.arg.opts_to_apply is not None else ast
|
||||
if __debug__: type_verify(list(modified_ast.toposort()))
|
||||
|
||||
# linearize
|
||||
try:
|
||||
uops = full_rewrite(modified_ast, renderer)
|
||||
uops = full_rewrite(ast, renderer)
|
||||
except RuntimeError:
|
||||
print("***** LINEARIZE FAILURE *****")
|
||||
print(f"ast = {ast}")
|
||||
print(f"opts = {modified_ast.arg.applied_opts}")
|
||||
raise
|
||||
assert uops[-1].op is Ops.SINK, "last uop must be sink"
|
||||
|
||||
|
|
|
|||
|
|
@ -2,9 +2,10 @@
|
|||
|
||||
from tinygrad.opt.kernel import Kernel
|
||||
from tinygrad.opt.heuristic import hand_coded_optimizations
|
||||
from tinygrad.uop.ops import UOp
|
||||
from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops
|
||||
from tinygrad.helpers import NOOPT, BEAM, USE_TC, getenv
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.uop.spec import type_verify
|
||||
|
||||
def get_optimized_ast(ast:UOp, renderer:Renderer) -> UOp:
|
||||
"""
|
||||
|
|
@ -27,4 +28,11 @@ def get_optimized_ast(ast:UOp, renderer:Renderer) -> UOp:
|
|||
kb = Kernel(ast, opts=renderer)
|
||||
rawbufs = bufs_from_lin(kb, allocate=False)
|
||||
k = beam_search(kb, rawbufs, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1)))
|
||||
return k.get_optimized_ast()
|
||||
ret = k.get_optimized_ast()
|
||||
if __debug__: type_verify(list(ret.toposort()))
|
||||
return ret
|
||||
|
||||
pm_optimize = PatternMatcher([
|
||||
(UPat(Ops.SINK, name="ast"), lambda ctx,ast:
|
||||
get_optimized_ast(ast, ctx) if (ast.arg is None or ast.arg.opts_to_apply is not None) and ast.src[0].st is not None else None),
|
||||
])
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from tinygrad.dtype import ImageDType, AddrSpace
|
|||
from tinygrad.helpers import all_same, colored, ansilen, dedup, prod, round_up, to_function_name, unwrap, argfix, DEBUG, TC_SELECT, TC_OPT, AMX
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import strides_for_shape, get_contraction
|
||||
from tinygrad.schedule.kernelize import view_left
|
||||
from tinygrad.opt.swizzler import view_left, view_left_through_load
|
||||
|
||||
class OptOps(Enum):
|
||||
TC = auto(); UPCAST = auto(); UNROLL = auto(); LOCAL = auto() # noqa: E702
|
||||
|
|
@ -503,4 +503,4 @@ class Kernel:
|
|||
self.finalized = True
|
||||
fixed_ast = fixup_ast(self.ast)
|
||||
del fixup_ast
|
||||
return graph_rewrite(fixed_ast, view_left, name="fixup optimized AST")
|
||||
return graph_rewrite(fixed_ast, view_left+view_left_through_load, name="fixup optimized AST")
|
||||
|
|
|
|||
137
tinygrad/opt/swizzler.py
Normal file
137
tinygrad/opt/swizzler.py
Normal file
|
|
@ -0,0 +1,137 @@
|
|||
from tinygrad.uop.ops import UOp, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, resolve, sint
|
||||
from tinygrad.helpers import all_same, prod, unwrap, colored
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View, strides_for_shape, get_contraction_with_reduce
|
||||
from tinygrad.schedule.grouper import ALWAYS_CONTIGUOUS
|
||||
from tinygrad.dtype import ImageDType, dtypes
|
||||
|
||||
merge_views = PatternMatcher([
|
||||
# merge adjacent views
|
||||
(UPat(Ops.VIEW, src=(UPat(Ops.VIEW, name="v1"),), name="v2"), lambda v1,v2: v1.replace(arg=v1.arg+v2.arg)),
|
||||
# replace MovementOps with VIEW
|
||||
(UPat(GroupOp.Movement, src=(UPat.var("x"),), name="mop"), lambda mop,x: x.base.view(mop.st)),
|
||||
# remove NOOP views
|
||||
(UPat.var("x").view(name="view"),
|
||||
lambda x,view: x if x.st is not None and x.op not in GroupOp.Defines and view.st.contiguous and view.shape == x.shape else None),
|
||||
(UPat(GroupOp.All-{Ops.DEFINE_GLOBAL}).view(name="view"),
|
||||
lambda view: view.const_like(0) if (mask:=view.st.views[-1].mask) is not None and any((x[1]-x[0]) == 0 for x in mask) else None),
|
||||
# only unmaksed VIEW on CONST replaces the ShapeTracker
|
||||
(UPat(Ops.VIEW, src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="x"),), name="view"),
|
||||
lambda x,view: x.replace(src=(x.src[0].replace(arg=x.st+view.st),)) if all(v.mask is None for v in (x.st+view.st).views) else None),
|
||||
])
|
||||
|
||||
def reduce_push_add_ones(src:UOp, r:UOp, view:UOp):
|
||||
# contiguous, expand, and the same with ones removed
|
||||
if unwrap(view.st).contiguous and len(r.shape) < len(view.shape) and \
|
||||
tuple(x for x in r.shape if resolve(x != 1)) == tuple(x for x in view.shape if resolve(x != 1)):
|
||||
new_shape: list[sint] = []
|
||||
new_reduce_axis = []
|
||||
if (contraction:=get_contraction_with_reduce(view.shape, r.shape, r.arg[1])) is None: return None
|
||||
for i,pairs in enumerate(contraction):
|
||||
new_shape_chunk = [view.shape[p] for p in pairs]
|
||||
if i in r.arg[1]:
|
||||
# if this is a reduce axis, we need a 1 in the view here to put it
|
||||
assert len(new_shape_chunk) > 0
|
||||
new_shape += [1]*(len(pairs)-1) + [src.shape[i]]
|
||||
new_reduce_axis.append(len(new_shape)-1)
|
||||
else:
|
||||
# otherwise, pass through the new_shape_chunk
|
||||
new_shape += new_shape_chunk
|
||||
ret = r.replace(src=(src.reshape(tuple(new_shape)),), arg=(r.arg[0], tuple(new_reduce_axis))+r.arg[2:])
|
||||
assert ret.shape == view.shape, f"shape mismatch on reduce_push_add_ones, {ret.shape} != {view.shape}"
|
||||
return ret
|
||||
return None
|
||||
|
||||
view_left = merge_views+PatternMatcher([
|
||||
# view before elementwise and buffer ops
|
||||
(UPat(Ops.VIEW, src=(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.BIND, Ops.STORE, Ops.VALID, Ops.SINK}, name="e"),), name="view"),
|
||||
lambda e,view: e.replace(src=tuple(s.view(view.st) for s in e.src))),
|
||||
# if there's ones added after reduce, put this before the reduce
|
||||
(UPat(Ops.VIEW, src=(UPat(Ops.REDUCE_AXIS, src=(UPat.var("src"),), name="r"),), name="view"), reduce_push_add_ones),
|
||||
])
|
||||
|
||||
view_left_through_load = PatternMatcher([
|
||||
# view before load
|
||||
(UPat(Ops.VIEW, src=(UPat(Ops.LOAD, name="e"),), name="view"),
|
||||
lambda e,view: e.replace(src=tuple(s.view(view.st) for s in e.src))),
|
||||
])
|
||||
|
||||
def apply_swizzle(u:UOp) -> UOp: return graph_rewrite(u, view_left, name="Sub View Left")
|
||||
|
||||
# change reduceop axes and input ShapeTrackers, view gets replaced with a reshape.
|
||||
def swizzle_reduceop(r:UOp, src:UOp, view:UOp, fuse=False):
|
||||
# contiguous and same size can push to children
|
||||
# if there's a reduce child, shapes match with ones removed
|
||||
if unwrap(view.st).contiguous and view.size == r.size and \
|
||||
(not (len(r.arg) == 3 and r.arg[2]) or # arg[2] = True is fuse marker
|
||||
tuple((i,x) for i,x in enumerate(r.shape) if resolve(x != 1)) == tuple((i,x) for i,x in enumerate(view.shape) if resolve(x != 1))):
|
||||
return None
|
||||
# swizzle the input
|
||||
input_st = ShapeTracker.from_shape(src.shape)
|
||||
tmp = input_st.permute(tuple(i for i in range(len(input_st.shape)) if i not in r.axis_arg)+r.axis_arg)
|
||||
prshape = prod(rshape:=tmp.shape[-len(r.axis_arg):])
|
||||
strides = strides_for_shape(rshape)
|
||||
nv = [View.create(v.shape+rshape, tuple(x*prshape for x in v.strides)+strides,
|
||||
v.offset*prshape, v.mask+tuple((0,s) for s in rshape) if v.mask is not None else None) for v in unwrap(view.st).views]
|
||||
new_view = tmp + ShapeTracker(tuple(nv))
|
||||
swizzled_input = apply_swizzle(src.view(new_view))
|
||||
# create a new reduceop
|
||||
new_axis = tuple(range(len(view.shape), len(view.shape) + len(r.axis_arg)))
|
||||
if fuse: red = UOp(Ops.REDUCE_AXIS, r.dtype, (swizzled_input.fuse(),), (r.arg[0], new_axis, True))
|
||||
else: red = UOp(Ops.REDUCE_AXIS, r.dtype, (swizzled_input,), (r.arg[0], new_axis))
|
||||
return red.reshape(view.shape)
|
||||
|
||||
def reduceop_view_right(src:UOp, v:UOp, r:UOp):
|
||||
assert unwrap(v.st).contiguous and v.size == src.size, f"can't compute new axis for {src.shape} -> {r.shape}"
|
||||
new_axis = [i for i,(s,u) in enumerate(zip(src.shape, r.shape)) if s != u]
|
||||
return src.r(r.arg[0], tuple(new_axis)).reshape(r.shape)
|
||||
|
||||
def elementwise_view_right(root:UOp):
|
||||
if not (swizzles:=[x for x in root.src if x.op is Ops.VIEW and x.base.op not in ALWAYS_CONTIGUOUS]): return None
|
||||
assert all_same([x.base.size for x in swizzles]), f"swizzle inputs must have the same size {swizzles}"
|
||||
# place view after applying the elementwise op
|
||||
new_st = ShapeTracker.from_shape(swizzles[0].base.shape)
|
||||
new_src = [x.base if x.base.shape==new_st.shape else apply_swizzle(x.view(new_st)) for x in root.src]
|
||||
# reshape to match downstream shapes
|
||||
return root.replace(src=tuple(new_src)).reshape(root.shape)
|
||||
|
||||
# push VIEW to children
|
||||
view_right = merge_views+PatternMatcher([
|
||||
# push a non contiguous ShapeTracker through reduceop
|
||||
(UPat(Ops.VIEW, src=(UPat(Ops.REDUCE_AXIS, src=(UPat.var("src"),), name="r"),), name="view"), swizzle_reduceop),
|
||||
# apply view after reduceops
|
||||
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.VIEW, src=(UPat(GroupOp.All-ALWAYS_CONTIGUOUS, name="src"),), name="v"),), name="r"), reduceop_view_right),
|
||||
# apply view after elementwise ops
|
||||
(UPat(GroupOp.All-{Ops.SINK, Ops.REDUCE_AXIS, Ops.STORE}, name="root"), elementwise_view_right),
|
||||
# merge axes for double reduce (invert of SPLIT_REDUCEOP=1)
|
||||
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="r1"),), name="r2"),
|
||||
lambda r1,r2: r1.replace(arg=(r1.arg[0], r2.arg[1]+r1.arg[1])) if r1.arg[0] is r2.arg[0] else None),
|
||||
])
|
||||
|
||||
def check_load_st(glbl:UOp, view:UOp):
|
||||
if glbl.arg != 0 or (st:=unwrap(view.st)).contiguous: return
|
||||
# if it has a single view and it becomes contiguous when you shrink expanded axes, it's fine
|
||||
if len(st.views) == 1 and st.shrink(tuple((0,1) if st == 0 else (0,s) for s,st in zip(st.shape, st.views[0].strides))).contiguous: return
|
||||
# if it has a single view and it's equal when you shrink a contig, it's fine
|
||||
if len(st.views) == 1 and (mask:=st.views[0].mask) is not None and ShapeTracker.from_shape(st.shape).shrink(mask) == st.shrink(mask): return
|
||||
# otherwise, it's not fine
|
||||
raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
|
||||
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
|
||||
|
||||
fix_kernel_ops = view_left_through_load+PatternMatcher([
|
||||
# STORE (except for meta ops)
|
||||
(UPat(Ops.SINK, src=UPat(GroupOp.All-{Ops.STORE}), name="sink"), lambda sink:
|
||||
UOp.sink(*[UOp.store(UOp(Ops.DEFINE_GLOBAL, (s:=x.base).dtype.ptr(s.st.real_size()), (), i).view(s.st), s) for i,x in enumerate(sink.src)])),
|
||||
# passthrough ASSIGN
|
||||
(UPat(Ops.ASSIGN, name="x"), lambda x: x.src[1]),
|
||||
# VALID
|
||||
(UPat(Ops.VIEW, src=(UPat.cvar(),), name="self"),
|
||||
lambda self: UOp.where(UOp(Ops.VALID, dtypes.bool, (UOp(Ops.VIEW, arg=self.st),)), self.const_like(self.base.arg), 0)),
|
||||
# remove CONTIGUOUS/DEVICE from kernel AST
|
||||
(UPat((Ops.CONTIGUOUS, Ops.MSELECT), src=(UPat.var("x"),)), lambda x: x),
|
||||
(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),), name="view"), lambda view: view.replace(src=())),
|
||||
# no ImageDType after index
|
||||
(UPat(GroupOp.All-{Ops.DEFINE_GLOBAL, Ops.VIEW}, name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None),
|
||||
# if this kernel also assigns to the loaded buffer, ensure we can index it correctly
|
||||
(UPat(Ops.LOAD, src=(UPat.var("glbl").view(name="view"),)), check_load_st),
|
||||
])
|
||||
|
|
@ -66,7 +66,7 @@ class NVPageTableEntry:
|
|||
return self.read_fields(entry_id)[f'address{small}{sys}'] << 12
|
||||
|
||||
class NVMemoryManager(MemoryManager):
|
||||
va_allocator = TLSFAllocator((1 << 44), base=1 << 30) # global for all devices.
|
||||
va_allocator = TLSFAllocator((1 << 44), base=0x1000000000) # global for all devices.
|
||||
|
||||
def on_range_mapped(self): self.dev.NV_VIRTUAL_FUNCTION_PRIV_MMU_INVALIDATE.write((1 << 0) | (1 << 1) | (1 << 6) | (1 << 31))
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from tinygrad.helpers import all_int, prod, unwrap, dedup, DONT_REALIZE_EXPAND,
|
|||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
|
||||
ALWAYS_CONTIGUOUS = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW,
|
||||
Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK}
|
||||
Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.DEFINE_GLOBAL, Ops.LOAD}
|
||||
|
||||
# **** Grouper decides which of the UOps realize
|
||||
|
||||
|
|
|
|||
|
|
@ -1,14 +1,13 @@
|
|||
from dataclasses import dataclass
|
||||
from tinygrad.uop.ops import UOp, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, identity_element, resolve, sint
|
||||
from tinygrad.uop.ops import UOp, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, identity_element, resolve
|
||||
from tinygrad.uop.ops import track_rewrites, _substitute
|
||||
from tinygrad.uop.spec import type_verify, tensor_uop_spec
|
||||
from tinygrad.uop.symbolic import symbolic_simple
|
||||
from tinygrad.helpers import Metadata, all_int, all_same, colored, prod, dedup, unwrap, getenv, pluralize, FUSE_ARANGE, DEBUG, SPLIT_REDUCEOP
|
||||
from tinygrad.dtype import ImageDType, dtypes
|
||||
from tinygrad.helpers import Metadata, all_int, all_same, prod, dedup, unwrap, getenv, pluralize, FUSE_ARANGE, DEBUG, SPLIT_REDUCEOP
|
||||
from tinygrad.dtype import ImageDType
|
||||
from tinygrad.schedule.multi import multi_pm
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View, strides_for_shape, get_contraction_with_reduce
|
||||
from tinygrad.schedule.grouper import group_realizes, ALWAYS_CONTIGUOUS
|
||||
from tinygrad.opt.swizzler import merge_views, view_left, view_right, fix_kernel_ops, apply_swizzle, swizzle_reduceop
|
||||
|
||||
# creation can recurse a lot
|
||||
import sys
|
||||
|
|
@ -148,139 +147,13 @@ create_kernels = PatternMatcher([
|
|||
lambda ms: UOp(Ops.MSTACK, ms.dtype, tuple(x.src[0] for x in ms.src)).reshape(ms.src[0].arg)),
|
||||
])
|
||||
|
||||
# **** swizzler
|
||||
|
||||
merge_views = PatternMatcher([
|
||||
# merge adjacent views
|
||||
(UPat(Ops.VIEW, src=(UPat(Ops.VIEW, name="v1"),), name="v2"), lambda v1,v2: v1.replace(arg=v1.arg+v2.arg)),
|
||||
# replace MovementOps with VIEW
|
||||
(UPat(GroupOp.Movement, src=(UPat.var("x"),), name="mop"), lambda mop,x: x.base.view(mop.st)),
|
||||
# remove NOOP views
|
||||
(UPat.var("x").view(name="view"),
|
||||
lambda x,view: x if x.st is not None and x.op not in GroupOp.Defines and view.st.contiguous and view.shape == x.shape else None),
|
||||
(UPat(GroupOp.All-{Ops.DEFINE_GLOBAL}).view(name="view"),
|
||||
lambda view: view.const_like(0) if (mask:=view.st.views[-1].mask) is not None and any((x[1]-x[0]) == 0 for x in mask) else None),
|
||||
# only unmaksed VIEW on CONST replaces the ShapeTracker
|
||||
(UPat(Ops.VIEW, src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="x"),), name="view"),
|
||||
lambda x,view: x.replace(src=(x.src[0].replace(arg=x.st+view.st),)) if all(v.mask is None for v in (x.st+view.st).views) else None),
|
||||
])
|
||||
|
||||
def reduce_push_add_ones(src:UOp, r:UOp, view:UOp):
|
||||
# contiguous, expand, and the same with ones removed
|
||||
if unwrap(view.st).contiguous and len(r.shape) < len(view.shape) and \
|
||||
tuple(x for x in r.shape if resolve(x != 1)) == tuple(x for x in view.shape if resolve(x != 1)):
|
||||
new_shape: list[sint] = []
|
||||
new_reduce_axis = []
|
||||
if (contraction:=get_contraction_with_reduce(view.shape, r.shape, r.arg[1])) is None: return None
|
||||
for i,pairs in enumerate(contraction):
|
||||
new_shape_chunk = [view.shape[p] for p in pairs]
|
||||
if i in r.arg[1]:
|
||||
# if this is a reduce axis, we need a 1 in the view here to put it
|
||||
assert len(new_shape_chunk) > 0
|
||||
new_shape += [1]*(len(pairs)-1) + [src.shape[i]]
|
||||
new_reduce_axis.append(len(new_shape)-1)
|
||||
else:
|
||||
# otherwise, pass through the new_shape_chunk
|
||||
new_shape += new_shape_chunk
|
||||
ret = r.replace(src=(src.reshape(tuple(new_shape)),), arg=(r.arg[0], tuple(new_reduce_axis))+r.arg[2:])
|
||||
assert ret.shape == view.shape, f"shape mismatch on reduce_push_add_ones, {ret.shape} != {view.shape}"
|
||||
return ret
|
||||
return None
|
||||
|
||||
view_left = merge_views+PatternMatcher([
|
||||
# view before elementwise and buffer ops
|
||||
(UPat(Ops.VIEW, src=(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.BIND, Ops.LOAD, Ops.STORE, Ops.VALID, Ops.SINK}, name="e"),), name="view"),
|
||||
lambda e,view: e.replace(src=tuple(s.view(view.st) for s in e.src))),
|
||||
# if there's ones added after reduce, put this before the reduce
|
||||
(UPat(Ops.VIEW, src=(UPat(Ops.REDUCE_AXIS, src=(UPat.var("src"),), name="r"),), name="view"), reduce_push_add_ones),
|
||||
])
|
||||
|
||||
def apply_swizzle(u:UOp) -> UOp: return graph_rewrite(u, view_left, name="Sub View Left")
|
||||
|
||||
# change reduceop axes and input ShapeTrackers, view gets replaced with a reshape.
|
||||
def swizzle_reduceop(r:UOp, src:UOp, view:UOp, fuse=False):
|
||||
# contiguous and same size can push to children
|
||||
# if there's a reduce child, shapes match with ones removed
|
||||
if unwrap(view.st).contiguous and view.size == r.size and \
|
||||
(not (len(r.arg) == 3 and r.arg[2]) or # arg[2] = True is fuse marker
|
||||
tuple((i,x) for i,x in enumerate(r.shape) if resolve(x != 1)) == tuple((i,x) for i,x in enumerate(view.shape) if resolve(x != 1))):
|
||||
return None
|
||||
# swizzle the input
|
||||
input_st = ShapeTracker.from_shape(src.shape)
|
||||
tmp = input_st.permute(tuple(i for i in range(len(input_st.shape)) if i not in r.axis_arg)+r.axis_arg)
|
||||
prshape = prod(rshape:=tmp.shape[-len(r.axis_arg):])
|
||||
strides = strides_for_shape(rshape)
|
||||
nv = [View.create(v.shape+rshape, tuple(x*prshape for x in v.strides)+strides,
|
||||
v.offset*prshape, v.mask+tuple((0,s) for s in rshape) if v.mask is not None else None) for v in unwrap(view.st).views]
|
||||
new_view = tmp + ShapeTracker(tuple(nv))
|
||||
swizzled_input = apply_swizzle(src.view(new_view))
|
||||
# create a new reduceop
|
||||
new_axis = tuple(range(len(view.shape), len(view.shape) + len(r.axis_arg)))
|
||||
if fuse: red = UOp(Ops.REDUCE_AXIS, r.dtype, (swizzled_input.fuse(),), (r.arg[0], new_axis, True))
|
||||
else: red = UOp(Ops.REDUCE_AXIS, r.dtype, (swizzled_input,), (r.arg[0], new_axis))
|
||||
return red.reshape(view.shape)
|
||||
|
||||
def reduceop_view_right(src:UOp, v:UOp, r:UOp):
|
||||
assert unwrap(v.st).contiguous and v.size == src.size, f"can't compute new axis for {src.shape} -> {r.shape}"
|
||||
new_axis = [i for i,(s,u) in enumerate(zip(src.shape, r.shape)) if s != u]
|
||||
return src.r(r.arg[0], tuple(new_axis)).reshape(r.shape)
|
||||
|
||||
def elementwise_view_right(root:UOp):
|
||||
if not (swizzles:=[x for x in root.src if x.op is Ops.VIEW and x.base.op not in ALWAYS_CONTIGUOUS]): return None
|
||||
assert all_same([x.base.size for x in swizzles]), f"swizzle inputs must have the same size {swizzles}"
|
||||
# place view after applying the elementwise op
|
||||
new_st = ShapeTracker.from_shape(swizzles[0].base.shape)
|
||||
new_src = [x.base if x.base.shape==new_st.shape else apply_swizzle(x.view(new_st)) for x in root.src]
|
||||
# reshape to match downstream shapes
|
||||
return root.replace(src=tuple(new_src)).reshape(root.shape)
|
||||
|
||||
# push VIEW to children
|
||||
view_right = merge_views+PatternMatcher([
|
||||
# push a non contiguous ShapeTracker through reduceop
|
||||
(UPat(Ops.VIEW, src=(UPat(Ops.REDUCE_AXIS, src=(UPat.var("src"),), name="r"),), name="view"), swizzle_reduceop),
|
||||
# apply view after reduceops
|
||||
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.VIEW, src=(UPat(GroupOp.All-ALWAYS_CONTIGUOUS, name="src"),), name="v"),), name="r"), reduceop_view_right),
|
||||
# apply view after elementwise ops
|
||||
(UPat(GroupOp.All-{Ops.SINK, Ops.REDUCE_AXIS}, name="root"), elementwise_view_right),
|
||||
# merge axes for double reduce (invert of SPLIT_REDUCEOP=1)
|
||||
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="r1"),), name="r2"),
|
||||
lambda r1,r2: r1.replace(arg=(r1.arg[0], r2.arg[1]+r1.arg[1])) if r1.arg[0] is r2.arg[0] else None),
|
||||
])
|
||||
|
||||
# **** fix kernel AST
|
||||
|
||||
add_buffer_ops = PatternMatcher([
|
||||
early_buffer_ops = PatternMatcher([
|
||||
# LOAD
|
||||
(UPat(Ops.BUFFER, name="x"), lambda ctx,x: UOp.load(UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), ctx.index(x)).view(x.st),)),
|
||||
# STORE (except for meta ops)
|
||||
(UPat(Ops.BUFFER, name="x"), lambda ctx,x: UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), ctx.index(x)).view(x.st).load()),
|
||||
# no SINK for meta ops
|
||||
(UPat(Ops.SINK, src=(UPat(Ops.CONTIGUOUS, src=(UPat(GroupOp.Meta, name="x"),),))), lambda x:x),
|
||||
(UPat(Ops.SINK, src=UPat(GroupOp.All-{Ops.STORE}), name="sink"), lambda ctx,sink:
|
||||
UOp.sink(*[UOp.store(UOp(Ops.DEFINE_GLOBAL, (s:=x.base).dtype.ptr(ctx[i].size), (), i).view(s.st), s) for i,x in enumerate(sink.src)])),
|
||||
# passthrough ASSIGN
|
||||
(UPat(Ops.ASSIGN, name="x"), lambda x: x.src[1]),
|
||||
# VALID
|
||||
(UPat(Ops.VIEW, src=(UPat.cvar(),), name="self"),
|
||||
lambda self: UOp.where(UOp(Ops.VALID, dtypes.bool, (UOp(Ops.VIEW, arg=self.st),)), self.const_like(self.base.arg), 0)),
|
||||
])
|
||||
|
||||
def check_load_st(glbl:UOp, view:UOp):
|
||||
if glbl.arg != 0 or (st:=unwrap(view.st)).contiguous: return
|
||||
# if it has a single view and it becomes contiguous when you shrink expanded axes, it's fine
|
||||
if len(st.views) == 1 and st.shrink(tuple((0,1) if st == 0 else (0,s) for s,st in zip(st.shape, st.views[0].strides))).contiguous: return
|
||||
# if it has a single view and it's equal when you shrink a contig, it's fine
|
||||
if len(st.views) == 1 and (mask:=st.views[0].mask) is not None and ShapeTracker.from_shape(st.shape).shrink(mask) == st.shrink(mask): return
|
||||
# otherwise, it's not fine
|
||||
raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
|
||||
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
|
||||
|
||||
fix_kernel_ops = PatternMatcher([
|
||||
# remove CONTIGUOUS/DEVICE from kernel AST
|
||||
(UPat((Ops.CONTIGUOUS, Ops.MSELECT), src=(UPat.var("x"),)), lambda x: x),
|
||||
(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),), name="view"), lambda view: view.replace(src=())),
|
||||
# no ImageDType after index
|
||||
(UPat(GroupOp.All-{Ops.DEFINE_GLOBAL, Ops.VIEW}, name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None),
|
||||
# if this kernel also assigns to the loaded buffer, ensure we can index it correctly
|
||||
(UPat(Ops.LOAD, src=(UPat.var("glbl").view(name="view"),)), check_load_st),
|
||||
])
|
||||
|
||||
replace_globals = PatternMatcher([
|
||||
|
|
@ -292,10 +165,6 @@ replace_globals = PatternMatcher([
|
|||
|
||||
def fix_kernel_ast(k:UOp) -> UOp|None:
|
||||
if k.arg.ast.op in GroupOp.Meta or all(s.op is Ops.STORE for s in k.arg.ast.src): return None
|
||||
# replace global memory ops with the BUFFER they write to
|
||||
ast = graph_rewrite(k.arg.ast, replace_globals, bottom_up=True, name="replace globals")
|
||||
# push views to edges
|
||||
ast = graph_rewrite(graph_rewrite(ast, view_left, name="Main View Left"), view_right, name="Main View Right")
|
||||
# replace buffer with define_global + add load/store last
|
||||
bufs = []
|
||||
for s in k.src:
|
||||
|
|
@ -303,9 +172,15 @@ def fix_kernel_ast(k:UOp) -> UOp|None:
|
|||
# traverse back through MSELECT and MSTACK. HACK: 0 branch of MSTACK only
|
||||
while s.op in {Ops.MSELECT, Ops.MSTACK}: s = s.src[0]
|
||||
bufs.append(s)
|
||||
ast = graph_rewrite(ast, view_left+add_buffer_ops+fix_kernel_ops, bufs, bottom_up=True, name="replace buffer")
|
||||
# replace global memory ops with the BUFFER they write to
|
||||
ast = graph_rewrite(k.arg.ast, replace_globals, bottom_up=True, name="replace globals")
|
||||
ast = graph_rewrite(ast, early_buffer_ops, bufs, bottom_up=True, name="replace buffer early")
|
||||
if ast.op is Ops.SINK and not all_same([x.device for x in k.src]):
|
||||
raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop.buffer for b in k.src)}")
|
||||
# TODO: move these to codegen
|
||||
ast = graph_rewrite(ast, view_left, name="Main View Left")
|
||||
ast = graph_rewrite(ast, view_right, name="Main View Right")
|
||||
ast = graph_rewrite(ast, view_left+fix_kernel_ops, bottom_up=True, name="replace buffer")
|
||||
return k.replace(arg=Kernel(ast, k.arg.metadata))
|
||||
|
||||
create_ast = PatternMatcher([(UPat(Ops.KERNEL, name="k"), fix_kernel_ast),])
|
||||
|
|
@ -441,8 +316,6 @@ def get_kernelize_map(sink:UOp) -> dict[UOp, UOp]:
|
|||
tensor_map = graph_rewrite_map(tensor_map[sink], add_contiguous, ctx=realize_map, bottom_up=True, input_map=tensor_map, name="add_contiguous")
|
||||
tensor_map = graph_rewrite_map(tensor_map[sink], finalize_contiguous+remove_tags, input_map=tensor_map, name="finalize_contiguous")
|
||||
|
||||
# TODO: move view_left/view_right here
|
||||
|
||||
# group into kernels (this is context-free)
|
||||
tensor_map = graph_rewrite_map(tensor_map[sink], create_kernels, input_map=tensor_map, name="create_kernels")
|
||||
|
||||
|
|
|
|||
|
|
@ -2234,7 +2234,7 @@ class Tensor(MathTrait):
|
|||
"""
|
||||
def parse_formula(formula:str, *operands:Tensor):
|
||||
if "..." in (formula := formula.replace(" ", "")):
|
||||
ell_chars, ell_longest = "".join(set(string.ascii_letters) - set(formula)), 0
|
||||
ell_chars, ell_longest = "".join(c for c in string.ascii_letters if c not in formula), 0
|
||||
for i, inp in enumerate(filter(lambda x: "..." in x, inputs := formula.split("->")[0].split(","))):
|
||||
if (ell_count := max(operands[i].ndim, 1) - (len(inp) - len("..."))) > ell_longest: ell_longest = ell_count
|
||||
inputs[i] = inp.replace("...", ell_chars[-ell_count:])
|
||||
|
|
|
|||
|
|
@ -440,7 +440,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
|||
all_vars = set([x for x in self.toposort() if x.op is Ops.DEFINE_VAR])
|
||||
return bound_vars.union(set([x for x in all_vars if x not in bound_var_base]))
|
||||
def variables(self) -> list[Variable]:
|
||||
st_vars: list[set[Variable]] = [x.st_arg.vars() for x in self.toposort() if x.op in GroupOp.Buffer]
|
||||
st_vars: list[set[Variable]] = [x.arg.vars() for x in self.toposort() if x.op is Ops.VIEW]
|
||||
return sorted(set.union(*st_vars, set([x.unbind()[0] if x.op is not Ops.DEFINE_VAR else x for x in self.vars()])), key=lambda v: v.arg)
|
||||
|
||||
# *** uop symbolic stuff ***
|
||||
|
|
|
|||
|
|
@ -132,7 +132,8 @@ def timeline_layout(events:list[tuple[int, int, float, DevEvent]]) -> dict:
|
|||
name, cat, info = e.name, None, None
|
||||
if (ref:=ref_map.get(name)) is not None:
|
||||
name = ctxs[ref]["name"]
|
||||
if isinstance(p:=contexts[0][ref].ret, ProgramSpec):
|
||||
# TODO: support symbolic by capturing var_vals in profile events
|
||||
if isinstance(p:=contexts[0][ref].ret, ProgramSpec) and all(isinstance(es,int) for es in [p.estimates.ops, p.estimates.mem, p.estimates.lds]):
|
||||
info = f"{p.estimates.ops/(t:=dur*1e3):.2f} GFLOPS {p.estimates.mem/t:4.1f}|{p.estimates.lds/t:.1f} GB/s"
|
||||
elif isinstance(e.name, TracingKey):
|
||||
name, cat = e.name.display_name, e.name.cat
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue