mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
3 commits
master
...
move_sched
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
150ee9eb6d | ||
|
|
5dcdfb0d75 | ||
|
|
bb69860d41 |
25 changed files with 206 additions and 206 deletions
|
|
@ -38,7 +38,7 @@ optim.schedule_step() # this will step the optimizer without running realize
|
||||||
# The weight Tensors have been assigned to, but not yet realized. Everything is still lazy at this point
|
# The weight Tensors have been assigned to, but not yet realized. Everything is still lazy at this point
|
||||||
# l1.uop and l2.uop define a computation graph
|
# l1.uop and l2.uop define a computation graph
|
||||||
|
|
||||||
from tinygrad.engine.schedule import ExecItem
|
from tinygrad.schedule import ExecItem
|
||||||
schedule: List[ExecItem] = Tensor.schedule(l1, l2)
|
schedule: List[ExecItem] = Tensor.schedule(l1, l2)
|
||||||
|
|
||||||
print(f"The schedule contains {len(schedule)} items.")
|
print(f"The schedule contains {len(schedule)} items.")
|
||||||
|
|
|
||||||
|
|
@ -17,9 +17,9 @@ The `UOp` graph specifies the compute in terms of low level tinygrad ops. Not al
|
||||||
|
|
||||||
## Scheduling
|
## Scheduling
|
||||||
|
|
||||||
The [scheduler](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/engine/schedule.py) converts the graph of UOps into a list of `ExecItem`. One `ExecItem` is one kernel on the GPU, and the scheduler is responsible for breaking the large compute graph into subgraphs that can fit in a kernel. `ast` specifies what compute to run, and `bufs` specifies what buffers to run it on.
|
The [scheduler](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/schedule/__init__.py) converts the graph of UOps into a list of `ExecItem`. One `ExecItem` is one kernel on the GPU, and the scheduler is responsible for breaking the large compute graph into subgraphs that can fit in a kernel. `ast` specifies what compute to run, and `bufs` specifies what buffers to run it on.
|
||||||
|
|
||||||
::: tinygrad.engine.schedule.ExecItem
|
::: tinygrad.schedule.ExecItem
|
||||||
|
|
||||||
## Lowering
|
## Lowering
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ from tinygrad import Device, nn, Tensor, dtypes
|
||||||
from train_gpt2 import GPT, GPTConfig
|
from train_gpt2 import GPT, GPTConfig
|
||||||
from tinygrad.helpers import DEV, dedup, flatten, getenv, GlobalCounters, to_function_name
|
from tinygrad.helpers import DEV, dedup, flatten, getenv, GlobalCounters, to_function_name
|
||||||
from tinygrad.engine.realize import get_kernel
|
from tinygrad.engine.realize import get_kernel
|
||||||
from tinygrad.engine.memory import memory_planner
|
from tinygrad.schedule.memory import memory_planner
|
||||||
from tinygrad.uop.ops import Ops
|
from tinygrad.uop.ops import Ops
|
||||||
|
|
||||||
DEV.value = "CPU"
|
DEV.value = "CPU"
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ from tinygrad import Tensor, GlobalCounters, dtypes, nn, Device, Variable
|
||||||
from tinygrad.helpers import Context, getenv, DEV
|
from tinygrad.helpers import Context, getenv, DEV
|
||||||
from tinygrad.engine.realize import run_schedule
|
from tinygrad.engine.realize import run_schedule
|
||||||
from tinygrad.engine.realize import CompiledRunner, get_program
|
from tinygrad.engine.realize import CompiledRunner, get_program
|
||||||
from tinygrad.engine.schedule import ExecItem
|
from tinygrad.schedule import ExecItem
|
||||||
from tinygrad.renderer import Estimates
|
from tinygrad.renderer import Estimates
|
||||||
from tinygrad.renderer.ptx import PTXRenderer
|
from tinygrad.renderer.ptx import PTXRenderer
|
||||||
from test.helpers import needs_second_gpu
|
from test.helpers import needs_second_gpu
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ from tinygrad.helpers import Context, dedup, from_mv
|
||||||
from tinygrad.dtype import dtypes
|
from tinygrad.dtype import dtypes
|
||||||
from tinygrad.engine.jit import MultiGraphRunner
|
from tinygrad.engine.jit import MultiGraphRunner
|
||||||
from tinygrad.engine.realize import BufferXfer, get_runner, CompiledRunner
|
from tinygrad.engine.realize import BufferXfer, get_runner, CompiledRunner
|
||||||
from tinygrad.engine.schedule import ExecItem
|
from tinygrad.schedule import ExecItem
|
||||||
from tinygrad.uop.ops import UOp, Ops
|
from tinygrad.uop.ops import UOp, Ops
|
||||||
|
|
||||||
from test.helpers import needs_second_gpu
|
from test.helpers import needs_second_gpu
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ from tinygrad import Tensor, Device
|
||||||
from tinygrad.helpers import get_single_element
|
from tinygrad.helpers import get_single_element
|
||||||
from tinygrad.codegen.opt import Opt, OptOps
|
from tinygrad.codegen.opt import Opt, OptOps
|
||||||
from tinygrad.engine.realize import CompiledRunner, get_program
|
from tinygrad.engine.realize import CompiledRunner, get_program
|
||||||
from tinygrad.engine.schedule import ExecItem
|
from tinygrad.schedule import ExecItem
|
||||||
|
|
||||||
class TestOptGemm(unittest.TestCase):
|
class TestOptGemm(unittest.TestCase):
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ from tinygrad import Tensor, Context, Device, dtypes
|
||||||
from tinygrad.uop.ops import Ops
|
from tinygrad.uop.ops import Ops
|
||||||
from tinygrad.codegen.opt import Opt, OptOps
|
from tinygrad.codegen.opt import Opt, OptOps
|
||||||
from tinygrad.engine.realize import CompiledRunner, get_program
|
from tinygrad.engine.realize import CompiledRunner, get_program
|
||||||
from tinygrad.engine.schedule import ExecItem
|
from tinygrad.schedule import ExecItem
|
||||||
|
|
||||||
N = 512
|
N = 512
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ from tinygrad.device import Buffer, Device
|
||||||
from tinygrad.uop.ops import Ops, UOp, KernelInfo, AxisType
|
from tinygrad.uop.ops import Ops, UOp, KernelInfo, AxisType
|
||||||
from tinygrad.renderer.cstyle import CStyleLanguage
|
from tinygrad.renderer.cstyle import CStyleLanguage
|
||||||
from tinygrad.engine.realize import CompiledRunner, get_program, get_runner
|
from tinygrad.engine.realize import CompiledRunner, get_program, get_runner
|
||||||
from tinygrad.engine.schedule import ExecItem
|
from tinygrad.schedule import ExecItem
|
||||||
from tinygrad.device import is_dtype_supported
|
from tinygrad.device import is_dtype_supported
|
||||||
from tinygrad.codegen.opt import Opt, OptOps
|
from tinygrad.codegen.opt import Opt, OptOps
|
||||||
from tinygrad.renderer.ptx import PTXRenderer
|
from tinygrad.renderer.ptx import PTXRenderer
|
||||||
|
|
|
||||||
2
test/external/external_test_amd.py
vendored
2
test/external/external_test_amd.py
vendored
|
|
@ -1,6 +1,6 @@
|
||||||
import unittest
|
import unittest
|
||||||
from tinygrad import Device, Tensor
|
from tinygrad import Device, Tensor
|
||||||
from tinygrad.engine.schedule import create_schedule
|
from tinygrad.schedule import create_schedule
|
||||||
from tinygrad.runtime.ops_amd import AMDDevice
|
from tinygrad.runtime.ops_amd import AMDDevice
|
||||||
|
|
||||||
class TestAMD(unittest.TestCase):
|
class TestAMD(unittest.TestCase):
|
||||||
|
|
|
||||||
2
test/external/external_test_hip_compile.py
vendored
2
test/external/external_test_hip_compile.py
vendored
|
|
@ -2,7 +2,7 @@ import time, unittest
|
||||||
from tinygrad.runtime.support.hip_comgr import compile_hip
|
from tinygrad.runtime.support.hip_comgr import compile_hip
|
||||||
from tinygrad import Tensor
|
from tinygrad import Tensor
|
||||||
from tinygrad.device import Device
|
from tinygrad.device import Device
|
||||||
from tinygrad.engine.schedule import create_schedule
|
from tinygrad.schedule import create_schedule
|
||||||
from tinygrad.codegen.opt.kernel import Kernel
|
from tinygrad.codegen.opt.kernel import Kernel
|
||||||
|
|
||||||
class TestHIPCompileSpeed(unittest.TestCase):
|
class TestHIPCompileSpeed(unittest.TestCase):
|
||||||
|
|
|
||||||
2
test/external/external_test_opt.py
vendored
2
test/external/external_test_opt.py
vendored
|
|
@ -7,7 +7,7 @@ from tinygrad import GlobalCounters, Tensor, Device
|
||||||
from tinygrad.helpers import getenv
|
from tinygrad.helpers import getenv
|
||||||
from tinygrad.nn.state import get_parameters
|
from tinygrad.nn.state import get_parameters
|
||||||
from tinygrad.engine.realize import capturing, run_schedule
|
from tinygrad.engine.realize import capturing, run_schedule
|
||||||
from tinygrad.engine.schedule import linear_to_schedule
|
from tinygrad.schedule import linear_to_schedule
|
||||||
from tinygrad.tensor import _to_np_dtype
|
from tinygrad.tensor import _to_np_dtype
|
||||||
|
|
||||||
class CLCache:
|
class CLCache:
|
||||||
|
|
|
||||||
2
test/external/external_uop_gc.py
vendored
2
test/external/external_uop_gc.py
vendored
|
|
@ -1,6 +1,6 @@
|
||||||
import gc
|
import gc
|
||||||
from tinygrad import Tensor, UOp, Device, nn
|
from tinygrad import Tensor, UOp, Device, nn
|
||||||
from tinygrad.engine.schedule import schedule_cache
|
from tinygrad.schedule import schedule_cache
|
||||||
from tinygrad.engine.realize import method_cache, get_program
|
from tinygrad.engine.realize import method_cache, get_program
|
||||||
from tinygrad.schedule.indexing import apply_movement_op, _apply_reshape
|
from tinygrad.schedule.indexing import apply_movement_op, _apply_reshape
|
||||||
from tinygrad.uop.divandmod import fold_divmod_general
|
from tinygrad.uop.divandmod import fold_divmod_general
|
||||||
|
|
|
||||||
2
test/external/fuzz_graph.py
vendored
2
test/external/fuzz_graph.py
vendored
|
|
@ -5,7 +5,7 @@ from tinygrad.helpers import Context, getenv, from_mv
|
||||||
from tinygrad.dtype import dtypes
|
from tinygrad.dtype import dtypes
|
||||||
from tinygrad.tensor import Tensor, _to_np_dtype
|
from tinygrad.tensor import Tensor, _to_np_dtype
|
||||||
from tinygrad.engine.realize import BufferXfer, get_runner
|
from tinygrad.engine.realize import BufferXfer, get_runner
|
||||||
from tinygrad.engine.schedule import ExecItem
|
from tinygrad.schedule import ExecItem
|
||||||
from tinygrad.uop.ops import UOp, Ops
|
from tinygrad.uop.ops import UOp, Ops
|
||||||
from tinygrad.engine.jit import apply_graph_to_jit
|
from tinygrad.engine.jit import apply_graph_to_jit
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
import unittest
|
import unittest
|
||||||
from tinygrad import dtypes
|
from tinygrad import dtypes
|
||||||
from tinygrad.uop.ops import UOp, Ops
|
from tinygrad.uop.ops import UOp, Ops
|
||||||
from tinygrad.engine.memory import memory_plan_rewrite
|
from tinygrad.schedule.memory import memory_plan_rewrite
|
||||||
|
|
||||||
global_map = {}
|
global_map = {}
|
||||||
held_bufs: set[UOp] = set()
|
held_bufs: set[UOp] = set()
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
import unittest
|
import unittest
|
||||||
from tinygrad import Tensor, Variable, Context
|
from tinygrad import Tensor, Variable, Context
|
||||||
from tinygrad.helpers import cpu_events
|
from tinygrad.helpers import cpu_events
|
||||||
from tinygrad.engine.schedule import schedule_cache
|
from tinygrad.schedule import schedule_cache
|
||||||
|
|
||||||
def schedule_one():
|
def schedule_one():
|
||||||
Tensor([1]).schedule()
|
Tensor([1]).schedule()
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ import unittest
|
||||||
from tinygrad import Tensor, dtypes
|
from tinygrad import Tensor, dtypes
|
||||||
from tinygrad.tensor import _METADATA
|
from tinygrad.tensor import _METADATA
|
||||||
from tinygrad.engine.realize import capturing
|
from tinygrad.engine.realize import capturing
|
||||||
from tinygrad.engine.schedule import linear_to_schedule
|
from tinygrad.schedule import linear_to_schedule
|
||||||
from tinygrad.helpers import Context
|
from tinygrad.helpers import Context
|
||||||
|
|
||||||
@unittest.skip("tensor metadata is no longer supported")
|
@unittest.skip("tensor metadata is no longer supported")
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ import unittest, math, time
|
||||||
from tinygrad import Tensor, Device, dtypes, Context
|
from tinygrad import Tensor, Device, dtypes, Context
|
||||||
from tinygrad.uop.ops import UOp, Ops
|
from tinygrad.uop.ops import UOp, Ops
|
||||||
from tinygrad.engine.realize import get_runner
|
from tinygrad.engine.realize import get_runner
|
||||||
from tinygrad.engine.schedule import ExecItem
|
from tinygrad.schedule import ExecItem
|
||||||
from tinygrad.engine.jit import TinyJit
|
from tinygrad.engine.jit import TinyJit
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
import unittest
|
import unittest
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
from tinygrad import Tensor, UOp
|
from tinygrad import Tensor, UOp
|
||||||
from tinygrad.engine.schedule import schedule_cache
|
from tinygrad.schedule import schedule_cache
|
||||||
from tinygrad.apps.llm import Transformer, TransformerConfig
|
from tinygrad.apps.llm import Transformer, TransformerConfig
|
||||||
|
|
||||||
TEST_CONFIG = TransformerConfig(num_blocks=1, dim=64, hidden_dim=128, n_heads=2, n_kv_heads=2,
|
TEST_CONFIG = TransformerConfig(num_blocks=1, dim=64, hidden_dim=128, n_heads=2, n_kv_heads=2,
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ import unittest
|
||||||
import functools
|
import functools
|
||||||
from tinygrad import Tensor, Variable, UOp
|
from tinygrad import Tensor, Variable, UOp
|
||||||
from tinygrad.uop.ops import KernelInfo
|
from tinygrad.uop.ops import KernelInfo
|
||||||
from tinygrad.engine.schedule import schedule_cache
|
from tinygrad.schedule import schedule_cache
|
||||||
|
|
||||||
def custom_set0_kernel(A:UOp, num:int) -> UOp:
|
def custom_set0_kernel(A:UOp, num:int) -> UOp:
|
||||||
return A[0].set(num).sink(arg=KernelInfo(f"custom_set0_{num}"))
|
return A[0].set(num).sink(arg=KernelInfo(f"custom_set0_{num}"))
|
||||||
|
|
|
||||||
|
|
@ -6,8 +6,8 @@ from tinygrad.device import Buffer, Compiled, Device, MultiBuffer
|
||||||
from tinygrad.dtype import DType, dtypes
|
from tinygrad.dtype import DType, dtypes
|
||||||
from tinygrad.uop.ops import UOp, PatternMatcher, Variable, sym_infer, Ops, buffers, track_rewrites, graph_rewrite
|
from tinygrad.uop.ops import UOp, PatternMatcher, Variable, sym_infer, Ops, buffers, track_rewrites, graph_rewrite
|
||||||
from tinygrad.engine.realize import ExecItem, capturing, BufferCopy, BufferXfer, EncDec, CompiledRunner, Runner, Estimates
|
from tinygrad.engine.realize import ExecItem, capturing, BufferCopy, BufferXfer, EncDec, CompiledRunner, Runner, Estimates
|
||||||
from tinygrad.engine.memory import memory_plan_rewrite, _collect_bufs
|
from tinygrad.schedule.memory import memory_plan_rewrite, _collect_bufs
|
||||||
from tinygrad.engine.schedule import linear_to_schedule
|
from tinygrad.schedule import linear_to_schedule
|
||||||
from tinygrad.nn.state import get_parameters
|
from tinygrad.nn.state import get_parameters
|
||||||
from tinygrad.schedule.rangeify import mop_cleanup
|
from tinygrad.schedule.rangeify import mop_cleanup
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
|
||||||
|
|
@ -1,182 +0,0 @@
|
||||||
import time, inspect
|
|
||||||
from typing import cast
|
|
||||||
from collections import deque
|
|
||||||
from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass, track_rewrites, graph_rewrite, gate_kernel_sink, KernelInfo
|
|
||||||
from tinygrad.uop.spec import type_verify, tensor_spec
|
|
||||||
from tinygrad.device import Buffer, MultiBuffer
|
|
||||||
from tinygrad.helpers import DEBUG, cpu_profile, TracingKey, SPEC, pluralize, SCACHE, BASEDIR, flatten, BEAM, partition
|
|
||||||
from tinygrad.engine.realize import ExecItem
|
|
||||||
|
|
||||||
# **** schedule linearizer
|
|
||||||
|
|
||||||
# unwrap VIEW/CAST/etc to find the actual data source (kernel output, buffer, or multi-device op)
|
|
||||||
def _unwrap_src(s: UOp) -> UOp:
|
|
||||||
while len(s.src) and s.op not in {Ops.AFTER, Ops.BUFFER, Ops.PARAM, Ops.MSELECT, Ops.MSTACK, Ops.BIND}: s = s.src[0]
|
|
||||||
return s
|
|
||||||
|
|
||||||
def _split_after(after: UOp) -> tuple[tuple[UOp, ...], tuple[UOp, ...]]:
|
|
||||||
kernels, remaining = partition(after.src[1:], lambda s: s.op in {Ops.CALL, Ops.END})
|
|
||||||
deps, remaining = partition(remaining, lambda s: s.op is Ops.AFTER)
|
|
||||||
if invalid := [s for s in remaining if s.op is not Ops.STORE]:
|
|
||||||
raise AssertionError(f"AFTER source should be CALL, END, STORE, or AFTER, not {invalid[0].op}")
|
|
||||||
return tuple(kernels), tuple(deps)
|
|
||||||
|
|
||||||
def create_schedule(sched_sink:UOp) -> UOp:
|
|
||||||
with cpu_profile(TracingKey("toposort sched_sink")):
|
|
||||||
# build kernel dependency graph: edges from producer kernel to consumer kernels
|
|
||||||
children: dict[UOp, list[UOp]] = {}
|
|
||||||
in_degree: dict[UOp, int] = {}
|
|
||||||
for u in sched_sink.toposort(gate_kernel_sink):
|
|
||||||
if u.op is not Ops.AFTER: continue
|
|
||||||
kernels, after_deps = _split_after(u)
|
|
||||||
for k in kernels:
|
|
||||||
in_degree.setdefault(k, 0)
|
|
||||||
if k.op is Ops.END: assert k.src[0].op is Ops.CALL, f"END src[0] should be KERNEL, not {k.src[0].op}"
|
|
||||||
kernel_deps = k.src[0].src[1:] if k.op is Ops.END else k.src[1:]
|
|
||||||
for s in kernel_deps + after_deps:
|
|
||||||
match (s := _unwrap_src(s)).op:
|
|
||||||
case Ops.AFTER:
|
|
||||||
for t in _split_after(s)[0]:
|
|
||||||
children.setdefault(t, []).append(k)
|
|
||||||
in_degree[k] += 1
|
|
||||||
case Ops.MSELECT | Ops.MSTACK:
|
|
||||||
for ss in s.src:
|
|
||||||
if ss.op is Ops.MSELECT: ss = ss.src[0]
|
|
||||||
if ss.op not in {Ops.BUFFER, Ops.PARAM}:
|
|
||||||
assert ss.op is Ops.AFTER, f"ss.op is not AFTER, it's {ss.op}"
|
|
||||||
for t in _split_after(ss)[0]:
|
|
||||||
children.setdefault(t, []).append(k)
|
|
||||||
in_degree[k] += 1
|
|
||||||
case Ops.BUFFER | Ops.PARAM | Ops.BIND:
|
|
||||||
pass # BUFFER/PARAM is already realized, BIND is a bound variable (not a buffer dependency)
|
|
||||||
case _:
|
|
||||||
raise RuntimeError(f"input to kernel must be AFTER, BUFFER, PARAM, MSELECT, MSTACK, or BIND, not {s.op}")
|
|
||||||
|
|
||||||
with cpu_profile(TracingKey("linearize schedule")):
|
|
||||||
queue: deque[UOp] = deque(k for k,v in in_degree.items() if v == 0)
|
|
||||||
linearized: list[UOp] = []
|
|
||||||
while len(queue):
|
|
||||||
rk = queue.popleft()
|
|
||||||
if rk.op is Ops.LINEAR:
|
|
||||||
linearized.extend(rk.src)
|
|
||||||
else:
|
|
||||||
k = rk.src[0] if rk.op is Ops.END else rk
|
|
||||||
assert k.op is Ops.CALL, f"unexpected op in queue: {k.op}"
|
|
||||||
buf_uops = tuple(_unwrap_src(s).buf_uop for s in k.src[1:] if s.op is not Ops.BIND)
|
|
||||||
linearized.append(k.src[0].call(*buf_uops, metadata=k.arg.metadata))
|
|
||||||
for x in children.get(rk, []):
|
|
||||||
in_degree[x] -= 1
|
|
||||||
if in_degree[x] == 0: queue.append(x)
|
|
||||||
return UOp(Ops.LINEAR, src=tuple(linearized))
|
|
||||||
|
|
||||||
def linear_to_schedule(linear:UOp) -> list[ExecItem]:
|
|
||||||
"""Convert a LINEAR UOp to a list of ExecItems."""
|
|
||||||
schedule: list[ExecItem] = []
|
|
||||||
for si in linear.src:
|
|
||||||
ast, buf_uops = si.src[0], si.src[1:]
|
|
||||||
# create subbuffers if needed
|
|
||||||
if ast.op is Ops.BUFFER_VIEW:
|
|
||||||
base = buf_uops[1].buffer
|
|
||||||
assert isinstance(base, Buffer), "base can't be MultiBuffer"
|
|
||||||
buffers[buf_uops[0]] = base.view(buf_uops[0].arg, ast.dtype, ast.arg[1]*base.dtype.itemsize)
|
|
||||||
# wrap SINK with BEAM UOp when beam search is enabled
|
|
||||||
if ast.op is Ops.SINK and BEAM >= 1: ast = UOp(Ops.BEAM, src=(ast,), arg=BEAM.value)
|
|
||||||
ubufs = [b.buffer for b in buf_uops if b.op is not Ops.BIND]
|
|
||||||
metadata = si.arg.metadata
|
|
||||||
if ast.op is Ops.CUSTOM_FUNCTION and ast.arg == "graph":
|
|
||||||
schedule.append(ExecItem(ast, flatten([b.bufs if isinstance(b, MultiBuffer) else [b] for b in ubufs]), metadata))
|
|
||||||
elif any(isinstance(x, MultiBuffer) for x in ubufs):
|
|
||||||
assert all(isinstance(x, MultiBuffer) for x in ubufs), "kernel must all be multibuffer"
|
|
||||||
dnums = [x for x in ast.variables() if x.expr == '_device_num']
|
|
||||||
for j, bufs in enumerate(zip(*[x.bufs for x in cast(tuple[MultiBuffer, ...], ubufs)])):
|
|
||||||
schedule.append(ExecItem(ast, list(bufs), metadata, {dnums[0].expr:j} if len(dnums) else {}))
|
|
||||||
else:
|
|
||||||
schedule.append(ExecItem(ast, cast(list[Buffer|None], ubufs), metadata))
|
|
||||||
return schedule
|
|
||||||
|
|
||||||
from tinygrad.engine.memory import memory_plan_rewrite
|
|
||||||
from tinygrad.engine.realize import capturing
|
|
||||||
from tinygrad.schedule.rangeify import get_kernel_graph
|
|
||||||
from tinygrad.helpers import CAPTURING
|
|
||||||
from tinygrad.uop.ops import PatternMatcher, UPat
|
|
||||||
|
|
||||||
def create_new_buffer(ctx:tuple[dict[UOp, UOp], tuple[UOp, ...]], b:UOp):
|
|
||||||
if (ret:=ctx[0].get(b, None)) is None: ctx[0][b] = ret = UOp.new_buffer(b.device, b.arg, b.dtype)
|
|
||||||
return ret
|
|
||||||
|
|
||||||
pm_post_sched_cache = PatternMatcher([
|
|
||||||
(UPat(Ops.PARAM, name="x"), lambda ctx,x: ctx[1][x.arg]),
|
|
||||||
# create new BUFFERs for LUNIQUE BUFFERs from rangeify
|
|
||||||
(UPat(Ops.BUFFER, src=(UPat(Ops.LUNIQUE), UPat(Ops.DEVICE)), name="b"), create_new_buffer),
|
|
||||||
])
|
|
||||||
|
|
||||||
pm_resolve_linear_call = PatternMatcher([
|
|
||||||
# call LINEAR is resolved here
|
|
||||||
(UPat(Ops.CALL, src=(UPat(Ops.LINEAR),), name="linear_call", allow_any_len=True), lambda linear_call:
|
|
||||||
graph_rewrite(linear_call.src[0], pm_post_sched_cache, ctx=({}, linear_call.src[1:]), walk=True, name="params to buffers")),
|
|
||||||
# LINEAR on LINEAR
|
|
||||||
(UPat(Ops.LINEAR, custom_early_reject={Ops.LINEAR}, name="x"),
|
|
||||||
lambda x: x.replace(src=tuple(flatten(x.src if x.op is Ops.LINEAR else (x,) for x in x.src)))),
|
|
||||||
])
|
|
||||||
|
|
||||||
schedule_cache: dict[bytes, UOp] = {}
|
|
||||||
# ctx is just for DEBUG on inner
|
|
||||||
def lower_sink_to_linear(function:UOp) -> UOp|None:
|
|
||||||
st = time.perf_counter()
|
|
||||||
if isinstance(function.arg, KernelInfo): return None
|
|
||||||
cache_key = function.key
|
|
||||||
if not SCACHE or (sc_ret:=schedule_cache.get(cache_key, None)) is None:
|
|
||||||
if SPEC: type_verify(function, tensor_spec)
|
|
||||||
# support recursive CALLs
|
|
||||||
linear = create_schedule(get_kernel_graph(function))
|
|
||||||
if SCACHE: schedule_cache[cache_key] = linear
|
|
||||||
else:
|
|
||||||
# schedule cache hit
|
|
||||||
linear = sc_ret
|
|
||||||
if (DEBUG >= 1 and len(linear.src) > 1) or DEBUG >= 3:
|
|
||||||
for frm in inspect.stack():
|
|
||||||
if frm.filename == "<string>": continue
|
|
||||||
if frm.filename.startswith(str(BASEDIR / "apps")): break
|
|
||||||
if not frm.filename.startswith(str(BASEDIR)) and not frm.filename.endswith("/contextlib.py"): break
|
|
||||||
else:
|
|
||||||
frm = None
|
|
||||||
print(f"scheduled {len(linear.src):5d} kernels in {(time.perf_counter()-st)*1000:8.2f} ms"+\
|
|
||||||
f" | {' cache hit' if SCACHE and sc_ret is not None else 'CACHE MISS'} {cache_key.hex()[:8]}"+\
|
|
||||||
f" | {len(UOpMetaClass.ucache):7d} uops in cache"+("" if frm is None else f" | {frm.filename}:{frm.lineno}"))
|
|
||||||
return linear
|
|
||||||
|
|
||||||
pm_schedule = PatternMatcher([
|
|
||||||
(UPat(Ops.SINK, name="function"), lower_sink_to_linear),
|
|
||||||
])
|
|
||||||
|
|
||||||
@track_rewrites(lambda _,ret: f"Schedule {pluralize('Kernel', len(ret[0]))}")
|
|
||||||
def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[list[ExecItem], dict[str, int]]:
|
|
||||||
# big_sink srcs are all the Tensors
|
|
||||||
linear_call = graph_rewrite(big_sink, pm_schedule, name="schedule to linear", enter_calls=True)
|
|
||||||
|
|
||||||
# this recursively resolves the linear_call and allocates buffers
|
|
||||||
linear = graph_rewrite(linear_call, pm_resolve_linear_call, name="resolve linear call")
|
|
||||||
|
|
||||||
# vars used in the schedule
|
|
||||||
used_vars = set().union(*[{v.expr for v in si.src[0].variables()} for si in linear.src])
|
|
||||||
# get var_vals
|
|
||||||
var_vals: dict[str, int] = {}
|
|
||||||
for b in big_sink.src[1:]:
|
|
||||||
if b.op is Ops.BIND:
|
|
||||||
nm = b.src[0].expr
|
|
||||||
if nm not in used_vars: continue
|
|
||||||
val = b.src[1].arg
|
|
||||||
if var_vals.get(nm, val) != val: raise RuntimeError(f"bind mismatch on {nm}, {var_vals[nm]} != {val}")
|
|
||||||
var_vals[nm] = val
|
|
||||||
|
|
||||||
# jit captures this schedule, no need to execute.
|
|
||||||
if len(capturing) and CAPTURING:
|
|
||||||
capturing[0].add_linear(linear, var_vals)
|
|
||||||
return [], var_vals
|
|
||||||
|
|
||||||
held_bufs = ({b for b in linear_call.src[1:] if b.op is Ops.BUFFER} if linear_call.op is Ops.CALL else set())
|
|
||||||
linear = memory_plan_rewrite(linear, held_bufs)
|
|
||||||
|
|
||||||
# convert LINEAR to ExecItems
|
|
||||||
schedule: list[ExecItem] = linear_to_schedule(linear)
|
|
||||||
return schedule, var_vals
|
|
||||||
|
|
@ -0,0 +1,182 @@
|
||||||
|
import time, inspect
|
||||||
|
from typing import cast
|
||||||
|
from collections import deque
|
||||||
|
from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass, track_rewrites, graph_rewrite, gate_kernel_sink, KernelInfo
|
||||||
|
from tinygrad.uop.spec import type_verify, tensor_spec
|
||||||
|
from tinygrad.device import Buffer, MultiBuffer
|
||||||
|
from tinygrad.helpers import DEBUG, cpu_profile, TracingKey, SPEC, pluralize, SCACHE, BASEDIR, flatten, BEAM, partition
|
||||||
|
from tinygrad.engine.realize import ExecItem
|
||||||
|
|
||||||
|
# **** schedule linearizer
|
||||||
|
|
||||||
|
# unwrap VIEW/CAST/etc to find the actual data source (kernel output, buffer, or multi-device op)
|
||||||
|
def _unwrap_src(s: UOp) -> UOp:
|
||||||
|
while len(s.src) and s.op not in {Ops.AFTER, Ops.BUFFER, Ops.PARAM, Ops.MSELECT, Ops.MSTACK, Ops.BIND}: s = s.src[0]
|
||||||
|
return s
|
||||||
|
|
||||||
|
def _split_after(after: UOp) -> tuple[tuple[UOp, ...], tuple[UOp, ...]]:
|
||||||
|
kernels, remaining = partition(after.src[1:], lambda s: s.op in {Ops.CALL, Ops.END})
|
||||||
|
deps, remaining = partition(remaining, lambda s: s.op is Ops.AFTER)
|
||||||
|
if invalid := [s for s in remaining if s.op is not Ops.STORE]:
|
||||||
|
raise AssertionError(f"AFTER source should be CALL, END, STORE, or AFTER, not {invalid[0].op}")
|
||||||
|
return tuple(kernels), tuple(deps)
|
||||||
|
|
||||||
|
def create_schedule(sched_sink:UOp) -> UOp:
|
||||||
|
with cpu_profile(TracingKey("toposort sched_sink")):
|
||||||
|
# build kernel dependency graph: edges from producer kernel to consumer kernels
|
||||||
|
children: dict[UOp, list[UOp]] = {}
|
||||||
|
in_degree: dict[UOp, int] = {}
|
||||||
|
for u in sched_sink.toposort(gate_kernel_sink):
|
||||||
|
if u.op is not Ops.AFTER: continue
|
||||||
|
kernels, after_deps = _split_after(u)
|
||||||
|
for k in kernels:
|
||||||
|
in_degree.setdefault(k, 0)
|
||||||
|
if k.op is Ops.END: assert k.src[0].op is Ops.CALL, f"END src[0] should be KERNEL, not {k.src[0].op}"
|
||||||
|
kernel_deps = k.src[0].src[1:] if k.op is Ops.END else k.src[1:]
|
||||||
|
for s in kernel_deps + after_deps:
|
||||||
|
match (s := _unwrap_src(s)).op:
|
||||||
|
case Ops.AFTER:
|
||||||
|
for t in _split_after(s)[0]:
|
||||||
|
children.setdefault(t, []).append(k)
|
||||||
|
in_degree[k] += 1
|
||||||
|
case Ops.MSELECT | Ops.MSTACK:
|
||||||
|
for ss in s.src:
|
||||||
|
if ss.op is Ops.MSELECT: ss = ss.src[0]
|
||||||
|
if ss.op not in {Ops.BUFFER, Ops.PARAM}:
|
||||||
|
assert ss.op is Ops.AFTER, f"ss.op is not AFTER, it's {ss.op}"
|
||||||
|
for t in _split_after(ss)[0]:
|
||||||
|
children.setdefault(t, []).append(k)
|
||||||
|
in_degree[k] += 1
|
||||||
|
case Ops.BUFFER | Ops.PARAM | Ops.BIND:
|
||||||
|
pass # BUFFER/PARAM is already realized, BIND is a bound variable (not a buffer dependency)
|
||||||
|
case _:
|
||||||
|
raise RuntimeError(f"input to kernel must be AFTER, BUFFER, PARAM, MSELECT, MSTACK, or BIND, not {s.op}")
|
||||||
|
|
||||||
|
with cpu_profile(TracingKey("linearize schedule")):
|
||||||
|
queue: deque[UOp] = deque(k for k,v in in_degree.items() if v == 0)
|
||||||
|
linearized: list[UOp] = []
|
||||||
|
while len(queue):
|
||||||
|
rk = queue.popleft()
|
||||||
|
if rk.op is Ops.LINEAR:
|
||||||
|
linearized.extend(rk.src)
|
||||||
|
else:
|
||||||
|
k = rk.src[0] if rk.op is Ops.END else rk
|
||||||
|
assert k.op is Ops.CALL, f"unexpected op in queue: {k.op}"
|
||||||
|
buf_uops = tuple(_unwrap_src(s).buf_uop for s in k.src[1:] if s.op is not Ops.BIND)
|
||||||
|
linearized.append(k.src[0].call(*buf_uops, metadata=k.arg.metadata))
|
||||||
|
for x in children.get(rk, []):
|
||||||
|
in_degree[x] -= 1
|
||||||
|
if in_degree[x] == 0: queue.append(x)
|
||||||
|
return UOp(Ops.LINEAR, src=tuple(linearized))
|
||||||
|
|
||||||
|
def linear_to_schedule(linear:UOp) -> list[ExecItem]:
|
||||||
|
"""Convert a LINEAR UOp to a list of ExecItems."""
|
||||||
|
schedule: list[ExecItem] = []
|
||||||
|
for si in linear.src:
|
||||||
|
ast, buf_uops = si.src[0], si.src[1:]
|
||||||
|
# create subbuffers if needed
|
||||||
|
if ast.op is Ops.BUFFER_VIEW:
|
||||||
|
base = buf_uops[1].buffer
|
||||||
|
assert isinstance(base, Buffer), "base can't be MultiBuffer"
|
||||||
|
buffers[buf_uops[0]] = base.view(buf_uops[0].arg, ast.dtype, ast.arg[1]*base.dtype.itemsize)
|
||||||
|
# wrap SINK with BEAM UOp when beam search is enabled
|
||||||
|
if ast.op is Ops.SINK and BEAM >= 1: ast = UOp(Ops.BEAM, src=(ast,), arg=BEAM.value)
|
||||||
|
ubufs = [b.buffer for b in buf_uops if b.op is not Ops.BIND]
|
||||||
|
metadata = si.arg.metadata
|
||||||
|
if ast.op is Ops.CUSTOM_FUNCTION and ast.arg == "graph":
|
||||||
|
schedule.append(ExecItem(ast, flatten([b.bufs if isinstance(b, MultiBuffer) else [b] for b in ubufs]), metadata))
|
||||||
|
elif any(isinstance(x, MultiBuffer) for x in ubufs):
|
||||||
|
assert all(isinstance(x, MultiBuffer) for x in ubufs), "kernel must all be multibuffer"
|
||||||
|
dnums = [x for x in ast.variables() if x.expr == '_device_num']
|
||||||
|
for j, bufs in enumerate(zip(*[x.bufs for x in cast(tuple[MultiBuffer, ...], ubufs)])):
|
||||||
|
schedule.append(ExecItem(ast, list(bufs), metadata, {dnums[0].expr:j} if len(dnums) else {}))
|
||||||
|
else:
|
||||||
|
schedule.append(ExecItem(ast, cast(list[Buffer|None], ubufs), metadata))
|
||||||
|
return schedule
|
||||||
|
|
||||||
|
from tinygrad.schedule.memory import memory_plan_rewrite
|
||||||
|
from tinygrad.engine.realize import capturing
|
||||||
|
from tinygrad.schedule.rangeify import get_kernel_graph
|
||||||
|
from tinygrad.helpers import CAPTURING
|
||||||
|
from tinygrad.uop.ops import PatternMatcher, UPat
|
||||||
|
|
||||||
|
def create_new_buffer(ctx:tuple[dict[UOp, UOp], tuple[UOp, ...]], b:UOp):
|
||||||
|
if (ret:=ctx[0].get(b, None)) is None: ctx[0][b] = ret = UOp.new_buffer(b.device, b.arg, b.dtype)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
pm_post_sched_cache = PatternMatcher([
|
||||||
|
(UPat(Ops.PARAM, name="x"), lambda ctx,x: ctx[1][x.arg]),
|
||||||
|
# create new BUFFERs for LUNIQUE BUFFERs from rangeify
|
||||||
|
(UPat(Ops.BUFFER, src=(UPat(Ops.LUNIQUE), UPat(Ops.DEVICE)), name="b"), create_new_buffer),
|
||||||
|
])
|
||||||
|
|
||||||
|
pm_resolve_linear_call = PatternMatcher([
|
||||||
|
# call LINEAR is resolved here
|
||||||
|
(UPat(Ops.CALL, src=(UPat(Ops.LINEAR),), name="linear_call", allow_any_len=True), lambda linear_call:
|
||||||
|
graph_rewrite(linear_call.src[0], pm_post_sched_cache, ctx=({}, linear_call.src[1:]), walk=True, name="params to buffers")),
|
||||||
|
# LINEAR on LINEAR
|
||||||
|
(UPat(Ops.LINEAR, custom_early_reject={Ops.LINEAR}, name="x"),
|
||||||
|
lambda x: x.replace(src=tuple(flatten(x.src if x.op is Ops.LINEAR else (x,) for x in x.src)))),
|
||||||
|
])
|
||||||
|
|
||||||
|
schedule_cache: dict[bytes, UOp] = {}
|
||||||
|
# ctx is just for DEBUG on inner
|
||||||
|
def lower_sink_to_linear(function:UOp) -> UOp|None:
|
||||||
|
st = time.perf_counter()
|
||||||
|
if isinstance(function.arg, KernelInfo): return None
|
||||||
|
cache_key = function.key
|
||||||
|
if not SCACHE or (sc_ret:=schedule_cache.get(cache_key, None)) is None:
|
||||||
|
if SPEC: type_verify(function, tensor_spec)
|
||||||
|
# support recursive CALLs
|
||||||
|
linear = create_schedule(get_kernel_graph(function))
|
||||||
|
if SCACHE: schedule_cache[cache_key] = linear
|
||||||
|
else:
|
||||||
|
# schedule cache hit
|
||||||
|
linear = sc_ret
|
||||||
|
if (DEBUG >= 1 and len(linear.src) > 1) or DEBUG >= 3:
|
||||||
|
for frm in inspect.stack():
|
||||||
|
if frm.filename == "<string>": continue
|
||||||
|
if frm.filename.startswith(str(BASEDIR / "apps")): break
|
||||||
|
if not frm.filename.startswith(str(BASEDIR)) and not frm.filename.endswith("/contextlib.py"): break
|
||||||
|
else:
|
||||||
|
frm = None
|
||||||
|
print(f"scheduled {len(linear.src):5d} kernels in {(time.perf_counter()-st)*1000:8.2f} ms"+\
|
||||||
|
f" | {' cache hit' if SCACHE and sc_ret is not None else 'CACHE MISS'} {cache_key.hex()[:8]}"+\
|
||||||
|
f" | {len(UOpMetaClass.ucache):7d} uops in cache"+("" if frm is None else f" | {frm.filename}:{frm.lineno}"))
|
||||||
|
return linear
|
||||||
|
|
||||||
|
pm_schedule = PatternMatcher([
|
||||||
|
(UPat(Ops.SINK, name="function"), lower_sink_to_linear),
|
||||||
|
])
|
||||||
|
|
||||||
|
@track_rewrites(lambda _,ret: f"Schedule {pluralize('Kernel', len(ret[0]))}")
|
||||||
|
def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[list[ExecItem], dict[str, int]]:
|
||||||
|
# big_sink srcs are all the Tensors
|
||||||
|
linear_call = graph_rewrite(big_sink, pm_schedule, name="schedule to linear", enter_calls=True)
|
||||||
|
|
||||||
|
# this recursively resolves the linear_call and allocates buffers
|
||||||
|
linear = graph_rewrite(linear_call, pm_resolve_linear_call, name="resolve linear call")
|
||||||
|
|
||||||
|
# vars used in the schedule
|
||||||
|
used_vars = set().union(*[{v.expr for v in si.src[0].variables()} for si in linear.src])
|
||||||
|
# get var_vals
|
||||||
|
var_vals: dict[str, int] = {}
|
||||||
|
for b in big_sink.src[1:]:
|
||||||
|
if b.op is Ops.BIND:
|
||||||
|
nm = b.src[0].expr
|
||||||
|
if nm not in used_vars: continue
|
||||||
|
val = b.src[1].arg
|
||||||
|
if var_vals.get(nm, val) != val: raise RuntimeError(f"bind mismatch on {nm}, {var_vals[nm]} != {val}")
|
||||||
|
var_vals[nm] = val
|
||||||
|
|
||||||
|
# jit captures this schedule, no need to execute.
|
||||||
|
if len(capturing) and CAPTURING:
|
||||||
|
capturing[0].add_linear(linear, var_vals)
|
||||||
|
return [], var_vals
|
||||||
|
|
||||||
|
held_bufs = ({b for b in linear_call.src[1:] if b.op is Ops.BUFFER} if linear_call.op is Ops.CALL else set())
|
||||||
|
linear = memory_plan_rewrite(linear, held_bufs)
|
||||||
|
|
||||||
|
# convert LINEAR to ExecItems
|
||||||
|
schedule: list[ExecItem] = linear_to_schedule(linear)
|
||||||
|
return schedule, var_vals
|
||||||
|
|
@ -13,10 +13,10 @@ from tinygrad.gradient import compute_gradient
|
||||||
from tinygrad.mixin import OpMixin, ReductionStr
|
from tinygrad.mixin import OpMixin, ReductionStr
|
||||||
from tinygrad.uop.ops import smax, UOp, Ops, sint, all_metadata, _index_to_concrete_int, sint_to_uop, Variable
|
from tinygrad.uop.ops import smax, UOp, Ops, sint, all_metadata, _index_to_concrete_int, sint_to_uop, Variable
|
||||||
from tinygrad.uop.ops import _broadcast_shape
|
from tinygrad.uop.ops import _broadcast_shape
|
||||||
from tinygrad.engine.schedule import ExecItem, complete_create_schedule_with_vars
|
from tinygrad.schedule import ExecItem, complete_create_schedule_with_vars
|
||||||
from tinygrad.device import Buffer, canonicalize_device
|
from tinygrad.device import Buffer, canonicalize_device
|
||||||
from tinygrad.engine.realize import run_schedule
|
from tinygrad.engine.realize import run_schedule
|
||||||
from tinygrad.engine.callify import transform_to_call
|
from tinygrad.callify import transform_to_call
|
||||||
|
|
||||||
# *** all in scope Tensors are here. this gets relevant UOps ***
|
# *** all in scope Tensors are here. this gets relevant UOps ***
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue