mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
parent
e27444a0ff
commit
bb652352c7
5 changed files with 32 additions and 80 deletions
|
|
@ -17,13 +17,11 @@ The `UOp` graph specifies the compute in terms of low level tinygrad ops. Not al
|
|||
|
||||
## Scheduling
|
||||
|
||||
The [scheduler](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/schedule/__init__.py) converts the graph of UOps into a list of `ExecItem`. One `ExecItem` is one kernel on the GPU, and the scheduler is responsible for breaking the large compute graph into subgraphs that can fit in a kernel. `ast` specifies what compute to run, and `bufs` specifies what buffers to run it on.
|
||||
|
||||
::: tinygrad.engine.realize.ExecItem
|
||||
The [scheduler](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/schedule/__init__.py) converts the graph of UOps into a `LINEAR` UOp whose `src` is a list of `CALL` UOps. One `CALL` is one kernel on the GPU, and the scheduler is responsible for breaking the large compute graph into subgraphs that can fit in a kernel. The `CALL`'s `src[0]` (a `SINK` ast) specifies what compute to run, and the remaining `src` are the buffers to run it on.
|
||||
|
||||
## Lowering
|
||||
|
||||
The code in [realize](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/engine/realize.py) lowers `ExecItem` by populating its `prg` field with
|
||||
The code in [realize](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/engine/realize.py) lowers each `CALL` by compiling its ast into a `PROGRAM` and running it.
|
||||
|
||||
::: tinygrad.engine.realize.run_linear
|
||||
|
||||
|
|
@ -35,13 +33,7 @@ Then we render the UOps into code with a `Renderer`, then we compile the code to
|
|||
|
||||
## Execution
|
||||
|
||||
Creating `ExecItem`, which has a run method
|
||||
|
||||
::: tinygrad.engine.realize.ExecItem
|
||||
options:
|
||||
members: true
|
||||
|
||||
Lists of `ExecItem` can be condensed into a single ExecItem with the Graph API (rename to Queue?)
|
||||
`run_linear` walks the `LINEAR` UOp, dispatching each `CALL` to a runner (kernel, copy, view, encdec, or graph).
|
||||
|
||||
## Runtime
|
||||
|
||||
|
|
|
|||
|
|
@ -1,31 +1,39 @@
|
|||
# kernel8_batched_gmem.s from https://seb-v.github.io/optimization/update/2025/01/20/Fast-GPU-Matrix-multiplication.html
|
||||
# sudo PATH=/opt/homebrew/Cellar/llvm/20.1.6/bin:$PATH AMD_LLVM=0 AMD=1 DEBUG=2 python3 extra/gemm/amd_matmul.py
|
||||
import pathlib
|
||||
from dataclasses import replace
|
||||
from tinygrad import Tensor, Device, Context, GlobalCounters
|
||||
from tinygrad.helpers import getenv
|
||||
from tinygrad.engine.realize import CompiledRunner, ExecItem, get_program
|
||||
from tinygrad.uop.ops import UOp, Ops, KernelInfo
|
||||
from tinygrad.renderer import Estimates
|
||||
from tinygrad.engine.realize import run_linear
|
||||
|
||||
N = 4096
|
||||
run_count = 5
|
||||
|
||||
if __name__ == "__main__":
|
||||
ast = (Tensor.empty(N, N)@Tensor.empty(N, N)).schedule_linear().src[-1].src[0]
|
||||
prg = get_program(ast, Device.default.renderer)
|
||||
def make_matmul_kernel(name:str, src:str, local_size:int):
|
||||
def fxn(a:UOp, b:UOp, c:UOp) -> UOp:
|
||||
threads = UOp.special(local_size, "lidx0")
|
||||
wg_x = UOp.special(N//128, "gidx0")
|
||||
wg_y = UOp.special(N//128, "gidx1")
|
||||
sink = UOp.sink(a.base, b.base, c.base, threads, wg_x, wg_y, arg=KernelInfo(name, estimates=Estimates(ops=2*N**3, mem=3*N*N*4)))
|
||||
lib = Device[Device.DEFAULT].compiler.compile_cached(src)
|
||||
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=Device.DEFAULT), UOp(Ops.LINEAR, src=(*sink.src, sink)),
|
||||
UOp(Ops.SOURCE, arg=src), UOp(Ops.BINARY, arg=lib)))
|
||||
return fxn
|
||||
|
||||
if __name__ == "__main__":
|
||||
if getenv("ASM") == 1:
|
||||
src = (pathlib.Path(__file__).parent / "amd_seb" / "kernel8_batched_gmem.s").read_text()
|
||||
prgfast = replace(prg, name="kernel", src=src, global_size=[N//128, N//128, 1], local_size=[128, 1, 1])
|
||||
name, local_size = "kernel", 128
|
||||
elif getenv("ASM") == -1:
|
||||
src = (pathlib.Path(__file__).parent / "amd_seb" / "kernel3_registers.cpp").read_text()
|
||||
prgfast = replace(prg, name="kernel3_registers", src=src, global_size=[N//128, N//128, 1], local_size=[256, 1, 1])
|
||||
name, local_size = "kernel3_registers", 256
|
||||
elif getenv("ASM") == -2:
|
||||
src = (pathlib.Path(__file__).parent / "amd_seb" / "kernel4_gmem_df.cpp").read_text()
|
||||
prgfast = replace(prg, name="kernel4_gmem_db", src=src, global_size=[N//128, N//128, 1], local_size=[256, 1, 1])
|
||||
name, local_size = "kernel4_gmem_db", 256
|
||||
else:
|
||||
src = (pathlib.Path(__file__).parent / "amd_seb" / "kernel5_lds_optim.cpp").read_text()
|
||||
prgfast = replace(prg, name="kernel5_lds_optim", src=src, global_size=[N//128, N//128, 1], local_size=[128, 1, 1])
|
||||
runner = CompiledRunner(prgfast)
|
||||
name, local_size = "kernel5_lds_optim", 128
|
||||
|
||||
a = Tensor.randn(N, N).realize()
|
||||
b = Tensor.randn(N, N).realize()
|
||||
|
|
@ -35,8 +43,8 @@ if __name__ == "__main__":
|
|||
with Context(DEBUG=2):
|
||||
for _ in range(run_count): tc = (a@b).realize()
|
||||
|
||||
linear = Tensor.custom_kernel(a, b, c, fxn=make_matmul_kernel(name, src, local_size))[2].schedule_linear()
|
||||
GlobalCounters.reset()
|
||||
ei = ExecItem(ast, [a.uop.buffer, b.uop.buffer, c.uop.buffer], prg=runner)
|
||||
with Context(DEBUG=2):
|
||||
for _ in range(run_count): ei.run(wait=True)
|
||||
for _ in range(run_count): run_linear(linear)
|
||||
print(f"custom {(c-tc).square().mean().item()}")
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import triton.language as tl
|
|||
from triton.compiler import AttrsDescriptor, ASTSource, compile as triton_compile
|
||||
import numpy as np
|
||||
from tinygrad import Tensor, dtypes, Device
|
||||
from tinygrad.engine.realize import CompiledRunner, ExecItem, ProgramSpec
|
||||
from tinygrad.engine.realize import CompiledRunner, ProgramSpec
|
||||
from tinygrad.helpers import getenv
|
||||
np.set_printoptions(suppress=True)
|
||||
|
||||
|
|
@ -91,10 +91,12 @@ if __name__ == "__main__":
|
|||
prg = ProgramSpec("matmul_kernel", src, device=Device.DEFAULT,
|
||||
global_size=[M//BLOCK_SIZE_M, N//BLOCK_SIZE_N, 1], local_size=[32*compiled.metadata.num_warps, 1, 1],
|
||||
mem_estimate=A.nbytes() + B.nbytes() + C.nbytes())
|
||||
ei = ExecItem(ast, [x.ensure_allocated() for x in bufs], last_call.arg.metadata, prg=CompiledRunner(prg))
|
||||
runner = CompiledRunner(prg)
|
||||
all_bufs = [x.ensure_allocated() for x in bufs]
|
||||
prg_bufs = [all_bufs[i] for i in runner.p.globals]
|
||||
tflops = []
|
||||
for i in range(5):
|
||||
tm = ei.run(wait=True)
|
||||
tm = runner(prg_bufs, {}, wait=True)
|
||||
tflops.append((2*M*K*N/tm)*1e-12)
|
||||
print(f"TFLOPS: {max(tflops):.2f}")
|
||||
|
||||
|
|
|
|||
|
|
@ -212,7 +212,7 @@ class TestSchedule(unittest.TestCase):
|
|||
linear, _ = Tensor.linear_with_vars(v)
|
||||
self.assertEqual(len(linear.src), 0)
|
||||
|
||||
# NOTE: because empty does not have a lowered ExecItem if realize is called on a childless empty, it never gets allocated.
|
||||
# NOTE: because empty does not have a lowered kernel if realize is called on a childless empty, it never gets allocated.
|
||||
def test_childless_empty_never_allocates(self):
|
||||
a = Tensor.empty(10)
|
||||
a.realize()
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
from typing import cast, Callable, Iterator
|
||||
import time, pprint, random, itertools, math, contextlib, weakref
|
||||
import time, random, itertools, math, contextlib, weakref
|
||||
from dataclasses import dataclass, replace, field
|
||||
from tinygrad.helpers import all_same, colored, DEBUG, GlobalCounters, ansilen, NOOPT, all_int, Metadata, TRACEMETA, TracingKey
|
||||
from tinygrad.helpers import colored, DEBUG, GlobalCounters, ansilen, NOOPT, all_int, Metadata, TRACEMETA, TracingKey
|
||||
from tinygrad.helpers import BEAM, DEVECTORIZE, size_to_str, time_to_str, VALIDATE_WITH_CPU, cpu_profile, PROFILE, ProfilePointEvent, cpu_events
|
||||
from tinygrad.helpers import prod, unwrap, EMULATED_DTYPES, flatten
|
||||
from tinygrad.helpers import prod, EMULATED_DTYPES, flatten
|
||||
from tinygrad.uop.ops import Ops, PatternMatcher, UOp, UPat, sym_infer, buffers, graph_rewrite
|
||||
from tinygrad.device import Device, Buffer, MultiBuffer
|
||||
from tinygrad.renderer import ProgramSpec, Estimates
|
||||
|
|
@ -151,56 +151,6 @@ def get_runner(device:str, ast:UOp) -> CompiledRunner:
|
|||
method_cache[ckey] = method_cache[bkey] = ret = CompiledRunner(replace(prg, device=device))
|
||||
return ret
|
||||
|
||||
# **************** lowering functions ****************
|
||||
|
||||
# NOTE: ctx is the buffers
|
||||
si_lowerer = PatternMatcher([
|
||||
(UPat((Ops.SINK, Ops.PROGRAM), name="sink"), lambda ctx,sink: get_runner(ctx[0].device, sink)),
|
||||
(UPat(Ops.BUFFER_VIEW), lambda ctx: ViewOp(ctx[0])),
|
||||
(UPat(Ops.COPY), lambda ctx: (BufferXfer(ctx[0].nbytes, ctx[0].device, ctx[1].device) \
|
||||
if hasattr(alc:=Device[ctx[0].device].allocator, '_transfer') and alc.supports_transfer and all_same([x.device.split(":")[0] for x in ctx]) \
|
||||
else BufferCopy(ctx[0].nbytes, ctx[0].device, ctx[1].device))),
|
||||
(UPat(Ops.CUSTOM_FUNCTION, arg="encdec", name="cf"), lambda ctx,cf: EncDec(cf, ctx[0].nbytes, ctx[0].device)),
|
||||
(UPat(Ops.CUSTOM_FUNCTION, arg="graph", name="cf"), lambda ctx,cf: Device[cf.device if isinstance(cf.device,str) else cf.device[0]].graph(cf, ctx))
|
||||
])
|
||||
|
||||
@dataclass
|
||||
class ExecItem:
|
||||
ast: UOp
|
||||
bufs: list[Buffer|None] = field(default_factory=list)
|
||||
metadata: tuple[Metadata, ...] = ()
|
||||
fixedvars: dict[str, int] = field(default_factory=dict)
|
||||
prg: Runner|None = None
|
||||
|
||||
def lower(self):
|
||||
"""Populate self.prg by lowering the AST."""
|
||||
if self.prg is not None: return self
|
||||
try: self.prg = cast(Runner, si_lowerer.rewrite(self.ast, self.bufs))
|
||||
except Exception as e:
|
||||
if DEBUG >= 2:
|
||||
print(f"error lowering {self.ast.op}")
|
||||
print("tensor operations:")
|
||||
pprint.pprint(self.metadata, indent=2)
|
||||
raise e
|
||||
return self
|
||||
|
||||
def run(self, _var_vals:dict[str, int]|None=None, wait=False, jit=False, do_update_stats=True) -> float|None:
|
||||
if self.prg is None: self.lower()
|
||||
assert self.prg is not None
|
||||
var_vals = self.fixedvars if _var_vals is None else (_var_vals|self.fixedvars)
|
||||
# reorder bufs to match program globals if needed
|
||||
_bufs = [self.bufs[i] for i in self.prg.p.globals] if isinstance(self.prg, CompiledRunner) else self.bufs
|
||||
bufs = [unwrap(x) for x in _bufs] if jit else [unwrap(x).ensure_allocated() for x in _bufs]
|
||||
if PROFILE:
|
||||
payload = {"metadata":self.metadata, "var_vals":var_vals, "bufs":[b.trace_num for b in bufs], "name":self.prg.display_name}
|
||||
payload["outputs"], payload["inputs"] = (self.prg.p.outs, self.prg.p.ins) if isinstance(self.prg, CompiledRunner) else ([0], [1])
|
||||
cpu_events.append(ProfilePointEvent(self.prg.device, "exec", len(cpu_events), payload))
|
||||
et = self.prg(bufs, var_vals, wait=wait or DEBUG >= 2)
|
||||
if do_update_stats:
|
||||
update_stats(self.prg.display_name, self.prg.device, self.prg.estimates, var_vals, et, len(bufs), jit, self.metadata, self.prg.first_run)
|
||||
self.prg.first_run = False
|
||||
return et
|
||||
|
||||
# **************** run linear ****************
|
||||
|
||||
capturing: list = [] # put classes with an add_linear method in here
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue