mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
4 commits
master
...
execution_
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
43c033e5e5 | ||
|
|
76f2f14233 | ||
|
|
5515de6553 | ||
|
|
439c4319ec |
13 changed files with 257 additions and 51 deletions
|
|
@ -45,9 +45,10 @@ print(f"The schedule contains {len(schedule)} items.")
|
|||
for si in schedule: print(str(si)[:80])
|
||||
|
||||
# *****
|
||||
# 4. Lower and run the schedule.
|
||||
# 4. Run the schedule.
|
||||
|
||||
for si in tqdm(schedule): si.run()
|
||||
from tinygrad.engine.realize import run_schedule
|
||||
run_schedule(schedule)
|
||||
|
||||
# *****
|
||||
# 5. Print the weight change
|
||||
|
|
|
|||
4
test/external/external_benchmark_resnet.py
vendored
4
test/external/external_benchmark_resnet.py
vendored
|
|
@ -71,8 +71,8 @@ class BenchmarkResnetTrain(unittest.TestCase):
|
|||
|
||||
y = x.sequential(layer).contiguous().contiguous_backward()
|
||||
y.sum().backward()
|
||||
if getenv("ASSIGN", 1): sched, _ = Tensor.schedule_with_vars(y, x.grad, *optim.schedule_step())
|
||||
else: sched, _ = Tensor.schedule_with_vars(y, x.grad, *[t.grad for t in optim.params])
|
||||
if getenv("ASSIGN", 1): sched, _, _ = Tensor.schedule_with_vars(y, x.grad, *optim.schedule_step())
|
||||
else: sched, _, _ = Tensor.schedule_with_vars(y, x.grad, *[t.grad for t in optim.params])
|
||||
|
||||
for _ in range(JITCNT):
|
||||
run_schedule(list(sched))
|
||||
|
|
|
|||
|
|
@ -49,8 +49,8 @@ class BenchmarkBertTrain(unittest.TestCase):
|
|||
|
||||
y = layer(*inputs).contiguous().contiguous_backward()
|
||||
y.sum().backward()
|
||||
if getenv("ASSIGN", 1): sched, _ = Tensor.schedule_with_vars(y, *list(inputs), *optim.schedule_step())
|
||||
else: sched, _ = Tensor.schedule_with_vars(y, *list(inputs), *[t.grad for t in optim.params])
|
||||
if getenv("ASSIGN", 1): sched, _, _ = Tensor.schedule_with_vars(y, *list(inputs), *optim.schedule_step())
|
||||
else: sched, _, _ = Tensor.schedule_with_vars(y, *list(inputs), *[t.grad for t in optim.params])
|
||||
|
||||
for _ in range(JITCNT):
|
||||
run_schedule(sched)
|
||||
|
|
|
|||
|
|
@ -865,7 +865,7 @@ class TestIdxUpcast(unittest.TestCase):
|
|||
for src in ast.src:
|
||||
if (ret:=self._find_op(src, op)) is not None: return ret
|
||||
def _schedule_render(self, a: Tensor):
|
||||
schedule, _ = a.schedule_with_vars()
|
||||
schedule, _buffer_map, _var_vals = a.schedule_with_vars()
|
||||
for s in schedule:
|
||||
if s.ast.op is Ops.SINK:
|
||||
renderer = Device[s.bufs[0].device].renderer
|
||||
|
|
|
|||
|
|
@ -481,7 +481,7 @@ class TestUOpMethod(unittest.TestCase):
|
|||
a = UOp.variable("a", 1, 10)
|
||||
uop_var = Tensor(a.bind(1))
|
||||
st_var = Tensor.empty((2, 10))[:, :a.bind(1)]
|
||||
_, var_vals = (uop_var+st_var).schedule_with_vars()
|
||||
_, _, var_vals = (uop_var+st_var).schedule_with_vars()
|
||||
self.assertEqual(len(var_vals), 1)
|
||||
self.assertEqual(list(var_vals)[0], a.expr)
|
||||
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ class TestRingAllReduce(unittest.TestCase):
|
|||
N = 4
|
||||
ds = tuple(f"CPU:{i}" for i in range(N))
|
||||
t = Tensor.empty(N, N*100).shard(ds, axis=0).realize()
|
||||
schedules = t.sum(0).schedule_with_vars()[0]
|
||||
schedules, _, _ = t.sum(0).schedule_with_vars()
|
||||
copies = [si for si in schedules if si.ast.op is Ops.COPY]
|
||||
pairs = [(c.bufs[0].device, c.bufs[1].device) for c in copies]
|
||||
# N*(N-1) scatter reduce, and N*(N-1) allgather
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ class TestScheduleCache(unittest.TestCase):
|
|||
x = Tensor.ones(10).contiguous().realize()
|
||||
|
||||
t = x + Tensor(v.bind(42))
|
||||
_, var_vals = t.schedule_with_vars()
|
||||
_, _, var_vals = t.schedule_with_vars()
|
||||
self.assertEqual(var_vals, {'pos': 42})
|
||||
|
||||
def test_simple(self):
|
||||
|
|
|
|||
170
tinygrad/engine/execution.py
Normal file
170
tinygrad/engine/execution.py
Normal file
|
|
@ -0,0 +1,170 @@
|
|||
from __future__ import annotations
|
||||
from typing import Any, cast
|
||||
from tinygrad.helpers import DEBUG, GlobalCounters, all_same, colored, ansilen, PROFILE, ProfilePointEvent, cpu_events, time_to_str, TRACEMETA
|
||||
from tinygrad.uop.ops import UOp, Ops, sym_infer
|
||||
from tinygrad.device import Device, Buffer
|
||||
|
||||
# **************** ExecutionUnit ****************
|
||||
|
||||
class ExecutionUnit:
|
||||
"""
|
||||
A bound, ready-to-execute unit. Replaces CapturedJit.
|
||||
|
||||
Takes ExecItems and binds them to real device resources on execution.
|
||||
"""
|
||||
|
||||
def __init__(self, items: list):
|
||||
"""
|
||||
Create an ExecutionUnit from ExecItems.
|
||||
|
||||
Args:
|
||||
items: ExecItems (with bufs as Buffers or UOps, lib optionally set)
|
||||
"""
|
||||
from tinygrad.engine.realize import ExecItem
|
||||
|
||||
self.items: list[ExecItem] = items
|
||||
self.buffer_map: dict[UOp, Buffer] = {}
|
||||
self.var_vals: dict[str, int] = {}
|
||||
|
||||
# Create bound items with runners - lazy, done on first call
|
||||
self._bound_items: list[tuple[Any, list[Buffer], tuple, dict[str, int]]]|None = None
|
||||
self._graphs: list|None = None
|
||||
self._first_run = True
|
||||
|
||||
def _bind(self):
|
||||
"""Create runners from lib and bind buffers."""
|
||||
from tinygrad.engine.realize import CompiledRunner, BufferCopy, BufferXfer, ViewOp, EncDec, get_runner, get_program
|
||||
|
||||
self._bound_items = []
|
||||
for item in self.items:
|
||||
# Get buffers - prefer buf_uops with buffer_map, fall back to bufs for backwards compatibility
|
||||
bufs: list[Buffer] = []
|
||||
if item.buf_uops:
|
||||
for uop in item.buf_uops:
|
||||
bufs.append(cast(Buffer, self.buffer_map.get(uop) or uop.buffer))
|
||||
else:
|
||||
for buf in item.bufs:
|
||||
if buf is not None: bufs.append(buf)
|
||||
|
||||
# Create runner from lib or use existing prg
|
||||
if item.prg is not None:
|
||||
runner = item.prg
|
||||
elif item.ast.op is Ops.SINK:
|
||||
device = bufs[0].device
|
||||
if item.lib is not None:
|
||||
# Create runner from cached lib
|
||||
prg = get_program(item.ast, Device[device].renderer)
|
||||
runner = CompiledRunner(prg, item.lib)
|
||||
else:
|
||||
# Compile and create runner
|
||||
runner = get_runner(device, item.ast)
|
||||
elif item.ast.op is Ops.BUFFER_VIEW:
|
||||
runner = ViewOp(bufs[0])
|
||||
elif item.ast.op is Ops.COPY:
|
||||
if hasattr(Device[bufs[0].device].allocator, '_transfer') and all_same([x.device.split(":")[0] for x in bufs]):
|
||||
runner = BufferXfer(bufs[0].nbytes, bufs[0].device, bufs[1].device)
|
||||
else:
|
||||
runner = BufferCopy(bufs[0].nbytes, bufs[0].device, bufs[1].device)
|
||||
elif item.ast.op is Ops.ENCDEC:
|
||||
runner = EncDec(item.ast, bufs[0].nbytes, bufs[1].device)
|
||||
else:
|
||||
raise RuntimeError(f"unknown op {item.ast.op}")
|
||||
|
||||
self._bound_items.append((runner, bufs, item.metadata, item.fixedvars))
|
||||
|
||||
def update(self, buffers: dict[UOp, Buffer]|None = None, var_vals: dict[str, int]|None = None) -> ExecutionUnit:
|
||||
"""Update buffer mapping and/or var_vals for next run. Returns self for chaining."""
|
||||
if buffers is not None:
|
||||
self.buffer_map.update(buffers)
|
||||
# Need to rebind if we update buffers
|
||||
self._bound_items = None
|
||||
if var_vals is not None:
|
||||
self.var_vals.update(var_vals)
|
||||
return self
|
||||
|
||||
def __call__(self, var_vals: dict[str, int]|None = None, wait=False, do_update_stats=True, jit=False) -> float|None:
|
||||
"""Execute all items."""
|
||||
from tinygrad.engine.realize import CompiledRunner
|
||||
|
||||
if var_vals is not None:
|
||||
self.var_vals.update(var_vals)
|
||||
|
||||
# Lazy bind on first call
|
||||
if self._bound_items is None:
|
||||
self._bind()
|
||||
|
||||
assert self._bound_items is not None
|
||||
|
||||
# TODO: create graphs on first run
|
||||
# if self._first_run:
|
||||
# self._create_graphs()
|
||||
# self._first_run = False
|
||||
|
||||
# Execute all items
|
||||
total_et = 0.0
|
||||
for runner, bufs, metadata, fixedvars in self._bound_items:
|
||||
merged_var_vals = self.var_vals | fixedvars
|
||||
|
||||
# Ensure buffers are allocated (skip if jit - already allocated)
|
||||
if not jit:
|
||||
for b in bufs:
|
||||
b.ensure_allocated()
|
||||
|
||||
# Reorder bufs to match program globals if needed
|
||||
if isinstance(runner, CompiledRunner):
|
||||
ordered_bufs = [bufs[i] for i in runner.p.globals]
|
||||
else:
|
||||
ordered_bufs = bufs
|
||||
|
||||
# PROFILE events
|
||||
if PROFILE:
|
||||
payload = {"metadata":metadata, "var_vals":merged_var_vals, "bufs":[b.trace_num for b in ordered_bufs], "name":runner.display_name}
|
||||
payload["outputs"], payload["inputs"] = (runner.p.outs, runner.p.ins) if isinstance(runner, CompiledRunner) else ([0], [1])
|
||||
cpu_events.append(ProfilePointEvent(runner.device, "exec", len(cpu_events), payload))
|
||||
|
||||
et = runner(ordered_bufs, merged_var_vals, wait=wait or DEBUG >= 2)
|
||||
if et is not None:
|
||||
total_et += et
|
||||
|
||||
# Update stats
|
||||
if do_update_stats:
|
||||
GlobalCounters.kernel_count += 1
|
||||
op_est = sym_infer(runner.estimates.ops, merged_var_vals)
|
||||
mem_est = sym_infer(runner.estimates.mem, merged_var_vals)
|
||||
GlobalCounters.global_ops += op_est
|
||||
GlobalCounters.global_mem += mem_est
|
||||
if et is not None:
|
||||
GlobalCounters.time_sum_s += et
|
||||
if DEBUG >= 2:
|
||||
lds_est = sym_infer(runner.estimates.lds, merged_var_vals)
|
||||
mem_est = min(mem_est, lds_est) # there can't be more memory accessed than loads/stores
|
||||
header_color = 'magenta' if jit else ('green' if runner.first_run else None)
|
||||
ptm = colored(time_to_str(et, w=9), "yellow" if et > 0.01 else None) if et is not None else ""
|
||||
flops, membw, ldsbw = op_est/(et or 1e-20), mem_est/(et or 1e-20), lds_est/(et or 1e-20)
|
||||
flops_str = f"{flops*1e-9:7.0f} GFLOPS" if flops < 1e14 else colored(f"{flops*1e-12:7.0f} TFLOPS", 'green')
|
||||
mem_str = f"{membw*1e-9:4.0f}|{ldsbw*1e-9:<6.0f} GB/s" if membw < 1e13 and ldsbw < 1e15 else \
|
||||
colored(f"{membw*1e-12:4.0f}|{ldsbw*1e-12:<6.0f} TB/s", 'green')
|
||||
print(f"{colored(f'*** {runner.device[:7]:7s} {GlobalCounters.kernel_count:4d}', header_color)}"+
|
||||
f" {runner.display_name+' '*(46-ansilen(runner.display_name))} arg {len(ordered_bufs):2d} mem {GlobalCounters.mem_used/1e9:6.2f} GB"+
|
||||
("" if et is None else f" tm {ptm}/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({flops_str} {mem_str})")+
|
||||
f" {[repr(m) if TRACEMETA >= 2 else str(m) for m in metadata] if metadata else ''}")
|
||||
runner.first_run = False
|
||||
|
||||
return total_et if wait else None
|
||||
|
||||
def __add__(self, other: ExecutionUnit) -> ExecutionUnit:
|
||||
"""Combine two ExecutionUnits, rebuild graph lazily."""
|
||||
combined = ExecutionUnit(self.items + other.items)
|
||||
combined.buffer_map = {**self.buffer_map, **other.buffer_map}
|
||||
combined.var_vals = {**self.var_vals, **other.var_vals}
|
||||
return combined
|
||||
|
||||
def free_intermediates(self):
|
||||
"""Deallocate internal buffers."""
|
||||
for buf in self.buffer_map.values():
|
||||
if buf.is_allocated():
|
||||
buf.deallocate()
|
||||
# Reset bound state
|
||||
self._bound_items = None
|
||||
self._graphs = None
|
||||
self._first_run = True
|
||||
|
|
@ -195,6 +195,8 @@ class CapturedJit(Generic[ReturnType]):
|
|||
|
||||
# jit exec
|
||||
def __call__(self, input_buffers:list[Buffer], var_vals:dict[str, int]) -> ReturnType:
|
||||
from tinygrad.engine.execution import ExecutionUnit
|
||||
|
||||
# assign inputs
|
||||
for idx, offset, device, size, dtype in self.extra_view_inputs:
|
||||
input_buffers.append(Buffer(device, size, dtype, base=input_buffers[idx], offset=offset).ensure_allocated())
|
||||
|
|
@ -213,7 +215,8 @@ class CapturedJit(Generic[ReturnType]):
|
|||
self._first_run = False
|
||||
|
||||
if DEBUG >= 1 and len(self._jit_cache) >= 10: print(f"jit execs {len(self._jit_cache)} kernels")
|
||||
for ei in self._jit_cache: ei.run(var_vals, jit=True)
|
||||
# Use ExecutionUnit for execution
|
||||
ExecutionUnit(self._jit_cache).update(var_vals=var_vals)(jit=True)
|
||||
self._clear_inputs()
|
||||
return self.ret
|
||||
|
||||
|
|
@ -251,7 +254,7 @@ class TinyJit(Generic[ReturnType]):
|
|||
return ret
|
||||
|
||||
def add(self, ei:ExecItem):
|
||||
self._jit_cache.append(ExecItem(ei.ast, [self.add_buffer(buf) for buf in ei.bufs if buf is not None], ei.metadata, ei.fixedvars, ei.prg))
|
||||
self._jit_cache.append(ExecItem(ei.ast, [self.add_buffer(buf) for buf in ei.bufs if buf is not None], ei.metadata, ei.fixedvars, ei.lib, ei.prg))
|
||||
|
||||
def reset(self):
|
||||
assert self.fxn is not None, "can't reset without function"
|
||||
|
|
@ -308,14 +311,16 @@ class TinyJit(Generic[ReturnType]):
|
|||
|
||||
# prune independent kernels (optional)
|
||||
if self.prune:
|
||||
from tinygrad.engine.execution import ExecutionUnit
|
||||
depends = set(input_buffers)
|
||||
update_depends(depends, jit_cache)
|
||||
pruned, onetime = partition(jit_cache, lambda ei: any(b in depends for b in get_out_buffers_for_ei(ei)))
|
||||
if DEBUG >= 1: print(f"pruned from {len(jit_cache)} -> {len(pruned)} kernels")
|
||||
# run the onetime kernels here
|
||||
for ei in onetime:
|
||||
for b in ei.bufs: cast(Buffer, b).ensure_allocated()
|
||||
ei.run(var_vals, jit=True)
|
||||
if onetime:
|
||||
for ei in onetime:
|
||||
for b in ei.bufs: cast(Buffer, b).ensure_allocated()
|
||||
ExecutionUnit(onetime).update(var_vals=var_vals)(jit=True)
|
||||
jit_cache = pruned
|
||||
|
||||
# memory planning (optional)
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from collections import defaultdict
|
|||
from tinygrad.engine.realize import ExecItem
|
||||
from tinygrad.device import Device, Buffer
|
||||
from tinygrad.helpers import NO_MEMORY_PLANNER, dedup, DEBUG, round_up
|
||||
from tinygrad.uop.ops import Ops
|
||||
from tinygrad.uop.ops import Ops, UOp
|
||||
from tinygrad.dtype import dtypes, ImageDType
|
||||
from tinygrad.runtime.support.memory import TLSFAllocator
|
||||
|
||||
|
|
@ -63,8 +63,16 @@ def _internal_memory_planner(buffers:list[list[Buffer]], noopt_buffers=None, ign
|
|||
|
||||
return assigned
|
||||
|
||||
def memory_planner(schedule:list[ExecItem]) -> list[ExecItem]:
|
||||
def memory_planner(schedule:list[ExecItem]) -> tuple[list[ExecItem], dict[UOp, Buffer]]:
|
||||
# Exclude buffers involved in load ops (e.g transfers) to preserve parallelism in graphs.
|
||||
assigned = _internal_memory_planner([[b for b in si.bufs if b is not None] for si in schedule],
|
||||
noopt_buffers={b for si in schedule if si.ast.op is not Ops.SINK for b in si.bufs if b is not None})
|
||||
return [ExecItem(si.ast, [assigned.get(x, x) if x is not None else None for x in si.bufs], si.metadata, si.fixedvars) for si in schedule]
|
||||
new_schedule = [ExecItem(si.ast, [assigned.get(x, x) if x is not None else None for x in si.bufs], si.metadata, si.fixedvars, buf_uops=si.buf_uops)
|
||||
for si in schedule]
|
||||
# Build buffer_map from buf_uops -> assigned buffers
|
||||
buffer_map: dict[UOp, Buffer] = {}
|
||||
for si in new_schedule:
|
||||
for i, uop in enumerate(si.buf_uops):
|
||||
if uop not in buffer_map and i < len(si.bufs) and si.bufs[i] is not None:
|
||||
buffer_map[uop] = si.bufs[i] # type: ignore
|
||||
return new_schedule, buffer_map
|
||||
|
|
|
|||
|
|
@ -183,15 +183,21 @@ si_lowerer = PatternMatcher([
|
|||
@dataclass
|
||||
class ExecItem:
|
||||
ast: UOp
|
||||
bufs: list[Buffer|None] = field(default_factory=list)
|
||||
bufs: list[Buffer|None] = field(default_factory=list) # TODO: deprecate, use buf_uops + buffer_map
|
||||
metadata: tuple[Metadata, ...] = ()
|
||||
fixedvars: dict[str, int] = field(default_factory=dict)
|
||||
lib: bytes|None = None # compiled binary, None for COPY/VIEW/ENCDEC
|
||||
prg: Runner|None = None
|
||||
buf_uops: tuple[UOp, ...] = () # buffer UOps, binding happens in ExecutionUnit
|
||||
|
||||
def lower(self):
|
||||
"""Populate self.prg by lowering the AST."""
|
||||
"""Populate self.prg and self.lib by lowering the AST."""
|
||||
if self.prg is not None: return self
|
||||
try: self.prg = cast(Runner, si_lowerer.rewrite(self.ast, self.bufs))
|
||||
try:
|
||||
self.prg = cast(Runner, si_lowerer.rewrite(self.ast, self.bufs))
|
||||
# Store lib for SINK ops (compiled kernels)
|
||||
if isinstance(self.prg, CompiledRunner):
|
||||
self.lib = self.prg.lib
|
||||
except Exception as e:
|
||||
if DEBUG >= 2:
|
||||
print(f"error lowering {self.ast.op}")
|
||||
|
|
@ -237,24 +243,39 @@ class ExecItem:
|
|||
|
||||
capturing: list = [] # put classes with an add method in here
|
||||
|
||||
def run_schedule(schedule:list[ExecItem], var_vals:dict[str, int]|None=None, do_update_stats=True):
|
||||
while len(schedule):
|
||||
ei = schedule.pop(0).lower()
|
||||
def run_schedule(schedule:list[ExecItem], buffer_map:dict[UOp, Buffer]|None=None, var_vals:dict[str, int]|None=None, do_update_stats=True):
|
||||
from tinygrad.engine.execution import ExecutionUnit
|
||||
if buffer_map is None: buffer_map = {}
|
||||
|
||||
# Lower all items first
|
||||
lowered: list[ExecItem] = []
|
||||
for ei in schedule:
|
||||
ei = ei.lower()
|
||||
if len(capturing) and CAPTURING: capturing[0].add(ei)
|
||||
if VALIDATE_WITH_CPU and ei.ast.op is Ops.SINK:
|
||||
# copy in allocated buffers from the GPU
|
||||
bufs = [b for b in ei.bufs if b is not None]
|
||||
nb: list[Buffer|None] = [Buffer("CPU", b.size, b.dtype) for b in bufs]
|
||||
for cpu_b, gpu_b in zip(nb, bufs):
|
||||
if cpu_b is not None and gpu_b.is_allocated(): cpu_b.ensure_allocated().copyin(gpu_b.as_buffer())
|
||||
lowered.append(ei)
|
||||
|
||||
# run on GPU
|
||||
ei.run(var_vals, do_update_stats=do_update_stats)
|
||||
if VALIDATE_WITH_CPU:
|
||||
# Run item by item with CPU validation
|
||||
for ei in lowered:
|
||||
if ei.ast.op is Ops.SINK:
|
||||
# copy in allocated buffers from the GPU
|
||||
bufs = [b for b in ei.bufs if b is not None]
|
||||
nb: list[Buffer|None] = [Buffer("CPU", b.size, b.dtype) for b in bufs]
|
||||
for cpu_b, gpu_b in zip(nb, bufs):
|
||||
if cpu_b is not None and gpu_b.is_allocated(): cpu_b.ensure_allocated().copyin(gpu_b.as_buffer())
|
||||
|
||||
# validate the output buffers match (NOTE: this is assuming the output is buffer 0)
|
||||
with Context(BEAM=0): ExecItem(ei.ast, nb, ei.metadata, ei.fixedvars).run(var_vals, do_update_stats=do_update_stats)
|
||||
import numpy as np
|
||||
assert nb[0] is not None
|
||||
np.testing.assert_allclose(bufs[0].numpy(), nb[0].numpy(), rtol=1e-3, atol=1e-3)
|
||||
else:
|
||||
ei.run(var_vals, do_update_stats=do_update_stats)
|
||||
# run on GPU
|
||||
ExecutionUnit([ei]).update(buffers=buffer_map, var_vals=var_vals)(do_update_stats=do_update_stats)
|
||||
|
||||
# validate the output buffers match (NOTE: this is assuming the output is buffer 0)
|
||||
with Context(BEAM=0):
|
||||
ExecutionUnit([ExecItem(ei.ast, nb, ei.metadata, ei.fixedvars)]).update(var_vals=var_vals)(do_update_stats=do_update_stats)
|
||||
import numpy as np
|
||||
assert nb[0] is not None
|
||||
np.testing.assert_allclose(bufs[0].numpy(), nb[0].numpy(), rtol=1e-3, atol=1e-3)
|
||||
else:
|
||||
ExecutionUnit([ei]).update(buffers=buffer_map, var_vals=var_vals)(do_update_stats=do_update_stats)
|
||||
else:
|
||||
# Use ExecutionUnit for batched execution
|
||||
if lowered:
|
||||
ExecutionUnit(lowered).update(buffers=buffer_map, var_vals=var_vals)(do_update_stats=do_update_stats)
|
||||
|
|
|
|||
|
|
@ -125,7 +125,7 @@ pm_post_sched_cache = PatternMatcher([
|
|||
|
||||
schedule_cache: dict[bytes, tuple[list[ExecItem], UOp]] = {}
|
||||
@track_rewrites(lambda _,ret: f"Schedule {pluralize('Kernel', len(ret[1]))}")
|
||||
def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], list[ExecItem], dict[str, int]]:
|
||||
def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], list[ExecItem], dict[UOp, Buffer], dict[str, int]]:
|
||||
# big_sink srcs are all the Tensors
|
||||
st = time.perf_counter()
|
||||
|
||||
|
|
@ -185,11 +185,11 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
|
|||
assert all(isinstance(x, MultiBuffer) for x in ubufs), "kernel must all be multibuffer"
|
||||
dnums = [x for x in si.ast.variables() if x.arg[0] == '_device_num']
|
||||
for j, bufs in enumerate(zip(*[x.bufs for x in cast(tuple[MultiBuffer, ...], ubufs)])):
|
||||
schedule.append(ExecItem(si.ast, list(bufs), si.metadata, si.fixedvars | ({dnums[0].expr:j} if len(dnums) else {})))
|
||||
schedule.append(ExecItem(si.ast, list(bufs), si.metadata, si.fixedvars | ({dnums[0].expr:j} if len(dnums) else {}), buf_uops=buf_uops))
|
||||
else:
|
||||
# ONE -> ONE
|
||||
schedule.append(ExecItem(si.ast, list(ubufs), si.metadata, si.fixedvars))
|
||||
with cpu_profile(TracingKey("memory planner")): schedule = memory_planner(schedule)
|
||||
schedule.append(ExecItem(si.ast, list(ubufs), si.metadata, si.fixedvars, buf_uops=buf_uops))
|
||||
with cpu_profile(TracingKey("memory planner")): schedule, buffer_map = memory_planner(schedule)
|
||||
|
||||
# extract var_vals from BINDs that were stripped (only if there are kernels)
|
||||
var_vals: dict[str, int] = {}
|
||||
|
|
@ -204,4 +204,4 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
|
|||
print(f"scheduled {len(schedule):4d} kernels in {(time.perf_counter()-st)*1000:8.2f} ms"+\
|
||||
f" | {' cache hit' if sc_ret is not None else 'CACHE MISS'} {sched_cache_key.hex()[:8]}"+\
|
||||
f" | {len(UOpMetaClass.ucache)} uops in cache")
|
||||
return tensor_map, schedule, var_vals
|
||||
return tensor_map, schedule, buffer_map, var_vals
|
||||
|
|
|
|||
|
|
@ -237,7 +237,7 @@ class Tensor(OpMixin):
|
|||
"""
|
||||
return [Tensor(u) for u in UOp.custom_kernel(*[t.uop for t in (self,)+lst], fxn=fxn, grad_fxn=grad_fxn)]
|
||||
|
||||
def schedule_with_vars(self, *lst:Tensor) -> tuple[list[ExecItem], dict[str, int]]:
|
||||
def schedule_with_vars(self, *lst:Tensor) -> tuple[list[ExecItem], dict[UOp, Buffer], dict[str, int]]:
|
||||
"""
|
||||
Creates the schedule needed to realize these Tensor(s), with Variables.
|
||||
|
||||
|
|
@ -246,13 +246,13 @@ class Tensor(OpMixin):
|
|||
big_sink = UOp.sink(*[x.uop for x in (self,)+lst])
|
||||
|
||||
# this is where the schedule cache should go
|
||||
becomes_map, schedule, var_vals = complete_create_schedule_with_vars(big_sink)
|
||||
becomes_map, schedule, buffer_map, var_vals = complete_create_schedule_with_vars(big_sink)
|
||||
_apply_map_to_tensors(becomes_map, name="Apply Schedule Map")
|
||||
return schedule, var_vals
|
||||
return schedule, buffer_map, var_vals
|
||||
|
||||
def schedule(self, *lst:Tensor) -> list[ExecItem]:
|
||||
"""Creates the schedule needed to realize these Tensor(s)."""
|
||||
schedule, var_vals = self.schedule_with_vars(*lst)
|
||||
schedule, _buffer_map, var_vals = self.schedule_with_vars(*lst)
|
||||
assert len(var_vals) == 0
|
||||
return schedule
|
||||
|
||||
|
|
@ -260,7 +260,8 @@ class Tensor(OpMixin):
|
|||
def realize(self, *lst:Tensor, do_update_stats=True) -> Tensor:
|
||||
"""Triggers the computation needed to create these Tensor(s)."""
|
||||
if len(to_realize:=[x for x in (self,)+lst if not x.uop.is_contiguous()]):
|
||||
run_schedule(*Tensor.schedule_with_vars(*to_realize), do_update_stats=do_update_stats)
|
||||
schedule, buffer_map, var_vals = Tensor.schedule_with_vars(*to_realize)
|
||||
run_schedule(schedule, buffer_map, var_vals, do_update_stats=do_update_stats)
|
||||
return self
|
||||
|
||||
def replace(self, x:Tensor, allow_shape_mismatch=False) -> Tensor:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue