Merge branch 'master' into single_file_onnx/1

This commit is contained in:
zibokapi 2025-08-06 11:51:26 +08:00
commit 8b59601ca6
16 changed files with 386 additions and 328 deletions

View file

@ -870,28 +870,28 @@ jobs:
run: WEBGPU=1 PYTHONPATH=. python3 test/external/external_test_onnx_runner.py
osxremote:
name: MacOS (remote metal)
runs-on: macos-15
timeout-minutes: 10
env:
REMOTE: 1
REMOTEDEV: METAL
steps:
- name: Checkout Code
uses: actions/checkout@v4
- name: Setup Environment
uses: ./.github/actions/setup-tinygrad
with:
key: macos-remote
deps: testing_minimal
- name: Check Device.DEFAULT and print some source
run: |
python -c "from tinygrad import Device; assert Device.DEFAULT == 'REMOTE', Device.DEFAULT"
python -c "from tinygrad import Device; assert Device.default.properties.real_device == 'METAL', Device.default.properties.real_device"
DEBUG=4 python3 test/test_tiny.py TestTiny.test_plus
- name: Run REMOTE=1 Test
run: |
python3 -m pytest test/test_tiny.py test/test_jit.py test/test_subbuffer.py test/test_graph.py test/test_multitensor.py test/test_tensor_variable.py
name: MacOS (remote metal)
runs-on: macos-15
timeout-minutes: 10
env:
REMOTE: 1
REMOTEDEV: METAL
steps:
- name: Checkout Code
uses: actions/checkout@v4
- name: Setup Environment
uses: ./.github/actions/setup-tinygrad
with:
key: macos-remote
deps: testing_minimal
- name: Check Device.DEFAULT and print some source
run: |
python -c "from tinygrad import Device; assert Device.DEFAULT == 'REMOTE', Device.DEFAULT"
python -c "from tinygrad import Device; assert Device.default.properties.real_device == 'METAL', Device.default.properties.real_device"
DEBUG=4 python3 test/test_tiny.py TestTiny.test_plus
- name: Run REMOTE=1 Test
run: |
python3 -m pytest test/test_tiny.py test/test_jit.py test/test_subbuffer.py test/test_graph.py test/test_multitensor.py test/test_tensor_variable.py
amdremote:
name: Linux (remote)

View file

@ -9,7 +9,7 @@ with open(directory / 'README.md', encoding='utf-8') as f:
testing_minimal = [
"numpy",
"torch",
"torch==2.7.1",
"pytest",
"pytest-xdist",
"hypothesis",

View file

@ -5,10 +5,10 @@ import numpy as np
from hypothesis import given, settings, strategies as strat
from test.helpers import assert_jit_cache_len, not_support_multi_device, REAL_DEV
from tinygrad.tensor import Tensor
from tinygrad.engine.jit import TinyJit, GraphRunner
from tinygrad.engine.jit import TinyJit, GraphRunner, MultiGraphRunner, graph_class
from tinygrad.engine.realize import CompiledRunner, BufferCopy, BufferXfer
from tinygrad.device import Device
from tinygrad.helpers import Context, JIT, GlobalCounters
from tinygrad.runtime.support.hcq import HCQCompiled
from tinygrad.helpers import Context, JIT, GlobalCounters, getenv
from tinygrad.dtype import dtypes
from extra.models.unet import ResBlock
@ -472,32 +472,6 @@ class TestJit(unittest.TestCase):
np.testing.assert_allclose((a.numpy()+b.numpy()), zc.numpy(), atol=1e-4, rtol=1e-5)
np.testing.assert_allclose((a.numpy()*b.numpy()), wc.numpy(), atol=1e-4, rtol=1e-5)
@unittest.skipUnless((not isinstance(Device.default, HCQCompiled)) and Device.default.graph is not None, "must be non-hcq with graph")
def test_jit_several_incompatible_devs(self):
assert isinstance(Device["CPU"], HCQCompiled) and Device["CPU"].graph is not None
assert (not isinstance(Device.default, HCQCompiled)) and Device.default.graph is not None
d0, d1 = Device.DEFAULT, "CPU"
@TinyJit
def f(a0, b0):
a1 = (a + 2.0).contiguous().realize()
a2 = (a1 * 2.0).contiguous().realize()
b1 = (b0 + 2.0).contiguous().realize()
b2 = (b1 * 2.0).contiguous().realize()
return a2, b2
for _ in range(5):
a = Tensor.randn(10, 10, device=d0).realize()
b = Tensor.randn(10, 10, device=d1).realize()
a1, b1 = f(a, b)
np.testing.assert_allclose(((a.numpy()+2.0)*2.0), a1.numpy(), atol=1e-4, rtol=1e-5)
np.testing.assert_allclose(((b.numpy()+2.0)*2.0), b1.numpy(), atol=1e-4, rtol=1e-5)
assert all(isinstance(ei.prg, GraphRunner) for ei in f.jit_cache), repr(f.jit_cache)
@unittest.skipIf(not_support_multi_device(), "no multi")
def test_jitted_view(self):
d0, d1 = f"{Device.DEFAULT}:0", f"{Device.DEFAULT}:1"
@ -696,5 +670,170 @@ class TestJitFree(unittest.TestCase):
out = fxn(Tensor([11,1,2,3,4]))
self.assertEqual(out.item(), 13600)
class TestJitGraphSplit(unittest.TestCase):
def compute(self, device, inp):
assert inp.device == device, f"Input device {inp.device} does not match expected {device}"
return (inp + 1.0).contiguous().realize()
def copy(self, device, to_device, inp):
assert inp.device == device, f"Input device {inp.device} does not match expected {device}"
return inp.to(to_device).realize()
def expect(self, f, *args, graph=None, multigraph=None, hcqgraph=None):
def _numpies(tpl): return tpl.numpy() if tpl.__class__ is Tensor else tuple([t.numpy() for t in tpl])
expected = _numpies(f(*args))
for i in range(4):
res = _numpies(f(*args))
np.testing.assert_allclose(res, expected, atol=1e-4, rtol=1e-5)
dev = Device[Device.DEFAULT]
graph_t = graph_class(dev)
if graph_t is None: return
got = f.jit_cache
from tinygrad.runtime.graph.hcq import HCQGraph
if graph_t is HCQGraph:
validate = hcqgraph
elif issubclass(graph_t, MultiGraphRunner):
validate = multigraph
else:
validate = graph
assert len(got) == len(validate), f"Expected {len(validate)} operations, got {len(got)}"
for expected, got in zip(validate, got):
if expected["type"] == "graph":
assert isinstance(got.prg, GraphRunner), f"Expected GraphRunner, got {type(got.prg)}"
assert len(got.prg.jit_cache) == expected["cnt"], f"Expected {expected['cnt']} operations in graph, got {len(got.prg.jit_cache)}"
elif expected["type"] == "comp":
assert isinstance(got.prg, CompiledRunner), f"Expected CompiledRunner, got {type(got.prg)}"
elif expected["type"] == "copy":
assert isinstance(got.prg, BufferCopy), f"Expected BufferCopy, got {type(got.prg)}"
elif expected["type"] == "xfer":
assert isinstance(got.prg, BufferXfer), f"Expected BufferXfer, got {type(got.prg)}"
def ji_graph(self, cnt): return {"type": "graph", "cnt": cnt}
def ji_comp(self): return {"type": "comp"}
def ji_copy(self): return {"type": "copy"}
def ji_xfer(self): return {"type": "xfer"}
def test_jit_split_simple(self):
if Device.DEFAULT == "REMOTE": raise unittest.SkipTest("REMOTE gpu is broken")
@TinyJit
def f(inp):
op0 = self.compute(Device.DEFAULT, inp)
op1 = self.compute(Device.DEFAULT, op0)
op2 = self.compute(Device.DEFAULT, op1)
return op2
inp = Tensor.randn(10, 10, device=Device.DEFAULT).realize()
self.expect(f, inp,
graph=[self.ji_graph(3)],
multigraph=[self.ji_graph(3)],
hcqgraph=[self.ji_graph(3)])
def test_jit_cpu_simple(self):
if Device.DEFAULT == "CPU": raise unittest.SkipTest("CPU is not a valid default device for this test")
@TinyJit
def f(inp, inp_cpu):
op0 = self.compute(Device.DEFAULT, inp)
op1 = self.compute(Device.DEFAULT, op0)
op2 = self.compute("CPU", inp_cpu)
op3 = self.compute(Device.DEFAULT, op1)
return op2, op3
inp = Tensor.randn(10, 10, device=Device.DEFAULT).realize()
inp_cpu = Tensor.randn(10, 10, device="CPU").realize()
self.expect(f, inp, inp_cpu,
graph=[self.ji_graph(2), self.ji_comp(), self.ji_comp()],
multigraph=[self.ji_graph(2), self.ji_comp(), self.ji_comp()],
hcqgraph=[self.ji_graph(4)])
def test_jit_cpu_several(self):
if Device.DEFAULT == "CPU": raise unittest.SkipTest("CPU is not a valid default device for this test")
@TinyJit
def f(inp, inp_cpu):
op0 = self.compute(Device.DEFAULT, inp)
op1 = self.compute(Device.DEFAULT, op0)
op2 = self.compute("CPU", inp_cpu)
op3 = self.compute("CPU", op2)
op4 = self.compute(Device.DEFAULT, op1)
return op3, op4
inp = Tensor.randn(10, 10, device=Device.DEFAULT).realize()
inp_cpu = Tensor.randn(10, 10, device="CPU").realize()
self.expect(f, inp, inp_cpu,
graph=[self.ji_graph(2), self.ji_graph(2), self.ji_comp()],
multigraph=[self.ji_graph(2), self.ji_graph(2), self.ji_comp()],
hcqgraph=[self.ji_graph(5)])
def test_jit_multidev(self):
if Device.DEFAULT == "CPU": raise unittest.SkipTest("CPU is not a valid default device for this test")
try: Device[f"{Device.DEFAULT}:1"]
except Exception: raise unittest.SkipTest("no multidevice")
@TinyJit
def f(inp, inp_d1):
op0 = self.compute(Device.DEFAULT, inp)
op1 = self.compute(Device.DEFAULT, op0)
op2 = self.compute(f"{Device.DEFAULT}:1", inp_d1)
op3 = self.compute(f"{Device.DEFAULT}:1", op2)
op4 = self.compute(Device.DEFAULT, op1)
return op3, op4
inp = Tensor.randn(10, 10, device=Device.DEFAULT).realize()
inp_d1 = Tensor.randn(10, 10, device=f"{Device.DEFAULT}:1").realize()
self.expect(f, inp, inp_d1,
graph=[self.ji_graph(2), self.ji_graph(2), self.ji_comp()],
multigraph=[self.ji_graph(5)],
hcqgraph=[self.ji_graph(5)])
@unittest.skip("flaky")
def test_jit_multidev_xfer(self):
if Device.DEFAULT in {"CPU", "LLVM"}: raise unittest.SkipTest("CPU/LLVM is not a valid default device for this test (zero-copies)")
try: Device[f"{Device.DEFAULT}:1"]
except Exception: raise unittest.SkipTest("no multidevice")
@TinyJit
def f(inp, inp_d1):
op0 = self.compute(Device.DEFAULT, inp)
op1 = self.compute(Device.DEFAULT, op0)
op2 = self.compute(f"{Device.DEFAULT}:1", inp_d1)
op3 = self.copy(f"{Device.DEFAULT}:1", Device.DEFAULT, op2)
op4 = self.compute(f"{Device.DEFAULT}:1", op2)
op5 = self.compute(Device.DEFAULT, op3)
return op1, op4, op5
inp = Tensor.randn(10, 10, device=Device.DEFAULT).realize()
inp_d1 = Tensor.randn(10, 10, device=f"{Device.DEFAULT}:1").realize()
self.expect(f, inp, inp_d1,
graph=[self.ji_graph(2), self.ji_comp(), self.ji_xfer(), self.ji_comp(), self.ji_comp()],
multigraph=[self.ji_graph(6)],
hcqgraph=[self.ji_graph(6)])
@unittest.skipIf(getenv("MOCKGPU"), "MockGPU does not support parallel copies")
def test_jit_multidev_copy(self):
if Device.DEFAULT in {"CPU", "LLVM"}: raise unittest.SkipTest("CPU/LLVM is not a valid default device for this test (zero-copies)")
if Device.DEFAULT == "REMOTE": raise unittest.SkipTest("REMOTE gpu is broken")
@TinyJit
def f(inp):
op0 = self.compute(Device.DEFAULT, inp)
op1 = self.compute(Device.DEFAULT, op0)
op2 = self.copy(Device.DEFAULT, "CPU", op1)
op3 = self.compute("CPU", op2)
return op3
inp = Tensor.randn(10, 10, device=Device.DEFAULT).realize()
self.expect(f, inp,
graph=[self.ji_graph(2), self.ji_copy(), self.ji_comp()],
multigraph=[self.ji_graph(2), self.ji_copy(), self.ji_comp()],
hcqgraph=[self.ji_graph(4)])
if __name__ == '__main__':
unittest.main()

View file

@ -108,105 +108,39 @@ class TestNN(unittest.TestCase):
_test_linear(Tensor.randn(BS, in_dim), in_dim, out_dim)
_test_linear(Tensor.randn(BS, T, in_dim), in_dim, out_dim) # test with more dims
def test_conv1d(self):
BS, C1, W = 4, 16, 224//4
C2, K, S, P = 64, 7, 2, 1
def _test_conv(self, tiny_conv, torch_conv, BS, C1, DIMS, C2, K, S, P, D=1):
# create in tinygrad
layer = Conv1d(C1, C2, kernel_size=K, stride=S, padding=P)
layer = tiny_conv(C1, C2, kernel_size=K, stride=S, padding=P, dilation=D)
# create in torch
with torch.no_grad():
torch_layer = torch.nn.Conv1d(C1, C2, kernel_size=K, stride=S, padding=P).eval()
torch_layer = torch_conv(C1, C2, kernel_size=K, stride=S, padding=P, dilation=D).eval()
torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32)
torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32)
# test
x = Tensor.uniform(BS, C1, W)
x = Tensor.uniform(BS, C1, *DIMS)
z = layer(x)
torch_x = torch.tensor(x.numpy())
torch_z = torch_layer(torch_x)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5)
def test_conv2d(self):
BS, C1, H, W = 4, 16, 224//4, 224//4
C2, K, S, P = 64, 7, 2, 1
# create in tinygrad
layer = Conv2d(C1, C2, kernel_size=K, stride=S, padding=P)
# create in torch
with torch.no_grad():
torch_layer = torch.nn.Conv2d(C1, C2, kernel_size=K, stride=S, padding=P).eval()
torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32)
torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32)
# test
x = Tensor.uniform(BS, C1, H, W)
z = layer(x)
torch_x = torch.tensor(x.numpy())
torch_z = torch_layer(torch_x)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5)
def test_conv1d(self): self._test_conv(Conv1d, torch.nn.Conv1d, BS=4, C1=16, DIMS=[224//4], C2=64, K=7, S=2, P=1)
def test_conv2d(self): self._test_conv(Conv2d, torch.nn.Conv2d, BS=4, C1=16, DIMS=[224//4, 224//4], C2=64, K=7, S=2, P=1)
def test_conv1d_same_padding(self):
BS, C1, W = 8, 3, 32
C2, K, S, P = 16, 3, 1, 'same'
# create in tinygrad
layer = Conv1d(C1, C2, kernel_size=K, stride=S, padding=P)
# create in torch
with torch.no_grad():
torch_layer = torch.nn.Conv1d(C1, C2, kernel_size=K, stride=S, padding=P).eval()
torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32)
torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32)
# test
x = Tensor.uniform(BS, C1, W)
z = layer(x)
torch_x = torch.tensor(x.numpy())
torch_z = torch_layer(torch_x)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5)
def _run_conv2d_same_padding_test(self, BS, C1, C2, H, W, K, S, padding='same', D=1):
# create in tinygrad
layer = Conv2d(C1, C2, kernel_size=K, stride=S, padding=padding, dilation=D)
# create in torch
with torch.no_grad():
torch_layer = torch.nn.Conv2d(C1, C2, kernel_size=K, stride=S, padding=padding, dilation=D).eval()
torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32)
torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32)
# test
x = Tensor.uniform(BS, C1, H, W)
z = layer(x)
torch_x = torch.tensor(x.numpy())
torch_z = torch_layer(torch_x)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5)
self._test_conv(Conv1d, torch.nn.Conv1d, BS=8, C1=3, DIMS=[32], C2=16, K=3, S=1, P='same')
def test_conv2d_same_padding_odd_input(self):
BS, C1, H, W = 16, 16, 29, 31
C2, K, S, P = 32, 5, 1, 'same'
self._run_conv2d_same_padding_test(BS, C1, C2, H, W, K, S, P)
self._test_conv(Conv2d, torch.nn.Conv2d, BS=16, C1=16, DIMS=[29, 31], C2=32, K=5, S=1, P='same')
def test_conv2d_same_padding_large_kernel(self):
BS, C1, H, W = 16, 16, 28, 33
C2, K, S, P = 32, 9, 1, 'same'
self._run_conv2d_same_padding_test(BS, C1, C2, H, W, K, S, P)
self._test_conv(Conv2d, torch.nn.Conv2d, BS=16, C1=16, DIMS=[28, 33], C2=32, K=9, S=1, P='same')
def test_conv2d_same_padding_with_dilation(self):
BS, C1, H, W = 16, 3, 28, 28
C2, K, S, P, D = 32, 3, 1, 'same', 3
self._run_conv2d_same_padding_test(BS, C1, C2, H, W, K, S, P, D)
self._test_conv(Conv2d, torch.nn.Conv2d, BS=16, C1=3, DIMS=[28, 28], C2=32, K=3, S=1, P='same', D=3)
def test_conv2d_same_padding_invalid_stride(self):
C1, C2, K, S, P = 16, 32, 2, 2, 'same'
self.assertRaises(ValueError, Conv2d, C1, C2, kernel_size=K, stride=S, padding=P)
self.assertRaises(ValueError, Conv2d, in_channels=16, out_channels=32, kernel_size=2, stride=2, padding='same')
def test_conv2d_same_padding_invalid_padding_str(self):
C1, C2, K, S, P = 16, 32, 2, 1, 'not_same'
self.assertRaises(ValueError, Conv2d, C1, C2, kernel_size=K, stride=S, padding=P)
self.assertRaises(ValueError, Conv2d, in_channels=16, out_channels=32, kernel_size=2, stride=1, padding='not_same')
@unittest.skip("Takes too long to compile for Compiled backends")
def test_conv2d_winograd(self):
@ -229,12 +163,13 @@ class TestNN(unittest.TestCase):
with Context(WINO=1):
z = layer(x)
m = z.mean()
m.backward()
torch_x = torch.tensor(x.numpy(), requires_grad=True)
torch_z = torch_layer(torch_x)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5)
m = z.mean()
m.backward()
gw = layer.weight.grad.realize()
gb = layer.bias.grad.realize()
gx = x.grad.realize()
@ -245,44 +180,9 @@ class TestNN(unittest.TestCase):
np.testing.assert_allclose(gx.numpy(), torch_x.grad.numpy(), atol=5e-4, rtol=1e-5)
def test_conv_transpose1d(self):
BS, C1, W = 4, 16, 224//4
C2, K, S, P = 64, 7, 2, 1
# create in tinygrad
layer = ConvTranspose1d(C1, C2, kernel_size=K, stride=S, padding=P)
# create in torch
with torch.no_grad():
torch_layer = torch.nn.ConvTranspose1d(C1, C2, kernel_size=K, stride=S, padding=P).eval()
torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32)
torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32)
# test
x = Tensor.uniform(BS, C1, W)
z = layer(x)
torch_x = torch.tensor(x.numpy())
torch_z = torch_layer(torch_x)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5)
self._test_conv(ConvTranspose1d, torch.nn.ConvTranspose1d, BS=4, C1=16, DIMS=[224//4], C2=64, K=7, S=2, P=1)
def test_conv_transpose2d(self):
BS, C1, H, W = 4, 16, 224//4, 224//4
C2, K, S, P = 64, 7, 2, 1
# create in tinygrad
layer = ConvTranspose2d(C1, C2, kernel_size=K, stride=S, padding=P)
# create in torch
with torch.no_grad():
torch_layer = torch.nn.ConvTranspose2d(C1, C2, kernel_size=K, stride=S, padding=P).eval()
torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32)
torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32)
# test
x = Tensor.uniform(BS, C1, H, W)
z = layer(x)
torch_x = torch.tensor(x.numpy())
torch_z = torch_layer(torch_x)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5)
self._test_conv(ConvTranspose2d, torch.nn.ConvTranspose2d, BS=4, C1=16, DIMS=[224//4, 224//4], C2=64, K=7, S=2, P=1)
def test_groupnorm(self):
BS, H, W, C, G = 20, 10, 10, 6, 3

View file

@ -892,13 +892,13 @@ class TestIdxUpcast(unittest.TestCase):
@unittest.skipUnless(is_dtype_supported(dtypes.long), "int64 is supported")
def test_overflow_sym(self):
self.do_op_then_assert(dtypes.long, 2048, 2048, UOp.variable("dim3", 0, 2048).bind(32))
self.do_op_then_assert(dtypes.long, 2048, 2048, UOp.variable("dim3", 1, 2048).bind(32))
def test_regular(self):
self.do_op_then_assert(dtypes.int, 64, 64, 64)
def test_regular_sym(self):
self.do_op_then_assert(dtypes.int, 2048, 2048, UOp.variable("dim3", 0, 64).bind(32))
self.do_op_then_assert(dtypes.int, 2048, 2048, UOp.variable("dim3", 1, 64).bind(32))
@unittest.skipIf(PTX, "PTX always convert Ops.INDEX to int64")
def test_symfold(self):
@ -910,7 +910,7 @@ class TestIdxUpcast(unittest.TestCase):
@unittest.skipIf(is_dtype_supported(dtypes.long), "int64 is supported")
def test_int64_unsupported_overflow_sym(self):
with self.assertRaises(KeyError):
self.do_op_then_assert(dtypes.long, 2048, 2048, UOp.variable("dim3", 0, 2048).bind(32))
self.do_op_then_assert(dtypes.long, 2048, 2048, UOp.variable("dim3", 1, 2048).bind(32))
@unittest.skipIf(is_dtype_supported(dtypes.long), "int64 is supported")
def test_int64_unsupported_overflow(self):

View file

@ -16,6 +16,7 @@ from tinygrad.codegen.devectorizer import load_store_folding, load_store_indexin
ReduceContext, correct_load_store, pm_render
from tinygrad.codegen.optional import get_late_rewrite_patterns
from tinygrad.codegen.linearize import block_create, pm_blockend_merge, block_merge, pm_finalize, BlockContext
from tinygrad.opt import pm_optimize
@dataclass
class RewriteStep:
@ -42,6 +43,10 @@ def get_rewrites_for_renderer(opts:Renderer, linearizer:bool=True) -> list[Rewri
def _get_rewrites_for_renderer(opts:Renderer, linearizer:bool, _QUANTIZE, _DEVECTORIZE, _TRANSCENDENTAL) -> list[RewriteStep]:
# ** lowerer (rewrite_shapetracker_with_index) **
ret: list[RewriteStep] = []
# this is kernel.py
ret.append(RewriteStep(pm_optimize, ctx=lambda _: opts, name="optimize ast"))
if _QUANTIZE and opts.device in {"CPU", "DSP"}: ret.append(RewriteStep(pm_quant, name="quantize"))
ret.append(RewriteStep(pm_lowerer, get_index, name="lowerer", bottom_up=True))

View file

@ -7,9 +7,7 @@ from tinygrad.uop.ops import Ops, PatternMatcher, UOp, UPat, Variable, sym_infer
from tinygrad.device import Device, Buffer
from tinygrad.renderer import Renderer, ProgramSpec, Estimates
from tinygrad.engine.schedule import ScheduleItem
from tinygrad.opt import get_optimized_ast
from tinygrad.codegen import full_rewrite
from tinygrad.uop.spec import type_verify
# **************** Program Creation ****************
@ -27,16 +25,13 @@ def get_program(ast:UOp, renderer:Renderer) -> ProgramSpec:
"""
if getenv("VIZ"): graph_rewrite(ast, PatternMatcher([]), name="View Base AST")
modified_ast = get_optimized_ast(ast, renderer) if ast.arg is None or ast.arg.opts_to_apply is not None else ast
if __debug__: type_verify(list(modified_ast.toposort()))
# linearize
try:
uops = full_rewrite(modified_ast, renderer)
uops = full_rewrite(ast, renderer)
except RuntimeError:
print("***** LINEARIZE FAILURE *****")
print(f"ast = {ast}")
print(f"opts = {modified_ast.arg.applied_opts}")
raise
assert uops[-1].op is Ops.SINK, "last uop must be sink"

View file

@ -2,9 +2,10 @@
from tinygrad.opt.kernel import Kernel
from tinygrad.opt.heuristic import hand_coded_optimizations
from tinygrad.uop.ops import UOp
from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops
from tinygrad.helpers import NOOPT, BEAM, USE_TC, getenv
from tinygrad.renderer import Renderer
from tinygrad.uop.spec import type_verify
def get_optimized_ast(ast:UOp, renderer:Renderer) -> UOp:
"""
@ -27,4 +28,11 @@ def get_optimized_ast(ast:UOp, renderer:Renderer) -> UOp:
kb = Kernel(ast, opts=renderer)
rawbufs = bufs_from_lin(kb, allocate=False)
k = beam_search(kb, rawbufs, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1)))
return k.get_optimized_ast()
ret = k.get_optimized_ast()
if __debug__: type_verify(list(ret.toposort()))
return ret
pm_optimize = PatternMatcher([
(UPat(Ops.SINK, name="ast"), lambda ctx,ast:
get_optimized_ast(ast, ctx) if (ast.arg is None or ast.arg.opts_to_apply is not None) and ast.src[0].st is not None else None),
])

View file

@ -14,7 +14,7 @@ from tinygrad.dtype import ImageDType, AddrSpace
from tinygrad.helpers import all_same, colored, ansilen, dedup, prod, round_up, to_function_name, unwrap, argfix, DEBUG, TC_SELECT, TC_OPT, AMX
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import strides_for_shape, get_contraction
from tinygrad.schedule.kernelize import view_left
from tinygrad.opt.swizzler import view_left, view_left_through_load
class OptOps(Enum):
TC = auto(); UPCAST = auto(); UNROLL = auto(); LOCAL = auto() # noqa: E702
@ -503,4 +503,4 @@ class Kernel:
self.finalized = True
fixed_ast = fixup_ast(self.ast)
del fixup_ast
return graph_rewrite(fixed_ast, view_left, name="fixup optimized AST")
return graph_rewrite(fixed_ast, view_left+view_left_through_load, name="fixup optimized AST")

137
tinygrad/opt/swizzler.py Normal file
View file

@ -0,0 +1,137 @@
from tinygrad.uop.ops import UOp, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, resolve, sint
from tinygrad.helpers import all_same, prod, unwrap, colored
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View, strides_for_shape, get_contraction_with_reduce
from tinygrad.schedule.grouper import ALWAYS_CONTIGUOUS
from tinygrad.dtype import ImageDType, dtypes
merge_views = PatternMatcher([
# merge adjacent views
(UPat(Ops.VIEW, src=(UPat(Ops.VIEW, name="v1"),), name="v2"), lambda v1,v2: v1.replace(arg=v1.arg+v2.arg)),
# replace MovementOps with VIEW
(UPat(GroupOp.Movement, src=(UPat.var("x"),), name="mop"), lambda mop,x: x.base.view(mop.st)),
# remove NOOP views
(UPat.var("x").view(name="view"),
lambda x,view: x if x.st is not None and x.op not in GroupOp.Defines and view.st.contiguous and view.shape == x.shape else None),
(UPat(GroupOp.All-{Ops.DEFINE_GLOBAL}).view(name="view"),
lambda view: view.const_like(0) if (mask:=view.st.views[-1].mask) is not None and any((x[1]-x[0]) == 0 for x in mask) else None),
# only unmaksed VIEW on CONST replaces the ShapeTracker
(UPat(Ops.VIEW, src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="x"),), name="view"),
lambda x,view: x.replace(src=(x.src[0].replace(arg=x.st+view.st),)) if all(v.mask is None for v in (x.st+view.st).views) else None),
])
def reduce_push_add_ones(src:UOp, r:UOp, view:UOp):
# contiguous, expand, and the same with ones removed
if unwrap(view.st).contiguous and len(r.shape) < len(view.shape) and \
tuple(x for x in r.shape if resolve(x != 1)) == tuple(x for x in view.shape if resolve(x != 1)):
new_shape: list[sint] = []
new_reduce_axis = []
if (contraction:=get_contraction_with_reduce(view.shape, r.shape, r.arg[1])) is None: return None
for i,pairs in enumerate(contraction):
new_shape_chunk = [view.shape[p] for p in pairs]
if i in r.arg[1]:
# if this is a reduce axis, we need a 1 in the view here to put it
assert len(new_shape_chunk) > 0
new_shape += [1]*(len(pairs)-1) + [src.shape[i]]
new_reduce_axis.append(len(new_shape)-1)
else:
# otherwise, pass through the new_shape_chunk
new_shape += new_shape_chunk
ret = r.replace(src=(src.reshape(tuple(new_shape)),), arg=(r.arg[0], tuple(new_reduce_axis))+r.arg[2:])
assert ret.shape == view.shape, f"shape mismatch on reduce_push_add_ones, {ret.shape} != {view.shape}"
return ret
return None
view_left = merge_views+PatternMatcher([
# view before elementwise and buffer ops
(UPat(Ops.VIEW, src=(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.BIND, Ops.STORE, Ops.VALID, Ops.SINK}, name="e"),), name="view"),
lambda e,view: e.replace(src=tuple(s.view(view.st) for s in e.src))),
# if there's ones added after reduce, put this before the reduce
(UPat(Ops.VIEW, src=(UPat(Ops.REDUCE_AXIS, src=(UPat.var("src"),), name="r"),), name="view"), reduce_push_add_ones),
])
view_left_through_load = PatternMatcher([
# view before load
(UPat(Ops.VIEW, src=(UPat(Ops.LOAD, name="e"),), name="view"),
lambda e,view: e.replace(src=tuple(s.view(view.st) for s in e.src))),
])
def apply_swizzle(u:UOp) -> UOp: return graph_rewrite(u, view_left, name="Sub View Left")
# change reduceop axes and input ShapeTrackers, view gets replaced with a reshape.
def swizzle_reduceop(r:UOp, src:UOp, view:UOp, fuse=False):
# contiguous and same size can push to children
# if there's a reduce child, shapes match with ones removed
if unwrap(view.st).contiguous and view.size == r.size and \
(not (len(r.arg) == 3 and r.arg[2]) or # arg[2] = True is fuse marker
tuple((i,x) for i,x in enumerate(r.shape) if resolve(x != 1)) == tuple((i,x) for i,x in enumerate(view.shape) if resolve(x != 1))):
return None
# swizzle the input
input_st = ShapeTracker.from_shape(src.shape)
tmp = input_st.permute(tuple(i for i in range(len(input_st.shape)) if i not in r.axis_arg)+r.axis_arg)
prshape = prod(rshape:=tmp.shape[-len(r.axis_arg):])
strides = strides_for_shape(rshape)
nv = [View.create(v.shape+rshape, tuple(x*prshape for x in v.strides)+strides,
v.offset*prshape, v.mask+tuple((0,s) for s in rshape) if v.mask is not None else None) for v in unwrap(view.st).views]
new_view = tmp + ShapeTracker(tuple(nv))
swizzled_input = apply_swizzle(src.view(new_view))
# create a new reduceop
new_axis = tuple(range(len(view.shape), len(view.shape) + len(r.axis_arg)))
if fuse: red = UOp(Ops.REDUCE_AXIS, r.dtype, (swizzled_input.fuse(),), (r.arg[0], new_axis, True))
else: red = UOp(Ops.REDUCE_AXIS, r.dtype, (swizzled_input,), (r.arg[0], new_axis))
return red.reshape(view.shape)
def reduceop_view_right(src:UOp, v:UOp, r:UOp):
assert unwrap(v.st).contiguous and v.size == src.size, f"can't compute new axis for {src.shape} -> {r.shape}"
new_axis = [i for i,(s,u) in enumerate(zip(src.shape, r.shape)) if s != u]
return src.r(r.arg[0], tuple(new_axis)).reshape(r.shape)
def elementwise_view_right(root:UOp):
if not (swizzles:=[x for x in root.src if x.op is Ops.VIEW and x.base.op not in ALWAYS_CONTIGUOUS]): return None
assert all_same([x.base.size for x in swizzles]), f"swizzle inputs must have the same size {swizzles}"
# place view after applying the elementwise op
new_st = ShapeTracker.from_shape(swizzles[0].base.shape)
new_src = [x.base if x.base.shape==new_st.shape else apply_swizzle(x.view(new_st)) for x in root.src]
# reshape to match downstream shapes
return root.replace(src=tuple(new_src)).reshape(root.shape)
# push VIEW to children
view_right = merge_views+PatternMatcher([
# push a non contiguous ShapeTracker through reduceop
(UPat(Ops.VIEW, src=(UPat(Ops.REDUCE_AXIS, src=(UPat.var("src"),), name="r"),), name="view"), swizzle_reduceop),
# apply view after reduceops
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.VIEW, src=(UPat(GroupOp.All-ALWAYS_CONTIGUOUS, name="src"),), name="v"),), name="r"), reduceop_view_right),
# apply view after elementwise ops
(UPat(GroupOp.All-{Ops.SINK, Ops.REDUCE_AXIS, Ops.STORE}, name="root"), elementwise_view_right),
# merge axes for double reduce (invert of SPLIT_REDUCEOP=1)
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="r1"),), name="r2"),
lambda r1,r2: r1.replace(arg=(r1.arg[0], r2.arg[1]+r1.arg[1])) if r1.arg[0] is r2.arg[0] else None),
])
def check_load_st(glbl:UOp, view:UOp):
if glbl.arg != 0 or (st:=unwrap(view.st)).contiguous: return
# if it has a single view and it becomes contiguous when you shrink expanded axes, it's fine
if len(st.views) == 1 and st.shrink(tuple((0,1) if st == 0 else (0,s) for s,st in zip(st.shape, st.views[0].strides))).contiguous: return
# if it has a single view and it's equal when you shrink a contig, it's fine
if len(st.views) == 1 and (mask:=st.views[0].mask) is not None and ShapeTracker.from_shape(st.shape).shrink(mask) == st.shrink(mask): return
# otherwise, it's not fine
raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
fix_kernel_ops = view_left_through_load+PatternMatcher([
# STORE (except for meta ops)
(UPat(Ops.SINK, src=UPat(GroupOp.All-{Ops.STORE}), name="sink"), lambda sink:
UOp.sink(*[UOp.store(UOp(Ops.DEFINE_GLOBAL, (s:=x.base).dtype.ptr(s.st.real_size()), (), i).view(s.st), s) for i,x in enumerate(sink.src)])),
# passthrough ASSIGN
(UPat(Ops.ASSIGN, name="x"), lambda x: x.src[1]),
# VALID
(UPat(Ops.VIEW, src=(UPat.cvar(),), name="self"),
lambda self: UOp.where(UOp(Ops.VALID, dtypes.bool, (UOp(Ops.VIEW, arg=self.st),)), self.const_like(self.base.arg), 0)),
# remove CONTIGUOUS/DEVICE from kernel AST
(UPat((Ops.CONTIGUOUS, Ops.MSELECT), src=(UPat.var("x"),)), lambda x: x),
(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),), name="view"), lambda view: view.replace(src=())),
# no ImageDType after index
(UPat(GroupOp.All-{Ops.DEFINE_GLOBAL, Ops.VIEW}, name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None),
# if this kernel also assigns to the loaded buffer, ensure we can index it correctly
(UPat(Ops.LOAD, src=(UPat.var("glbl").view(name="view"),)), check_load_st),
])

View file

@ -66,7 +66,7 @@ class NVPageTableEntry:
return self.read_fields(entry_id)[f'address{small}{sys}'] << 12
class NVMemoryManager(MemoryManager):
va_allocator = TLSFAllocator((1 << 44), base=1 << 30) # global for all devices.
va_allocator = TLSFAllocator((1 << 44), base=0x1000000000) # global for all devices.
def on_range_mapped(self): self.dev.NV_VIRTUAL_FUNCTION_PRIV_MMU_INVALIDATE.write((1 << 0) | (1 << 1) | (1 << 6) | (1 << 31))

View file

@ -3,7 +3,7 @@ from tinygrad.helpers import all_int, prod, unwrap, dedup, DONT_REALIZE_EXPAND,
from tinygrad.shape.shapetracker import ShapeTracker
ALWAYS_CONTIGUOUS = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW,
Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK}
Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.DEFINE_GLOBAL, Ops.LOAD}
# **** Grouper decides which of the UOps realize

View file

@ -1,14 +1,13 @@
from dataclasses import dataclass
from tinygrad.uop.ops import UOp, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, identity_element, resolve, sint
from tinygrad.uop.ops import UOp, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, identity_element, resolve
from tinygrad.uop.ops import track_rewrites, _substitute
from tinygrad.uop.spec import type_verify, tensor_uop_spec
from tinygrad.uop.symbolic import symbolic_simple
from tinygrad.helpers import Metadata, all_int, all_same, colored, prod, dedup, unwrap, getenv, pluralize, FUSE_ARANGE, DEBUG, SPLIT_REDUCEOP
from tinygrad.dtype import ImageDType, dtypes
from tinygrad.helpers import Metadata, all_int, all_same, prod, dedup, unwrap, getenv, pluralize, FUSE_ARANGE, DEBUG, SPLIT_REDUCEOP
from tinygrad.dtype import ImageDType
from tinygrad.schedule.multi import multi_pm
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View, strides_for_shape, get_contraction_with_reduce
from tinygrad.schedule.grouper import group_realizes, ALWAYS_CONTIGUOUS
from tinygrad.opt.swizzler import merge_views, view_left, view_right, fix_kernel_ops, apply_swizzle, swizzle_reduceop
# creation can recurse a lot
import sys
@ -148,139 +147,13 @@ create_kernels = PatternMatcher([
lambda ms: UOp(Ops.MSTACK, ms.dtype, tuple(x.src[0] for x in ms.src)).reshape(ms.src[0].arg)),
])
# **** swizzler
merge_views = PatternMatcher([
# merge adjacent views
(UPat(Ops.VIEW, src=(UPat(Ops.VIEW, name="v1"),), name="v2"), lambda v1,v2: v1.replace(arg=v1.arg+v2.arg)),
# replace MovementOps with VIEW
(UPat(GroupOp.Movement, src=(UPat.var("x"),), name="mop"), lambda mop,x: x.base.view(mop.st)),
# remove NOOP views
(UPat.var("x").view(name="view"),
lambda x,view: x if x.st is not None and x.op not in GroupOp.Defines and view.st.contiguous and view.shape == x.shape else None),
(UPat(GroupOp.All-{Ops.DEFINE_GLOBAL}).view(name="view"),
lambda view: view.const_like(0) if (mask:=view.st.views[-1].mask) is not None and any((x[1]-x[0]) == 0 for x in mask) else None),
# only unmaksed VIEW on CONST replaces the ShapeTracker
(UPat(Ops.VIEW, src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="x"),), name="view"),
lambda x,view: x.replace(src=(x.src[0].replace(arg=x.st+view.st),)) if all(v.mask is None for v in (x.st+view.st).views) else None),
])
def reduce_push_add_ones(src:UOp, r:UOp, view:UOp):
# contiguous, expand, and the same with ones removed
if unwrap(view.st).contiguous and len(r.shape) < len(view.shape) and \
tuple(x for x in r.shape if resolve(x != 1)) == tuple(x for x in view.shape if resolve(x != 1)):
new_shape: list[sint] = []
new_reduce_axis = []
if (contraction:=get_contraction_with_reduce(view.shape, r.shape, r.arg[1])) is None: return None
for i,pairs in enumerate(contraction):
new_shape_chunk = [view.shape[p] for p in pairs]
if i in r.arg[1]:
# if this is a reduce axis, we need a 1 in the view here to put it
assert len(new_shape_chunk) > 0
new_shape += [1]*(len(pairs)-1) + [src.shape[i]]
new_reduce_axis.append(len(new_shape)-1)
else:
# otherwise, pass through the new_shape_chunk
new_shape += new_shape_chunk
ret = r.replace(src=(src.reshape(tuple(new_shape)),), arg=(r.arg[0], tuple(new_reduce_axis))+r.arg[2:])
assert ret.shape == view.shape, f"shape mismatch on reduce_push_add_ones, {ret.shape} != {view.shape}"
return ret
return None
view_left = merge_views+PatternMatcher([
# view before elementwise and buffer ops
(UPat(Ops.VIEW, src=(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.BIND, Ops.LOAD, Ops.STORE, Ops.VALID, Ops.SINK}, name="e"),), name="view"),
lambda e,view: e.replace(src=tuple(s.view(view.st) for s in e.src))),
# if there's ones added after reduce, put this before the reduce
(UPat(Ops.VIEW, src=(UPat(Ops.REDUCE_AXIS, src=(UPat.var("src"),), name="r"),), name="view"), reduce_push_add_ones),
])
def apply_swizzle(u:UOp) -> UOp: return graph_rewrite(u, view_left, name="Sub View Left")
# change reduceop axes and input ShapeTrackers, view gets replaced with a reshape.
def swizzle_reduceop(r:UOp, src:UOp, view:UOp, fuse=False):
# contiguous and same size can push to children
# if there's a reduce child, shapes match with ones removed
if unwrap(view.st).contiguous and view.size == r.size and \
(not (len(r.arg) == 3 and r.arg[2]) or # arg[2] = True is fuse marker
tuple((i,x) for i,x in enumerate(r.shape) if resolve(x != 1)) == tuple((i,x) for i,x in enumerate(view.shape) if resolve(x != 1))):
return None
# swizzle the input
input_st = ShapeTracker.from_shape(src.shape)
tmp = input_st.permute(tuple(i for i in range(len(input_st.shape)) if i not in r.axis_arg)+r.axis_arg)
prshape = prod(rshape:=tmp.shape[-len(r.axis_arg):])
strides = strides_for_shape(rshape)
nv = [View.create(v.shape+rshape, tuple(x*prshape for x in v.strides)+strides,
v.offset*prshape, v.mask+tuple((0,s) for s in rshape) if v.mask is not None else None) for v in unwrap(view.st).views]
new_view = tmp + ShapeTracker(tuple(nv))
swizzled_input = apply_swizzle(src.view(new_view))
# create a new reduceop
new_axis = tuple(range(len(view.shape), len(view.shape) + len(r.axis_arg)))
if fuse: red = UOp(Ops.REDUCE_AXIS, r.dtype, (swizzled_input.fuse(),), (r.arg[0], new_axis, True))
else: red = UOp(Ops.REDUCE_AXIS, r.dtype, (swizzled_input,), (r.arg[0], new_axis))
return red.reshape(view.shape)
def reduceop_view_right(src:UOp, v:UOp, r:UOp):
assert unwrap(v.st).contiguous and v.size == src.size, f"can't compute new axis for {src.shape} -> {r.shape}"
new_axis = [i for i,(s,u) in enumerate(zip(src.shape, r.shape)) if s != u]
return src.r(r.arg[0], tuple(new_axis)).reshape(r.shape)
def elementwise_view_right(root:UOp):
if not (swizzles:=[x for x in root.src if x.op is Ops.VIEW and x.base.op not in ALWAYS_CONTIGUOUS]): return None
assert all_same([x.base.size for x in swizzles]), f"swizzle inputs must have the same size {swizzles}"
# place view after applying the elementwise op
new_st = ShapeTracker.from_shape(swizzles[0].base.shape)
new_src = [x.base if x.base.shape==new_st.shape else apply_swizzle(x.view(new_st)) for x in root.src]
# reshape to match downstream shapes
return root.replace(src=tuple(new_src)).reshape(root.shape)
# push VIEW to children
view_right = merge_views+PatternMatcher([
# push a non contiguous ShapeTracker through reduceop
(UPat(Ops.VIEW, src=(UPat(Ops.REDUCE_AXIS, src=(UPat.var("src"),), name="r"),), name="view"), swizzle_reduceop),
# apply view after reduceops
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.VIEW, src=(UPat(GroupOp.All-ALWAYS_CONTIGUOUS, name="src"),), name="v"),), name="r"), reduceop_view_right),
# apply view after elementwise ops
(UPat(GroupOp.All-{Ops.SINK, Ops.REDUCE_AXIS}, name="root"), elementwise_view_right),
# merge axes for double reduce (invert of SPLIT_REDUCEOP=1)
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="r1"),), name="r2"),
lambda r1,r2: r1.replace(arg=(r1.arg[0], r2.arg[1]+r1.arg[1])) if r1.arg[0] is r2.arg[0] else None),
])
# **** fix kernel AST
add_buffer_ops = PatternMatcher([
early_buffer_ops = PatternMatcher([
# LOAD
(UPat(Ops.BUFFER, name="x"), lambda ctx,x: UOp.load(UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), ctx.index(x)).view(x.st),)),
# STORE (except for meta ops)
(UPat(Ops.BUFFER, name="x"), lambda ctx,x: UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), ctx.index(x)).view(x.st).load()),
# no SINK for meta ops
(UPat(Ops.SINK, src=(UPat(Ops.CONTIGUOUS, src=(UPat(GroupOp.Meta, name="x"),),))), lambda x:x),
(UPat(Ops.SINK, src=UPat(GroupOp.All-{Ops.STORE}), name="sink"), lambda ctx,sink:
UOp.sink(*[UOp.store(UOp(Ops.DEFINE_GLOBAL, (s:=x.base).dtype.ptr(ctx[i].size), (), i).view(s.st), s) for i,x in enumerate(sink.src)])),
# passthrough ASSIGN
(UPat(Ops.ASSIGN, name="x"), lambda x: x.src[1]),
# VALID
(UPat(Ops.VIEW, src=(UPat.cvar(),), name="self"),
lambda self: UOp.where(UOp(Ops.VALID, dtypes.bool, (UOp(Ops.VIEW, arg=self.st),)), self.const_like(self.base.arg), 0)),
])
def check_load_st(glbl:UOp, view:UOp):
if glbl.arg != 0 or (st:=unwrap(view.st)).contiguous: return
# if it has a single view and it becomes contiguous when you shrink expanded axes, it's fine
if len(st.views) == 1 and st.shrink(tuple((0,1) if st == 0 else (0,s) for s,st in zip(st.shape, st.views[0].strides))).contiguous: return
# if it has a single view and it's equal when you shrink a contig, it's fine
if len(st.views) == 1 and (mask:=st.views[0].mask) is not None and ShapeTracker.from_shape(st.shape).shrink(mask) == st.shrink(mask): return
# otherwise, it's not fine
raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
fix_kernel_ops = PatternMatcher([
# remove CONTIGUOUS/DEVICE from kernel AST
(UPat((Ops.CONTIGUOUS, Ops.MSELECT), src=(UPat.var("x"),)), lambda x: x),
(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),), name="view"), lambda view: view.replace(src=())),
# no ImageDType after index
(UPat(GroupOp.All-{Ops.DEFINE_GLOBAL, Ops.VIEW}, name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None),
# if this kernel also assigns to the loaded buffer, ensure we can index it correctly
(UPat(Ops.LOAD, src=(UPat.var("glbl").view(name="view"),)), check_load_st),
])
replace_globals = PatternMatcher([
@ -292,10 +165,6 @@ replace_globals = PatternMatcher([
def fix_kernel_ast(k:UOp) -> UOp|None:
if k.arg.ast.op in GroupOp.Meta or all(s.op is Ops.STORE for s in k.arg.ast.src): return None
# replace global memory ops with the BUFFER they write to
ast = graph_rewrite(k.arg.ast, replace_globals, bottom_up=True, name="replace globals")
# push views to edges
ast = graph_rewrite(graph_rewrite(ast, view_left, name="Main View Left"), view_right, name="Main View Right")
# replace buffer with define_global + add load/store last
bufs = []
for s in k.src:
@ -303,9 +172,15 @@ def fix_kernel_ast(k:UOp) -> UOp|None:
# traverse back through MSELECT and MSTACK. HACK: 0 branch of MSTACK only
while s.op in {Ops.MSELECT, Ops.MSTACK}: s = s.src[0]
bufs.append(s)
ast = graph_rewrite(ast, view_left+add_buffer_ops+fix_kernel_ops, bufs, bottom_up=True, name="replace buffer")
# replace global memory ops with the BUFFER they write to
ast = graph_rewrite(k.arg.ast, replace_globals, bottom_up=True, name="replace globals")
ast = graph_rewrite(ast, early_buffer_ops, bufs, bottom_up=True, name="replace buffer early")
if ast.op is Ops.SINK and not all_same([x.device for x in k.src]):
raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop.buffer for b in k.src)}")
# TODO: move these to codegen
ast = graph_rewrite(ast, view_left, name="Main View Left")
ast = graph_rewrite(ast, view_right, name="Main View Right")
ast = graph_rewrite(ast, view_left+fix_kernel_ops, bottom_up=True, name="replace buffer")
return k.replace(arg=Kernel(ast, k.arg.metadata))
create_ast = PatternMatcher([(UPat(Ops.KERNEL, name="k"), fix_kernel_ast),])
@ -441,8 +316,6 @@ def get_kernelize_map(sink:UOp) -> dict[UOp, UOp]:
tensor_map = graph_rewrite_map(tensor_map[sink], add_contiguous, ctx=realize_map, bottom_up=True, input_map=tensor_map, name="add_contiguous")
tensor_map = graph_rewrite_map(tensor_map[sink], finalize_contiguous+remove_tags, input_map=tensor_map, name="finalize_contiguous")
# TODO: move view_left/view_right here
# group into kernels (this is context-free)
tensor_map = graph_rewrite_map(tensor_map[sink], create_kernels, input_map=tensor_map, name="create_kernels")

View file

@ -2234,7 +2234,7 @@ class Tensor(MathTrait):
"""
def parse_formula(formula:str, *operands:Tensor):
if "..." in (formula := formula.replace(" ", "")):
ell_chars, ell_longest = "".join(set(string.ascii_letters) - set(formula)), 0
ell_chars, ell_longest = "".join(c for c in string.ascii_letters if c not in formula), 0
for i, inp in enumerate(filter(lambda x: "..." in x, inputs := formula.split("->")[0].split(","))):
if (ell_count := max(operands[i].ndim, 1) - (len(inp) - len("..."))) > ell_longest: ell_longest = ell_count
inputs[i] = inp.replace("...", ell_chars[-ell_count:])

View file

@ -440,7 +440,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
all_vars = set([x for x in self.toposort() if x.op is Ops.DEFINE_VAR])
return bound_vars.union(set([x for x in all_vars if x not in bound_var_base]))
def variables(self) -> list[Variable]:
st_vars: list[set[Variable]] = [x.st_arg.vars() for x in self.toposort() if x.op in GroupOp.Buffer]
st_vars: list[set[Variable]] = [x.arg.vars() for x in self.toposort() if x.op is Ops.VIEW]
return sorted(set.union(*st_vars, set([x.unbind()[0] if x.op is not Ops.DEFINE_VAR else x for x in self.vars()])), key=lambda v: v.arg)
# *** uop symbolic stuff ***

View file

@ -132,7 +132,8 @@ def timeline_layout(events:list[tuple[int, int, float, DevEvent]]) -> dict:
name, cat, info = e.name, None, None
if (ref:=ref_map.get(name)) is not None:
name = ctxs[ref]["name"]
if isinstance(p:=contexts[0][ref].ret, ProgramSpec):
# TODO: support symbolic by capturing var_vals in profile events
if isinstance(p:=contexts[0][ref].ret, ProgramSpec) and all(isinstance(es,int) for es in [p.estimates.ops, p.estimates.mem, p.estimates.lds]):
info = f"{p.estimates.ops/(t:=dur*1e3):.2f} GFLOPS {p.estimates.mem/t:4.1f}|{p.estimates.lds/t:.1f} GB/s"
elif isinstance(e.name, TracingKey):
name, cat = e.name.display_name, e.name.cat