mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
4 commits
master
...
execution_
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
43c033e5e5 | ||
|
|
76f2f14233 | ||
|
|
5515de6553 | ||
|
|
439c4319ec |
13 changed files with 257 additions and 51 deletions
|
|
@ -45,9 +45,10 @@ print(f"The schedule contains {len(schedule)} items.")
|
||||||
for si in schedule: print(str(si)[:80])
|
for si in schedule: print(str(si)[:80])
|
||||||
|
|
||||||
# *****
|
# *****
|
||||||
# 4. Lower and run the schedule.
|
# 4. Run the schedule.
|
||||||
|
|
||||||
for si in tqdm(schedule): si.run()
|
from tinygrad.engine.realize import run_schedule
|
||||||
|
run_schedule(schedule)
|
||||||
|
|
||||||
# *****
|
# *****
|
||||||
# 5. Print the weight change
|
# 5. Print the weight change
|
||||||
|
|
|
||||||
4
test/external/external_benchmark_resnet.py
vendored
4
test/external/external_benchmark_resnet.py
vendored
|
|
@ -71,8 +71,8 @@ class BenchmarkResnetTrain(unittest.TestCase):
|
||||||
|
|
||||||
y = x.sequential(layer).contiguous().contiguous_backward()
|
y = x.sequential(layer).contiguous().contiguous_backward()
|
||||||
y.sum().backward()
|
y.sum().backward()
|
||||||
if getenv("ASSIGN", 1): sched, _ = Tensor.schedule_with_vars(y, x.grad, *optim.schedule_step())
|
if getenv("ASSIGN", 1): sched, _, _ = Tensor.schedule_with_vars(y, x.grad, *optim.schedule_step())
|
||||||
else: sched, _ = Tensor.schedule_with_vars(y, x.grad, *[t.grad for t in optim.params])
|
else: sched, _, _ = Tensor.schedule_with_vars(y, x.grad, *[t.grad for t in optim.params])
|
||||||
|
|
||||||
for _ in range(JITCNT):
|
for _ in range(JITCNT):
|
||||||
run_schedule(list(sched))
|
run_schedule(list(sched))
|
||||||
|
|
|
||||||
|
|
@ -49,8 +49,8 @@ class BenchmarkBertTrain(unittest.TestCase):
|
||||||
|
|
||||||
y = layer(*inputs).contiguous().contiguous_backward()
|
y = layer(*inputs).contiguous().contiguous_backward()
|
||||||
y.sum().backward()
|
y.sum().backward()
|
||||||
if getenv("ASSIGN", 1): sched, _ = Tensor.schedule_with_vars(y, *list(inputs), *optim.schedule_step())
|
if getenv("ASSIGN", 1): sched, _, _ = Tensor.schedule_with_vars(y, *list(inputs), *optim.schedule_step())
|
||||||
else: sched, _ = Tensor.schedule_with_vars(y, *list(inputs), *[t.grad for t in optim.params])
|
else: sched, _, _ = Tensor.schedule_with_vars(y, *list(inputs), *[t.grad for t in optim.params])
|
||||||
|
|
||||||
for _ in range(JITCNT):
|
for _ in range(JITCNT):
|
||||||
run_schedule(sched)
|
run_schedule(sched)
|
||||||
|
|
|
||||||
|
|
@ -865,7 +865,7 @@ class TestIdxUpcast(unittest.TestCase):
|
||||||
for src in ast.src:
|
for src in ast.src:
|
||||||
if (ret:=self._find_op(src, op)) is not None: return ret
|
if (ret:=self._find_op(src, op)) is not None: return ret
|
||||||
def _schedule_render(self, a: Tensor):
|
def _schedule_render(self, a: Tensor):
|
||||||
schedule, _ = a.schedule_with_vars()
|
schedule, _buffer_map, _var_vals = a.schedule_with_vars()
|
||||||
for s in schedule:
|
for s in schedule:
|
||||||
if s.ast.op is Ops.SINK:
|
if s.ast.op is Ops.SINK:
|
||||||
renderer = Device[s.bufs[0].device].renderer
|
renderer = Device[s.bufs[0].device].renderer
|
||||||
|
|
|
||||||
|
|
@ -481,7 +481,7 @@ class TestUOpMethod(unittest.TestCase):
|
||||||
a = UOp.variable("a", 1, 10)
|
a = UOp.variable("a", 1, 10)
|
||||||
uop_var = Tensor(a.bind(1))
|
uop_var = Tensor(a.bind(1))
|
||||||
st_var = Tensor.empty((2, 10))[:, :a.bind(1)]
|
st_var = Tensor.empty((2, 10))[:, :a.bind(1)]
|
||||||
_, var_vals = (uop_var+st_var).schedule_with_vars()
|
_, _, var_vals = (uop_var+st_var).schedule_with_vars()
|
||||||
self.assertEqual(len(var_vals), 1)
|
self.assertEqual(len(var_vals), 1)
|
||||||
self.assertEqual(list(var_vals)[0], a.expr)
|
self.assertEqual(list(var_vals)[0], a.expr)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ class TestRingAllReduce(unittest.TestCase):
|
||||||
N = 4
|
N = 4
|
||||||
ds = tuple(f"CPU:{i}" for i in range(N))
|
ds = tuple(f"CPU:{i}" for i in range(N))
|
||||||
t = Tensor.empty(N, N*100).shard(ds, axis=0).realize()
|
t = Tensor.empty(N, N*100).shard(ds, axis=0).realize()
|
||||||
schedules = t.sum(0).schedule_with_vars()[0]
|
schedules, _, _ = t.sum(0).schedule_with_vars()
|
||||||
copies = [si for si in schedules if si.ast.op is Ops.COPY]
|
copies = [si for si in schedules if si.ast.op is Ops.COPY]
|
||||||
pairs = [(c.bufs[0].device, c.bufs[1].device) for c in copies]
|
pairs = [(c.bufs[0].device, c.bufs[1].device) for c in copies]
|
||||||
# N*(N-1) scatter reduce, and N*(N-1) allgather
|
# N*(N-1) scatter reduce, and N*(N-1) allgather
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,7 @@ class TestScheduleCache(unittest.TestCase):
|
||||||
x = Tensor.ones(10).contiguous().realize()
|
x = Tensor.ones(10).contiguous().realize()
|
||||||
|
|
||||||
t = x + Tensor(v.bind(42))
|
t = x + Tensor(v.bind(42))
|
||||||
_, var_vals = t.schedule_with_vars()
|
_, _, var_vals = t.schedule_with_vars()
|
||||||
self.assertEqual(var_vals, {'pos': 42})
|
self.assertEqual(var_vals, {'pos': 42})
|
||||||
|
|
||||||
def test_simple(self):
|
def test_simple(self):
|
||||||
|
|
|
||||||
170
tinygrad/engine/execution.py
Normal file
170
tinygrad/engine/execution.py
Normal file
|
|
@ -0,0 +1,170 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
from typing import Any, cast
|
||||||
|
from tinygrad.helpers import DEBUG, GlobalCounters, all_same, colored, ansilen, PROFILE, ProfilePointEvent, cpu_events, time_to_str, TRACEMETA
|
||||||
|
from tinygrad.uop.ops import UOp, Ops, sym_infer
|
||||||
|
from tinygrad.device import Device, Buffer
|
||||||
|
|
||||||
|
# **************** ExecutionUnit ****************
|
||||||
|
|
||||||
|
class ExecutionUnit:
|
||||||
|
"""
|
||||||
|
A bound, ready-to-execute unit. Replaces CapturedJit.
|
||||||
|
|
||||||
|
Takes ExecItems and binds them to real device resources on execution.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, items: list):
|
||||||
|
"""
|
||||||
|
Create an ExecutionUnit from ExecItems.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
items: ExecItems (with bufs as Buffers or UOps, lib optionally set)
|
||||||
|
"""
|
||||||
|
from tinygrad.engine.realize import ExecItem
|
||||||
|
|
||||||
|
self.items: list[ExecItem] = items
|
||||||
|
self.buffer_map: dict[UOp, Buffer] = {}
|
||||||
|
self.var_vals: dict[str, int] = {}
|
||||||
|
|
||||||
|
# Create bound items with runners - lazy, done on first call
|
||||||
|
self._bound_items: list[tuple[Any, list[Buffer], tuple, dict[str, int]]]|None = None
|
||||||
|
self._graphs: list|None = None
|
||||||
|
self._first_run = True
|
||||||
|
|
||||||
|
def _bind(self):
|
||||||
|
"""Create runners from lib and bind buffers."""
|
||||||
|
from tinygrad.engine.realize import CompiledRunner, BufferCopy, BufferXfer, ViewOp, EncDec, get_runner, get_program
|
||||||
|
|
||||||
|
self._bound_items = []
|
||||||
|
for item in self.items:
|
||||||
|
# Get buffers - prefer buf_uops with buffer_map, fall back to bufs for backwards compatibility
|
||||||
|
bufs: list[Buffer] = []
|
||||||
|
if item.buf_uops:
|
||||||
|
for uop in item.buf_uops:
|
||||||
|
bufs.append(cast(Buffer, self.buffer_map.get(uop) or uop.buffer))
|
||||||
|
else:
|
||||||
|
for buf in item.bufs:
|
||||||
|
if buf is not None: bufs.append(buf)
|
||||||
|
|
||||||
|
# Create runner from lib or use existing prg
|
||||||
|
if item.prg is not None:
|
||||||
|
runner = item.prg
|
||||||
|
elif item.ast.op is Ops.SINK:
|
||||||
|
device = bufs[0].device
|
||||||
|
if item.lib is not None:
|
||||||
|
# Create runner from cached lib
|
||||||
|
prg = get_program(item.ast, Device[device].renderer)
|
||||||
|
runner = CompiledRunner(prg, item.lib)
|
||||||
|
else:
|
||||||
|
# Compile and create runner
|
||||||
|
runner = get_runner(device, item.ast)
|
||||||
|
elif item.ast.op is Ops.BUFFER_VIEW:
|
||||||
|
runner = ViewOp(bufs[0])
|
||||||
|
elif item.ast.op is Ops.COPY:
|
||||||
|
if hasattr(Device[bufs[0].device].allocator, '_transfer') and all_same([x.device.split(":")[0] for x in bufs]):
|
||||||
|
runner = BufferXfer(bufs[0].nbytes, bufs[0].device, bufs[1].device)
|
||||||
|
else:
|
||||||
|
runner = BufferCopy(bufs[0].nbytes, bufs[0].device, bufs[1].device)
|
||||||
|
elif item.ast.op is Ops.ENCDEC:
|
||||||
|
runner = EncDec(item.ast, bufs[0].nbytes, bufs[1].device)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"unknown op {item.ast.op}")
|
||||||
|
|
||||||
|
self._bound_items.append((runner, bufs, item.metadata, item.fixedvars))
|
||||||
|
|
||||||
|
def update(self, buffers: dict[UOp, Buffer]|None = None, var_vals: dict[str, int]|None = None) -> ExecutionUnit:
|
||||||
|
"""Update buffer mapping and/or var_vals for next run. Returns self for chaining."""
|
||||||
|
if buffers is not None:
|
||||||
|
self.buffer_map.update(buffers)
|
||||||
|
# Need to rebind if we update buffers
|
||||||
|
self._bound_items = None
|
||||||
|
if var_vals is not None:
|
||||||
|
self.var_vals.update(var_vals)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __call__(self, var_vals: dict[str, int]|None = None, wait=False, do_update_stats=True, jit=False) -> float|None:
|
||||||
|
"""Execute all items."""
|
||||||
|
from tinygrad.engine.realize import CompiledRunner
|
||||||
|
|
||||||
|
if var_vals is not None:
|
||||||
|
self.var_vals.update(var_vals)
|
||||||
|
|
||||||
|
# Lazy bind on first call
|
||||||
|
if self._bound_items is None:
|
||||||
|
self._bind()
|
||||||
|
|
||||||
|
assert self._bound_items is not None
|
||||||
|
|
||||||
|
# TODO: create graphs on first run
|
||||||
|
# if self._first_run:
|
||||||
|
# self._create_graphs()
|
||||||
|
# self._first_run = False
|
||||||
|
|
||||||
|
# Execute all items
|
||||||
|
total_et = 0.0
|
||||||
|
for runner, bufs, metadata, fixedvars in self._bound_items:
|
||||||
|
merged_var_vals = self.var_vals | fixedvars
|
||||||
|
|
||||||
|
# Ensure buffers are allocated (skip if jit - already allocated)
|
||||||
|
if not jit:
|
||||||
|
for b in bufs:
|
||||||
|
b.ensure_allocated()
|
||||||
|
|
||||||
|
# Reorder bufs to match program globals if needed
|
||||||
|
if isinstance(runner, CompiledRunner):
|
||||||
|
ordered_bufs = [bufs[i] for i in runner.p.globals]
|
||||||
|
else:
|
||||||
|
ordered_bufs = bufs
|
||||||
|
|
||||||
|
# PROFILE events
|
||||||
|
if PROFILE:
|
||||||
|
payload = {"metadata":metadata, "var_vals":merged_var_vals, "bufs":[b.trace_num for b in ordered_bufs], "name":runner.display_name}
|
||||||
|
payload["outputs"], payload["inputs"] = (runner.p.outs, runner.p.ins) if isinstance(runner, CompiledRunner) else ([0], [1])
|
||||||
|
cpu_events.append(ProfilePointEvent(runner.device, "exec", len(cpu_events), payload))
|
||||||
|
|
||||||
|
et = runner(ordered_bufs, merged_var_vals, wait=wait or DEBUG >= 2)
|
||||||
|
if et is not None:
|
||||||
|
total_et += et
|
||||||
|
|
||||||
|
# Update stats
|
||||||
|
if do_update_stats:
|
||||||
|
GlobalCounters.kernel_count += 1
|
||||||
|
op_est = sym_infer(runner.estimates.ops, merged_var_vals)
|
||||||
|
mem_est = sym_infer(runner.estimates.mem, merged_var_vals)
|
||||||
|
GlobalCounters.global_ops += op_est
|
||||||
|
GlobalCounters.global_mem += mem_est
|
||||||
|
if et is not None:
|
||||||
|
GlobalCounters.time_sum_s += et
|
||||||
|
if DEBUG >= 2:
|
||||||
|
lds_est = sym_infer(runner.estimates.lds, merged_var_vals)
|
||||||
|
mem_est = min(mem_est, lds_est) # there can't be more memory accessed than loads/stores
|
||||||
|
header_color = 'magenta' if jit else ('green' if runner.first_run else None)
|
||||||
|
ptm = colored(time_to_str(et, w=9), "yellow" if et > 0.01 else None) if et is not None else ""
|
||||||
|
flops, membw, ldsbw = op_est/(et or 1e-20), mem_est/(et or 1e-20), lds_est/(et or 1e-20)
|
||||||
|
flops_str = f"{flops*1e-9:7.0f} GFLOPS" if flops < 1e14 else colored(f"{flops*1e-12:7.0f} TFLOPS", 'green')
|
||||||
|
mem_str = f"{membw*1e-9:4.0f}|{ldsbw*1e-9:<6.0f} GB/s" if membw < 1e13 and ldsbw < 1e15 else \
|
||||||
|
colored(f"{membw*1e-12:4.0f}|{ldsbw*1e-12:<6.0f} TB/s", 'green')
|
||||||
|
print(f"{colored(f'*** {runner.device[:7]:7s} {GlobalCounters.kernel_count:4d}', header_color)}"+
|
||||||
|
f" {runner.display_name+' '*(46-ansilen(runner.display_name))} arg {len(ordered_bufs):2d} mem {GlobalCounters.mem_used/1e9:6.2f} GB"+
|
||||||
|
("" if et is None else f" tm {ptm}/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({flops_str} {mem_str})")+
|
||||||
|
f" {[repr(m) if TRACEMETA >= 2 else str(m) for m in metadata] if metadata else ''}")
|
||||||
|
runner.first_run = False
|
||||||
|
|
||||||
|
return total_et if wait else None
|
||||||
|
|
||||||
|
def __add__(self, other: ExecutionUnit) -> ExecutionUnit:
|
||||||
|
"""Combine two ExecutionUnits, rebuild graph lazily."""
|
||||||
|
combined = ExecutionUnit(self.items + other.items)
|
||||||
|
combined.buffer_map = {**self.buffer_map, **other.buffer_map}
|
||||||
|
combined.var_vals = {**self.var_vals, **other.var_vals}
|
||||||
|
return combined
|
||||||
|
|
||||||
|
def free_intermediates(self):
|
||||||
|
"""Deallocate internal buffers."""
|
||||||
|
for buf in self.buffer_map.values():
|
||||||
|
if buf.is_allocated():
|
||||||
|
buf.deallocate()
|
||||||
|
# Reset bound state
|
||||||
|
self._bound_items = None
|
||||||
|
self._graphs = None
|
||||||
|
self._first_run = True
|
||||||
|
|
@ -195,6 +195,8 @@ class CapturedJit(Generic[ReturnType]):
|
||||||
|
|
||||||
# jit exec
|
# jit exec
|
||||||
def __call__(self, input_buffers:list[Buffer], var_vals:dict[str, int]) -> ReturnType:
|
def __call__(self, input_buffers:list[Buffer], var_vals:dict[str, int]) -> ReturnType:
|
||||||
|
from tinygrad.engine.execution import ExecutionUnit
|
||||||
|
|
||||||
# assign inputs
|
# assign inputs
|
||||||
for idx, offset, device, size, dtype in self.extra_view_inputs:
|
for idx, offset, device, size, dtype in self.extra_view_inputs:
|
||||||
input_buffers.append(Buffer(device, size, dtype, base=input_buffers[idx], offset=offset).ensure_allocated())
|
input_buffers.append(Buffer(device, size, dtype, base=input_buffers[idx], offset=offset).ensure_allocated())
|
||||||
|
|
@ -213,7 +215,8 @@ class CapturedJit(Generic[ReturnType]):
|
||||||
self._first_run = False
|
self._first_run = False
|
||||||
|
|
||||||
if DEBUG >= 1 and len(self._jit_cache) >= 10: print(f"jit execs {len(self._jit_cache)} kernels")
|
if DEBUG >= 1 and len(self._jit_cache) >= 10: print(f"jit execs {len(self._jit_cache)} kernels")
|
||||||
for ei in self._jit_cache: ei.run(var_vals, jit=True)
|
# Use ExecutionUnit for execution
|
||||||
|
ExecutionUnit(self._jit_cache).update(var_vals=var_vals)(jit=True)
|
||||||
self._clear_inputs()
|
self._clear_inputs()
|
||||||
return self.ret
|
return self.ret
|
||||||
|
|
||||||
|
|
@ -251,7 +254,7 @@ class TinyJit(Generic[ReturnType]):
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def add(self, ei:ExecItem):
|
def add(self, ei:ExecItem):
|
||||||
self._jit_cache.append(ExecItem(ei.ast, [self.add_buffer(buf) for buf in ei.bufs if buf is not None], ei.metadata, ei.fixedvars, ei.prg))
|
self._jit_cache.append(ExecItem(ei.ast, [self.add_buffer(buf) for buf in ei.bufs if buf is not None], ei.metadata, ei.fixedvars, ei.lib, ei.prg))
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
assert self.fxn is not None, "can't reset without function"
|
assert self.fxn is not None, "can't reset without function"
|
||||||
|
|
@ -308,14 +311,16 @@ class TinyJit(Generic[ReturnType]):
|
||||||
|
|
||||||
# prune independent kernels (optional)
|
# prune independent kernels (optional)
|
||||||
if self.prune:
|
if self.prune:
|
||||||
|
from tinygrad.engine.execution import ExecutionUnit
|
||||||
depends = set(input_buffers)
|
depends = set(input_buffers)
|
||||||
update_depends(depends, jit_cache)
|
update_depends(depends, jit_cache)
|
||||||
pruned, onetime = partition(jit_cache, lambda ei: any(b in depends for b in get_out_buffers_for_ei(ei)))
|
pruned, onetime = partition(jit_cache, lambda ei: any(b in depends for b in get_out_buffers_for_ei(ei)))
|
||||||
if DEBUG >= 1: print(f"pruned from {len(jit_cache)} -> {len(pruned)} kernels")
|
if DEBUG >= 1: print(f"pruned from {len(jit_cache)} -> {len(pruned)} kernels")
|
||||||
# run the onetime kernels here
|
# run the onetime kernels here
|
||||||
for ei in onetime:
|
if onetime:
|
||||||
for b in ei.bufs: cast(Buffer, b).ensure_allocated()
|
for ei in onetime:
|
||||||
ei.run(var_vals, jit=True)
|
for b in ei.bufs: cast(Buffer, b).ensure_allocated()
|
||||||
|
ExecutionUnit(onetime).update(var_vals=var_vals)(jit=True)
|
||||||
jit_cache = pruned
|
jit_cache = pruned
|
||||||
|
|
||||||
# memory planning (optional)
|
# memory planning (optional)
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ from collections import defaultdict
|
||||||
from tinygrad.engine.realize import ExecItem
|
from tinygrad.engine.realize import ExecItem
|
||||||
from tinygrad.device import Device, Buffer
|
from tinygrad.device import Device, Buffer
|
||||||
from tinygrad.helpers import NO_MEMORY_PLANNER, dedup, DEBUG, round_up
|
from tinygrad.helpers import NO_MEMORY_PLANNER, dedup, DEBUG, round_up
|
||||||
from tinygrad.uop.ops import Ops
|
from tinygrad.uop.ops import Ops, UOp
|
||||||
from tinygrad.dtype import dtypes, ImageDType
|
from tinygrad.dtype import dtypes, ImageDType
|
||||||
from tinygrad.runtime.support.memory import TLSFAllocator
|
from tinygrad.runtime.support.memory import TLSFAllocator
|
||||||
|
|
||||||
|
|
@ -63,8 +63,16 @@ def _internal_memory_planner(buffers:list[list[Buffer]], noopt_buffers=None, ign
|
||||||
|
|
||||||
return assigned
|
return assigned
|
||||||
|
|
||||||
def memory_planner(schedule:list[ExecItem]) -> list[ExecItem]:
|
def memory_planner(schedule:list[ExecItem]) -> tuple[list[ExecItem], dict[UOp, Buffer]]:
|
||||||
# Exclude buffers involved in load ops (e.g transfers) to preserve parallelism in graphs.
|
# Exclude buffers involved in load ops (e.g transfers) to preserve parallelism in graphs.
|
||||||
assigned = _internal_memory_planner([[b for b in si.bufs if b is not None] for si in schedule],
|
assigned = _internal_memory_planner([[b for b in si.bufs if b is not None] for si in schedule],
|
||||||
noopt_buffers={b for si in schedule if si.ast.op is not Ops.SINK for b in si.bufs if b is not None})
|
noopt_buffers={b for si in schedule if si.ast.op is not Ops.SINK for b in si.bufs if b is not None})
|
||||||
return [ExecItem(si.ast, [assigned.get(x, x) if x is not None else None for x in si.bufs], si.metadata, si.fixedvars) for si in schedule]
|
new_schedule = [ExecItem(si.ast, [assigned.get(x, x) if x is not None else None for x in si.bufs], si.metadata, si.fixedvars, buf_uops=si.buf_uops)
|
||||||
|
for si in schedule]
|
||||||
|
# Build buffer_map from buf_uops -> assigned buffers
|
||||||
|
buffer_map: dict[UOp, Buffer] = {}
|
||||||
|
for si in new_schedule:
|
||||||
|
for i, uop in enumerate(si.buf_uops):
|
||||||
|
if uop not in buffer_map and i < len(si.bufs) and si.bufs[i] is not None:
|
||||||
|
buffer_map[uop] = si.bufs[i] # type: ignore
|
||||||
|
return new_schedule, buffer_map
|
||||||
|
|
|
||||||
|
|
@ -183,15 +183,21 @@ si_lowerer = PatternMatcher([
|
||||||
@dataclass
|
@dataclass
|
||||||
class ExecItem:
|
class ExecItem:
|
||||||
ast: UOp
|
ast: UOp
|
||||||
bufs: list[Buffer|None] = field(default_factory=list)
|
bufs: list[Buffer|None] = field(default_factory=list) # TODO: deprecate, use buf_uops + buffer_map
|
||||||
metadata: tuple[Metadata, ...] = ()
|
metadata: tuple[Metadata, ...] = ()
|
||||||
fixedvars: dict[str, int] = field(default_factory=dict)
|
fixedvars: dict[str, int] = field(default_factory=dict)
|
||||||
|
lib: bytes|None = None # compiled binary, None for COPY/VIEW/ENCDEC
|
||||||
prg: Runner|None = None
|
prg: Runner|None = None
|
||||||
|
buf_uops: tuple[UOp, ...] = () # buffer UOps, binding happens in ExecutionUnit
|
||||||
|
|
||||||
def lower(self):
|
def lower(self):
|
||||||
"""Populate self.prg by lowering the AST."""
|
"""Populate self.prg and self.lib by lowering the AST."""
|
||||||
if self.prg is not None: return self
|
if self.prg is not None: return self
|
||||||
try: self.prg = cast(Runner, si_lowerer.rewrite(self.ast, self.bufs))
|
try:
|
||||||
|
self.prg = cast(Runner, si_lowerer.rewrite(self.ast, self.bufs))
|
||||||
|
# Store lib for SINK ops (compiled kernels)
|
||||||
|
if isinstance(self.prg, CompiledRunner):
|
||||||
|
self.lib = self.prg.lib
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if DEBUG >= 2:
|
if DEBUG >= 2:
|
||||||
print(f"error lowering {self.ast.op}")
|
print(f"error lowering {self.ast.op}")
|
||||||
|
|
@ -237,24 +243,39 @@ class ExecItem:
|
||||||
|
|
||||||
capturing: list = [] # put classes with an add method in here
|
capturing: list = [] # put classes with an add method in here
|
||||||
|
|
||||||
def run_schedule(schedule:list[ExecItem], var_vals:dict[str, int]|None=None, do_update_stats=True):
|
def run_schedule(schedule:list[ExecItem], buffer_map:dict[UOp, Buffer]|None=None, var_vals:dict[str, int]|None=None, do_update_stats=True):
|
||||||
while len(schedule):
|
from tinygrad.engine.execution import ExecutionUnit
|
||||||
ei = schedule.pop(0).lower()
|
if buffer_map is None: buffer_map = {}
|
||||||
|
|
||||||
|
# Lower all items first
|
||||||
|
lowered: list[ExecItem] = []
|
||||||
|
for ei in schedule:
|
||||||
|
ei = ei.lower()
|
||||||
if len(capturing) and CAPTURING: capturing[0].add(ei)
|
if len(capturing) and CAPTURING: capturing[0].add(ei)
|
||||||
if VALIDATE_WITH_CPU and ei.ast.op is Ops.SINK:
|
lowered.append(ei)
|
||||||
# copy in allocated buffers from the GPU
|
|
||||||
bufs = [b for b in ei.bufs if b is not None]
|
|
||||||
nb: list[Buffer|None] = [Buffer("CPU", b.size, b.dtype) for b in bufs]
|
|
||||||
for cpu_b, gpu_b in zip(nb, bufs):
|
|
||||||
if cpu_b is not None and gpu_b.is_allocated(): cpu_b.ensure_allocated().copyin(gpu_b.as_buffer())
|
|
||||||
|
|
||||||
# run on GPU
|
if VALIDATE_WITH_CPU:
|
||||||
ei.run(var_vals, do_update_stats=do_update_stats)
|
# Run item by item with CPU validation
|
||||||
|
for ei in lowered:
|
||||||
|
if ei.ast.op is Ops.SINK:
|
||||||
|
# copy in allocated buffers from the GPU
|
||||||
|
bufs = [b for b in ei.bufs if b is not None]
|
||||||
|
nb: list[Buffer|None] = [Buffer("CPU", b.size, b.dtype) for b in bufs]
|
||||||
|
for cpu_b, gpu_b in zip(nb, bufs):
|
||||||
|
if cpu_b is not None and gpu_b.is_allocated(): cpu_b.ensure_allocated().copyin(gpu_b.as_buffer())
|
||||||
|
|
||||||
# validate the output buffers match (NOTE: this is assuming the output is buffer 0)
|
# run on GPU
|
||||||
with Context(BEAM=0): ExecItem(ei.ast, nb, ei.metadata, ei.fixedvars).run(var_vals, do_update_stats=do_update_stats)
|
ExecutionUnit([ei]).update(buffers=buffer_map, var_vals=var_vals)(do_update_stats=do_update_stats)
|
||||||
import numpy as np
|
|
||||||
assert nb[0] is not None
|
# validate the output buffers match (NOTE: this is assuming the output is buffer 0)
|
||||||
np.testing.assert_allclose(bufs[0].numpy(), nb[0].numpy(), rtol=1e-3, atol=1e-3)
|
with Context(BEAM=0):
|
||||||
else:
|
ExecutionUnit([ExecItem(ei.ast, nb, ei.metadata, ei.fixedvars)]).update(var_vals=var_vals)(do_update_stats=do_update_stats)
|
||||||
ei.run(var_vals, do_update_stats=do_update_stats)
|
import numpy as np
|
||||||
|
assert nb[0] is not None
|
||||||
|
np.testing.assert_allclose(bufs[0].numpy(), nb[0].numpy(), rtol=1e-3, atol=1e-3)
|
||||||
|
else:
|
||||||
|
ExecutionUnit([ei]).update(buffers=buffer_map, var_vals=var_vals)(do_update_stats=do_update_stats)
|
||||||
|
else:
|
||||||
|
# Use ExecutionUnit for batched execution
|
||||||
|
if lowered:
|
||||||
|
ExecutionUnit(lowered).update(buffers=buffer_map, var_vals=var_vals)(do_update_stats=do_update_stats)
|
||||||
|
|
|
||||||
|
|
@ -125,7 +125,7 @@ pm_post_sched_cache = PatternMatcher([
|
||||||
|
|
||||||
schedule_cache: dict[bytes, tuple[list[ExecItem], UOp]] = {}
|
schedule_cache: dict[bytes, tuple[list[ExecItem], UOp]] = {}
|
||||||
@track_rewrites(lambda _,ret: f"Schedule {pluralize('Kernel', len(ret[1]))}")
|
@track_rewrites(lambda _,ret: f"Schedule {pluralize('Kernel', len(ret[1]))}")
|
||||||
def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], list[ExecItem], dict[str, int]]:
|
def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], list[ExecItem], dict[UOp, Buffer], dict[str, int]]:
|
||||||
# big_sink srcs are all the Tensors
|
# big_sink srcs are all the Tensors
|
||||||
st = time.perf_counter()
|
st = time.perf_counter()
|
||||||
|
|
||||||
|
|
@ -185,11 +185,11 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
|
||||||
assert all(isinstance(x, MultiBuffer) for x in ubufs), "kernel must all be multibuffer"
|
assert all(isinstance(x, MultiBuffer) for x in ubufs), "kernel must all be multibuffer"
|
||||||
dnums = [x for x in si.ast.variables() if x.arg[0] == '_device_num']
|
dnums = [x for x in si.ast.variables() if x.arg[0] == '_device_num']
|
||||||
for j, bufs in enumerate(zip(*[x.bufs for x in cast(tuple[MultiBuffer, ...], ubufs)])):
|
for j, bufs in enumerate(zip(*[x.bufs for x in cast(tuple[MultiBuffer, ...], ubufs)])):
|
||||||
schedule.append(ExecItem(si.ast, list(bufs), si.metadata, si.fixedvars | ({dnums[0].expr:j} if len(dnums) else {})))
|
schedule.append(ExecItem(si.ast, list(bufs), si.metadata, si.fixedvars | ({dnums[0].expr:j} if len(dnums) else {}), buf_uops=buf_uops))
|
||||||
else:
|
else:
|
||||||
# ONE -> ONE
|
# ONE -> ONE
|
||||||
schedule.append(ExecItem(si.ast, list(ubufs), si.metadata, si.fixedvars))
|
schedule.append(ExecItem(si.ast, list(ubufs), si.metadata, si.fixedvars, buf_uops=buf_uops))
|
||||||
with cpu_profile(TracingKey("memory planner")): schedule = memory_planner(schedule)
|
with cpu_profile(TracingKey("memory planner")): schedule, buffer_map = memory_planner(schedule)
|
||||||
|
|
||||||
# extract var_vals from BINDs that were stripped (only if there are kernels)
|
# extract var_vals from BINDs that were stripped (only if there are kernels)
|
||||||
var_vals: dict[str, int] = {}
|
var_vals: dict[str, int] = {}
|
||||||
|
|
@ -204,4 +204,4 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
|
||||||
print(f"scheduled {len(schedule):4d} kernels in {(time.perf_counter()-st)*1000:8.2f} ms"+\
|
print(f"scheduled {len(schedule):4d} kernels in {(time.perf_counter()-st)*1000:8.2f} ms"+\
|
||||||
f" | {' cache hit' if sc_ret is not None else 'CACHE MISS'} {sched_cache_key.hex()[:8]}"+\
|
f" | {' cache hit' if sc_ret is not None else 'CACHE MISS'} {sched_cache_key.hex()[:8]}"+\
|
||||||
f" | {len(UOpMetaClass.ucache)} uops in cache")
|
f" | {len(UOpMetaClass.ucache)} uops in cache")
|
||||||
return tensor_map, schedule, var_vals
|
return tensor_map, schedule, buffer_map, var_vals
|
||||||
|
|
|
||||||
|
|
@ -237,7 +237,7 @@ class Tensor(OpMixin):
|
||||||
"""
|
"""
|
||||||
return [Tensor(u) for u in UOp.custom_kernel(*[t.uop for t in (self,)+lst], fxn=fxn, grad_fxn=grad_fxn)]
|
return [Tensor(u) for u in UOp.custom_kernel(*[t.uop for t in (self,)+lst], fxn=fxn, grad_fxn=grad_fxn)]
|
||||||
|
|
||||||
def schedule_with_vars(self, *lst:Tensor) -> tuple[list[ExecItem], dict[str, int]]:
|
def schedule_with_vars(self, *lst:Tensor) -> tuple[list[ExecItem], dict[UOp, Buffer], dict[str, int]]:
|
||||||
"""
|
"""
|
||||||
Creates the schedule needed to realize these Tensor(s), with Variables.
|
Creates the schedule needed to realize these Tensor(s), with Variables.
|
||||||
|
|
||||||
|
|
@ -246,13 +246,13 @@ class Tensor(OpMixin):
|
||||||
big_sink = UOp.sink(*[x.uop for x in (self,)+lst])
|
big_sink = UOp.sink(*[x.uop for x in (self,)+lst])
|
||||||
|
|
||||||
# this is where the schedule cache should go
|
# this is where the schedule cache should go
|
||||||
becomes_map, schedule, var_vals = complete_create_schedule_with_vars(big_sink)
|
becomes_map, schedule, buffer_map, var_vals = complete_create_schedule_with_vars(big_sink)
|
||||||
_apply_map_to_tensors(becomes_map, name="Apply Schedule Map")
|
_apply_map_to_tensors(becomes_map, name="Apply Schedule Map")
|
||||||
return schedule, var_vals
|
return schedule, buffer_map, var_vals
|
||||||
|
|
||||||
def schedule(self, *lst:Tensor) -> list[ExecItem]:
|
def schedule(self, *lst:Tensor) -> list[ExecItem]:
|
||||||
"""Creates the schedule needed to realize these Tensor(s)."""
|
"""Creates the schedule needed to realize these Tensor(s)."""
|
||||||
schedule, var_vals = self.schedule_with_vars(*lst)
|
schedule, _buffer_map, var_vals = self.schedule_with_vars(*lst)
|
||||||
assert len(var_vals) == 0
|
assert len(var_vals) == 0
|
||||||
return schedule
|
return schedule
|
||||||
|
|
||||||
|
|
@ -260,7 +260,8 @@ class Tensor(OpMixin):
|
||||||
def realize(self, *lst:Tensor, do_update_stats=True) -> Tensor:
|
def realize(self, *lst:Tensor, do_update_stats=True) -> Tensor:
|
||||||
"""Triggers the computation needed to create these Tensor(s)."""
|
"""Triggers the computation needed to create these Tensor(s)."""
|
||||||
if len(to_realize:=[x for x in (self,)+lst if not x.uop.is_contiguous()]):
|
if len(to_realize:=[x for x in (self,)+lst if not x.uop.is_contiguous()]):
|
||||||
run_schedule(*Tensor.schedule_with_vars(*to_realize), do_update_stats=do_update_stats)
|
schedule, buffer_map, var_vals = Tensor.schedule_with_vars(*to_realize)
|
||||||
|
run_schedule(schedule, buffer_map, var_vals, do_update_stats=do_update_stats)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def replace(self, x:Tensor, allow_shape_mismatch=False) -> Tensor:
|
def replace(self, x:Tensor, allow_shape_mismatch=False) -> Tensor:
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue