Compare commits

...

4 commits

Author SHA1 Message Date
George Hotz
43c033e5e5 clean 2025-12-21 21:39:17 -04:00
George Hotz
76f2f14233 more 2025-12-21 15:48:59 -04:00
George Hotz
5515de6553 chain 2025-12-21 12:52:20 -04:00
George Hotz
439c4319ec phase 1 2025-12-21 12:34:00 -04:00
13 changed files with 257 additions and 51 deletions

View file

@ -45,9 +45,10 @@ print(f"The schedule contains {len(schedule)} items.")
for si in schedule: print(str(si)[:80]) for si in schedule: print(str(si)[:80])
# ***** # *****
# 4. Lower and run the schedule. # 4. Run the schedule.
for si in tqdm(schedule): si.run() from tinygrad.engine.realize import run_schedule
run_schedule(schedule)
# ***** # *****
# 5. Print the weight change # 5. Print the weight change

View file

@ -71,8 +71,8 @@ class BenchmarkResnetTrain(unittest.TestCase):
y = x.sequential(layer).contiguous().contiguous_backward() y = x.sequential(layer).contiguous().contiguous_backward()
y.sum().backward() y.sum().backward()
if getenv("ASSIGN", 1): sched, _ = Tensor.schedule_with_vars(y, x.grad, *optim.schedule_step()) if getenv("ASSIGN", 1): sched, _, _ = Tensor.schedule_with_vars(y, x.grad, *optim.schedule_step())
else: sched, _ = Tensor.schedule_with_vars(y, x.grad, *[t.grad for t in optim.params]) else: sched, _, _ = Tensor.schedule_with_vars(y, x.grad, *[t.grad for t in optim.params])
for _ in range(JITCNT): for _ in range(JITCNT):
run_schedule(list(sched)) run_schedule(list(sched))

View file

@ -49,8 +49,8 @@ class BenchmarkBertTrain(unittest.TestCase):
y = layer(*inputs).contiguous().contiguous_backward() y = layer(*inputs).contiguous().contiguous_backward()
y.sum().backward() y.sum().backward()
if getenv("ASSIGN", 1): sched, _ = Tensor.schedule_with_vars(y, *list(inputs), *optim.schedule_step()) if getenv("ASSIGN", 1): sched, _, _ = Tensor.schedule_with_vars(y, *list(inputs), *optim.schedule_step())
else: sched, _ = Tensor.schedule_with_vars(y, *list(inputs), *[t.grad for t in optim.params]) else: sched, _, _ = Tensor.schedule_with_vars(y, *list(inputs), *[t.grad for t in optim.params])
for _ in range(JITCNT): for _ in range(JITCNT):
run_schedule(sched) run_schedule(sched)

View file

@ -865,7 +865,7 @@ class TestIdxUpcast(unittest.TestCase):
for src in ast.src: for src in ast.src:
if (ret:=self._find_op(src, op)) is not None: return ret if (ret:=self._find_op(src, op)) is not None: return ret
def _schedule_render(self, a: Tensor): def _schedule_render(self, a: Tensor):
schedule, _ = a.schedule_with_vars() schedule, _buffer_map, _var_vals = a.schedule_with_vars()
for s in schedule: for s in schedule:
if s.ast.op is Ops.SINK: if s.ast.op is Ops.SINK:
renderer = Device[s.bufs[0].device].renderer renderer = Device[s.bufs[0].device].renderer

View file

@ -481,7 +481,7 @@ class TestUOpMethod(unittest.TestCase):
a = UOp.variable("a", 1, 10) a = UOp.variable("a", 1, 10)
uop_var = Tensor(a.bind(1)) uop_var = Tensor(a.bind(1))
st_var = Tensor.empty((2, 10))[:, :a.bind(1)] st_var = Tensor.empty((2, 10))[:, :a.bind(1)]
_, var_vals = (uop_var+st_var).schedule_with_vars() _, _, var_vals = (uop_var+st_var).schedule_with_vars()
self.assertEqual(len(var_vals), 1) self.assertEqual(len(var_vals), 1)
self.assertEqual(list(var_vals)[0], a.expr) self.assertEqual(list(var_vals)[0], a.expr)

View file

@ -9,7 +9,7 @@ class TestRingAllReduce(unittest.TestCase):
N = 4 N = 4
ds = tuple(f"CPU:{i}" for i in range(N)) ds = tuple(f"CPU:{i}" for i in range(N))
t = Tensor.empty(N, N*100).shard(ds, axis=0).realize() t = Tensor.empty(N, N*100).shard(ds, axis=0).realize()
schedules = t.sum(0).schedule_with_vars()[0] schedules, _, _ = t.sum(0).schedule_with_vars()
copies = [si for si in schedules if si.ast.op is Ops.COPY] copies = [si for si in schedules if si.ast.op is Ops.COPY]
pairs = [(c.bufs[0].device, c.bufs[1].device) for c in copies] pairs = [(c.bufs[0].device, c.bufs[1].device) for c in copies]
# N*(N-1) scatter reduce, and N*(N-1) allgather # N*(N-1) scatter reduce, and N*(N-1) allgather

View file

@ -23,7 +23,7 @@ class TestScheduleCache(unittest.TestCase):
x = Tensor.ones(10).contiguous().realize() x = Tensor.ones(10).contiguous().realize()
t = x + Tensor(v.bind(42)) t = x + Tensor(v.bind(42))
_, var_vals = t.schedule_with_vars() _, _, var_vals = t.schedule_with_vars()
self.assertEqual(var_vals, {'pos': 42}) self.assertEqual(var_vals, {'pos': 42})
def test_simple(self): def test_simple(self):

View file

@ -0,0 +1,170 @@
from __future__ import annotations
from typing import Any, cast
from tinygrad.helpers import DEBUG, GlobalCounters, all_same, colored, ansilen, PROFILE, ProfilePointEvent, cpu_events, time_to_str, TRACEMETA
from tinygrad.uop.ops import UOp, Ops, sym_infer
from tinygrad.device import Device, Buffer
# **************** ExecutionUnit ****************
class ExecutionUnit:
"""
A bound, ready-to-execute unit. Replaces CapturedJit.
Takes ExecItems and binds them to real device resources on execution.
"""
def __init__(self, items: list):
"""
Create an ExecutionUnit from ExecItems.
Args:
items: ExecItems (with bufs as Buffers or UOps, lib optionally set)
"""
from tinygrad.engine.realize import ExecItem
self.items: list[ExecItem] = items
self.buffer_map: dict[UOp, Buffer] = {}
self.var_vals: dict[str, int] = {}
# Create bound items with runners - lazy, done on first call
self._bound_items: list[tuple[Any, list[Buffer], tuple, dict[str, int]]]|None = None
self._graphs: list|None = None
self._first_run = True
def _bind(self):
"""Create runners from lib and bind buffers."""
from tinygrad.engine.realize import CompiledRunner, BufferCopy, BufferXfer, ViewOp, EncDec, get_runner, get_program
self._bound_items = []
for item in self.items:
# Get buffers - prefer buf_uops with buffer_map, fall back to bufs for backwards compatibility
bufs: list[Buffer] = []
if item.buf_uops:
for uop in item.buf_uops:
bufs.append(cast(Buffer, self.buffer_map.get(uop) or uop.buffer))
else:
for buf in item.bufs:
if buf is not None: bufs.append(buf)
# Create runner from lib or use existing prg
if item.prg is not None:
runner = item.prg
elif item.ast.op is Ops.SINK:
device = bufs[0].device
if item.lib is not None:
# Create runner from cached lib
prg = get_program(item.ast, Device[device].renderer)
runner = CompiledRunner(prg, item.lib)
else:
# Compile and create runner
runner = get_runner(device, item.ast)
elif item.ast.op is Ops.BUFFER_VIEW:
runner = ViewOp(bufs[0])
elif item.ast.op is Ops.COPY:
if hasattr(Device[bufs[0].device].allocator, '_transfer') and all_same([x.device.split(":")[0] for x in bufs]):
runner = BufferXfer(bufs[0].nbytes, bufs[0].device, bufs[1].device)
else:
runner = BufferCopy(bufs[0].nbytes, bufs[0].device, bufs[1].device)
elif item.ast.op is Ops.ENCDEC:
runner = EncDec(item.ast, bufs[0].nbytes, bufs[1].device)
else:
raise RuntimeError(f"unknown op {item.ast.op}")
self._bound_items.append((runner, bufs, item.metadata, item.fixedvars))
def update(self, buffers: dict[UOp, Buffer]|None = None, var_vals: dict[str, int]|None = None) -> ExecutionUnit:
"""Update buffer mapping and/or var_vals for next run. Returns self for chaining."""
if buffers is not None:
self.buffer_map.update(buffers)
# Need to rebind if we update buffers
self._bound_items = None
if var_vals is not None:
self.var_vals.update(var_vals)
return self
def __call__(self, var_vals: dict[str, int]|None = None, wait=False, do_update_stats=True, jit=False) -> float|None:
"""Execute all items."""
from tinygrad.engine.realize import CompiledRunner
if var_vals is not None:
self.var_vals.update(var_vals)
# Lazy bind on first call
if self._bound_items is None:
self._bind()
assert self._bound_items is not None
# TODO: create graphs on first run
# if self._first_run:
# self._create_graphs()
# self._first_run = False
# Execute all items
total_et = 0.0
for runner, bufs, metadata, fixedvars in self._bound_items:
merged_var_vals = self.var_vals | fixedvars
# Ensure buffers are allocated (skip if jit - already allocated)
if not jit:
for b in bufs:
b.ensure_allocated()
# Reorder bufs to match program globals if needed
if isinstance(runner, CompiledRunner):
ordered_bufs = [bufs[i] for i in runner.p.globals]
else:
ordered_bufs = bufs
# PROFILE events
if PROFILE:
payload = {"metadata":metadata, "var_vals":merged_var_vals, "bufs":[b.trace_num for b in ordered_bufs], "name":runner.display_name}
payload["outputs"], payload["inputs"] = (runner.p.outs, runner.p.ins) if isinstance(runner, CompiledRunner) else ([0], [1])
cpu_events.append(ProfilePointEvent(runner.device, "exec", len(cpu_events), payload))
et = runner(ordered_bufs, merged_var_vals, wait=wait or DEBUG >= 2)
if et is not None:
total_et += et
# Update stats
if do_update_stats:
GlobalCounters.kernel_count += 1
op_est = sym_infer(runner.estimates.ops, merged_var_vals)
mem_est = sym_infer(runner.estimates.mem, merged_var_vals)
GlobalCounters.global_ops += op_est
GlobalCounters.global_mem += mem_est
if et is not None:
GlobalCounters.time_sum_s += et
if DEBUG >= 2:
lds_est = sym_infer(runner.estimates.lds, merged_var_vals)
mem_est = min(mem_est, lds_est) # there can't be more memory accessed than loads/stores
header_color = 'magenta' if jit else ('green' if runner.first_run else None)
ptm = colored(time_to_str(et, w=9), "yellow" if et > 0.01 else None) if et is not None else ""
flops, membw, ldsbw = op_est/(et or 1e-20), mem_est/(et or 1e-20), lds_est/(et or 1e-20)
flops_str = f"{flops*1e-9:7.0f} GFLOPS" if flops < 1e14 else colored(f"{flops*1e-12:7.0f} TFLOPS", 'green')
mem_str = f"{membw*1e-9:4.0f}|{ldsbw*1e-9:<6.0f} GB/s" if membw < 1e13 and ldsbw < 1e15 else \
colored(f"{membw*1e-12:4.0f}|{ldsbw*1e-12:<6.0f} TB/s", 'green')
print(f"{colored(f'*** {runner.device[:7]:7s} {GlobalCounters.kernel_count:4d}', header_color)}"+
f" {runner.display_name+' '*(46-ansilen(runner.display_name))} arg {len(ordered_bufs):2d} mem {GlobalCounters.mem_used/1e9:6.2f} GB"+
("" if et is None else f" tm {ptm}/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({flops_str} {mem_str})")+
f" {[repr(m) if TRACEMETA >= 2 else str(m) for m in metadata] if metadata else ''}")
runner.first_run = False
return total_et if wait else None
def __add__(self, other: ExecutionUnit) -> ExecutionUnit:
"""Combine two ExecutionUnits, rebuild graph lazily."""
combined = ExecutionUnit(self.items + other.items)
combined.buffer_map = {**self.buffer_map, **other.buffer_map}
combined.var_vals = {**self.var_vals, **other.var_vals}
return combined
def free_intermediates(self):
"""Deallocate internal buffers."""
for buf in self.buffer_map.values():
if buf.is_allocated():
buf.deallocate()
# Reset bound state
self._bound_items = None
self._graphs = None
self._first_run = True

View file

@ -195,6 +195,8 @@ class CapturedJit(Generic[ReturnType]):
# jit exec # jit exec
def __call__(self, input_buffers:list[Buffer], var_vals:dict[str, int]) -> ReturnType: def __call__(self, input_buffers:list[Buffer], var_vals:dict[str, int]) -> ReturnType:
from tinygrad.engine.execution import ExecutionUnit
# assign inputs # assign inputs
for idx, offset, device, size, dtype in self.extra_view_inputs: for idx, offset, device, size, dtype in self.extra_view_inputs:
input_buffers.append(Buffer(device, size, dtype, base=input_buffers[idx], offset=offset).ensure_allocated()) input_buffers.append(Buffer(device, size, dtype, base=input_buffers[idx], offset=offset).ensure_allocated())
@ -213,7 +215,8 @@ class CapturedJit(Generic[ReturnType]):
self._first_run = False self._first_run = False
if DEBUG >= 1 and len(self._jit_cache) >= 10: print(f"jit execs {len(self._jit_cache)} kernels") if DEBUG >= 1 and len(self._jit_cache) >= 10: print(f"jit execs {len(self._jit_cache)} kernels")
for ei in self._jit_cache: ei.run(var_vals, jit=True) # Use ExecutionUnit for execution
ExecutionUnit(self._jit_cache).update(var_vals=var_vals)(jit=True)
self._clear_inputs() self._clear_inputs()
return self.ret return self.ret
@ -251,7 +254,7 @@ class TinyJit(Generic[ReturnType]):
return ret return ret
def add(self, ei:ExecItem): def add(self, ei:ExecItem):
self._jit_cache.append(ExecItem(ei.ast, [self.add_buffer(buf) for buf in ei.bufs if buf is not None], ei.metadata, ei.fixedvars, ei.prg)) self._jit_cache.append(ExecItem(ei.ast, [self.add_buffer(buf) for buf in ei.bufs if buf is not None], ei.metadata, ei.fixedvars, ei.lib, ei.prg))
def reset(self): def reset(self):
assert self.fxn is not None, "can't reset without function" assert self.fxn is not None, "can't reset without function"
@ -308,14 +311,16 @@ class TinyJit(Generic[ReturnType]):
# prune independent kernels (optional) # prune independent kernels (optional)
if self.prune: if self.prune:
from tinygrad.engine.execution import ExecutionUnit
depends = set(input_buffers) depends = set(input_buffers)
update_depends(depends, jit_cache) update_depends(depends, jit_cache)
pruned, onetime = partition(jit_cache, lambda ei: any(b in depends for b in get_out_buffers_for_ei(ei))) pruned, onetime = partition(jit_cache, lambda ei: any(b in depends for b in get_out_buffers_for_ei(ei)))
if DEBUG >= 1: print(f"pruned from {len(jit_cache)} -> {len(pruned)} kernels") if DEBUG >= 1: print(f"pruned from {len(jit_cache)} -> {len(pruned)} kernels")
# run the onetime kernels here # run the onetime kernels here
if onetime:
for ei in onetime: for ei in onetime:
for b in ei.bufs: cast(Buffer, b).ensure_allocated() for b in ei.bufs: cast(Buffer, b).ensure_allocated()
ei.run(var_vals, jit=True) ExecutionUnit(onetime).update(var_vals=var_vals)(jit=True)
jit_cache = pruned jit_cache = pruned
# memory planning (optional) # memory planning (optional)

View file

@ -3,7 +3,7 @@ from collections import defaultdict
from tinygrad.engine.realize import ExecItem from tinygrad.engine.realize import ExecItem
from tinygrad.device import Device, Buffer from tinygrad.device import Device, Buffer
from tinygrad.helpers import NO_MEMORY_PLANNER, dedup, DEBUG, round_up from tinygrad.helpers import NO_MEMORY_PLANNER, dedup, DEBUG, round_up
from tinygrad.uop.ops import Ops from tinygrad.uop.ops import Ops, UOp
from tinygrad.dtype import dtypes, ImageDType from tinygrad.dtype import dtypes, ImageDType
from tinygrad.runtime.support.memory import TLSFAllocator from tinygrad.runtime.support.memory import TLSFAllocator
@ -63,8 +63,16 @@ def _internal_memory_planner(buffers:list[list[Buffer]], noopt_buffers=None, ign
return assigned return assigned
def memory_planner(schedule:list[ExecItem]) -> list[ExecItem]: def memory_planner(schedule:list[ExecItem]) -> tuple[list[ExecItem], dict[UOp, Buffer]]:
# Exclude buffers involved in load ops (e.g transfers) to preserve parallelism in graphs. # Exclude buffers involved in load ops (e.g transfers) to preserve parallelism in graphs.
assigned = _internal_memory_planner([[b for b in si.bufs if b is not None] for si in schedule], assigned = _internal_memory_planner([[b for b in si.bufs if b is not None] for si in schedule],
noopt_buffers={b for si in schedule if si.ast.op is not Ops.SINK for b in si.bufs if b is not None}) noopt_buffers={b for si in schedule if si.ast.op is not Ops.SINK for b in si.bufs if b is not None})
return [ExecItem(si.ast, [assigned.get(x, x) if x is not None else None for x in si.bufs], si.metadata, si.fixedvars) for si in schedule] new_schedule = [ExecItem(si.ast, [assigned.get(x, x) if x is not None else None for x in si.bufs], si.metadata, si.fixedvars, buf_uops=si.buf_uops)
for si in schedule]
# Build buffer_map from buf_uops -> assigned buffers
buffer_map: dict[UOp, Buffer] = {}
for si in new_schedule:
for i, uop in enumerate(si.buf_uops):
if uop not in buffer_map and i < len(si.bufs) and si.bufs[i] is not None:
buffer_map[uop] = si.bufs[i] # type: ignore
return new_schedule, buffer_map

View file

@ -183,15 +183,21 @@ si_lowerer = PatternMatcher([
@dataclass @dataclass
class ExecItem: class ExecItem:
ast: UOp ast: UOp
bufs: list[Buffer|None] = field(default_factory=list) bufs: list[Buffer|None] = field(default_factory=list) # TODO: deprecate, use buf_uops + buffer_map
metadata: tuple[Metadata, ...] = () metadata: tuple[Metadata, ...] = ()
fixedvars: dict[str, int] = field(default_factory=dict) fixedvars: dict[str, int] = field(default_factory=dict)
lib: bytes|None = None # compiled binary, None for COPY/VIEW/ENCDEC
prg: Runner|None = None prg: Runner|None = None
buf_uops: tuple[UOp, ...] = () # buffer UOps, binding happens in ExecutionUnit
def lower(self): def lower(self):
"""Populate self.prg by lowering the AST.""" """Populate self.prg and self.lib by lowering the AST."""
if self.prg is not None: return self if self.prg is not None: return self
try: self.prg = cast(Runner, si_lowerer.rewrite(self.ast, self.bufs)) try:
self.prg = cast(Runner, si_lowerer.rewrite(self.ast, self.bufs))
# Store lib for SINK ops (compiled kernels)
if isinstance(self.prg, CompiledRunner):
self.lib = self.prg.lib
except Exception as e: except Exception as e:
if DEBUG >= 2: if DEBUG >= 2:
print(f"error lowering {self.ast.op}") print(f"error lowering {self.ast.op}")
@ -237,11 +243,21 @@ class ExecItem:
capturing: list = [] # put classes with an add method in here capturing: list = [] # put classes with an add method in here
def run_schedule(schedule:list[ExecItem], var_vals:dict[str, int]|None=None, do_update_stats=True): def run_schedule(schedule:list[ExecItem], buffer_map:dict[UOp, Buffer]|None=None, var_vals:dict[str, int]|None=None, do_update_stats=True):
while len(schedule): from tinygrad.engine.execution import ExecutionUnit
ei = schedule.pop(0).lower() if buffer_map is None: buffer_map = {}
# Lower all items first
lowered: list[ExecItem] = []
for ei in schedule:
ei = ei.lower()
if len(capturing) and CAPTURING: capturing[0].add(ei) if len(capturing) and CAPTURING: capturing[0].add(ei)
if VALIDATE_WITH_CPU and ei.ast.op is Ops.SINK: lowered.append(ei)
if VALIDATE_WITH_CPU:
# Run item by item with CPU validation
for ei in lowered:
if ei.ast.op is Ops.SINK:
# copy in allocated buffers from the GPU # copy in allocated buffers from the GPU
bufs = [b for b in ei.bufs if b is not None] bufs = [b for b in ei.bufs if b is not None]
nb: list[Buffer|None] = [Buffer("CPU", b.size, b.dtype) for b in bufs] nb: list[Buffer|None] = [Buffer("CPU", b.size, b.dtype) for b in bufs]
@ -249,12 +265,17 @@ def run_schedule(schedule:list[ExecItem], var_vals:dict[str, int]|None=None, do_
if cpu_b is not None and gpu_b.is_allocated(): cpu_b.ensure_allocated().copyin(gpu_b.as_buffer()) if cpu_b is not None and gpu_b.is_allocated(): cpu_b.ensure_allocated().copyin(gpu_b.as_buffer())
# run on GPU # run on GPU
ei.run(var_vals, do_update_stats=do_update_stats) ExecutionUnit([ei]).update(buffers=buffer_map, var_vals=var_vals)(do_update_stats=do_update_stats)
# validate the output buffers match (NOTE: this is assuming the output is buffer 0) # validate the output buffers match (NOTE: this is assuming the output is buffer 0)
with Context(BEAM=0): ExecItem(ei.ast, nb, ei.metadata, ei.fixedvars).run(var_vals, do_update_stats=do_update_stats) with Context(BEAM=0):
ExecutionUnit([ExecItem(ei.ast, nb, ei.metadata, ei.fixedvars)]).update(var_vals=var_vals)(do_update_stats=do_update_stats)
import numpy as np import numpy as np
assert nb[0] is not None assert nb[0] is not None
np.testing.assert_allclose(bufs[0].numpy(), nb[0].numpy(), rtol=1e-3, atol=1e-3) np.testing.assert_allclose(bufs[0].numpy(), nb[0].numpy(), rtol=1e-3, atol=1e-3)
else: else:
ei.run(var_vals, do_update_stats=do_update_stats) ExecutionUnit([ei]).update(buffers=buffer_map, var_vals=var_vals)(do_update_stats=do_update_stats)
else:
# Use ExecutionUnit for batched execution
if lowered:
ExecutionUnit(lowered).update(buffers=buffer_map, var_vals=var_vals)(do_update_stats=do_update_stats)

View file

@ -125,7 +125,7 @@ pm_post_sched_cache = PatternMatcher([
schedule_cache: dict[bytes, tuple[list[ExecItem], UOp]] = {} schedule_cache: dict[bytes, tuple[list[ExecItem], UOp]] = {}
@track_rewrites(lambda _,ret: f"Schedule {pluralize('Kernel', len(ret[1]))}") @track_rewrites(lambda _,ret: f"Schedule {pluralize('Kernel', len(ret[1]))}")
def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], list[ExecItem], dict[str, int]]: def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], list[ExecItem], dict[UOp, Buffer], dict[str, int]]:
# big_sink srcs are all the Tensors # big_sink srcs are all the Tensors
st = time.perf_counter() st = time.perf_counter()
@ -185,11 +185,11 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
assert all(isinstance(x, MultiBuffer) for x in ubufs), "kernel must all be multibuffer" assert all(isinstance(x, MultiBuffer) for x in ubufs), "kernel must all be multibuffer"
dnums = [x for x in si.ast.variables() if x.arg[0] == '_device_num'] dnums = [x for x in si.ast.variables() if x.arg[0] == '_device_num']
for j, bufs in enumerate(zip(*[x.bufs for x in cast(tuple[MultiBuffer, ...], ubufs)])): for j, bufs in enumerate(zip(*[x.bufs for x in cast(tuple[MultiBuffer, ...], ubufs)])):
schedule.append(ExecItem(si.ast, list(bufs), si.metadata, si.fixedvars | ({dnums[0].expr:j} if len(dnums) else {}))) schedule.append(ExecItem(si.ast, list(bufs), si.metadata, si.fixedvars | ({dnums[0].expr:j} if len(dnums) else {}), buf_uops=buf_uops))
else: else:
# ONE -> ONE # ONE -> ONE
schedule.append(ExecItem(si.ast, list(ubufs), si.metadata, si.fixedvars)) schedule.append(ExecItem(si.ast, list(ubufs), si.metadata, si.fixedvars, buf_uops=buf_uops))
with cpu_profile(TracingKey("memory planner")): schedule = memory_planner(schedule) with cpu_profile(TracingKey("memory planner")): schedule, buffer_map = memory_planner(schedule)
# extract var_vals from BINDs that were stripped (only if there are kernels) # extract var_vals from BINDs that were stripped (only if there are kernels)
var_vals: dict[str, int] = {} var_vals: dict[str, int] = {}
@ -204,4 +204,4 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
print(f"scheduled {len(schedule):4d} kernels in {(time.perf_counter()-st)*1000:8.2f} ms"+\ print(f"scheduled {len(schedule):4d} kernels in {(time.perf_counter()-st)*1000:8.2f} ms"+\
f" | {' cache hit' if sc_ret is not None else 'CACHE MISS'} {sched_cache_key.hex()[:8]}"+\ f" | {' cache hit' if sc_ret is not None else 'CACHE MISS'} {sched_cache_key.hex()[:8]}"+\
f" | {len(UOpMetaClass.ucache)} uops in cache") f" | {len(UOpMetaClass.ucache)} uops in cache")
return tensor_map, schedule, var_vals return tensor_map, schedule, buffer_map, var_vals

View file

@ -237,7 +237,7 @@ class Tensor(OpMixin):
""" """
return [Tensor(u) for u in UOp.custom_kernel(*[t.uop for t in (self,)+lst], fxn=fxn, grad_fxn=grad_fxn)] return [Tensor(u) for u in UOp.custom_kernel(*[t.uop for t in (self,)+lst], fxn=fxn, grad_fxn=grad_fxn)]
def schedule_with_vars(self, *lst:Tensor) -> tuple[list[ExecItem], dict[str, int]]: def schedule_with_vars(self, *lst:Tensor) -> tuple[list[ExecItem], dict[UOp, Buffer], dict[str, int]]:
""" """
Creates the schedule needed to realize these Tensor(s), with Variables. Creates the schedule needed to realize these Tensor(s), with Variables.
@ -246,13 +246,13 @@ class Tensor(OpMixin):
big_sink = UOp.sink(*[x.uop for x in (self,)+lst]) big_sink = UOp.sink(*[x.uop for x in (self,)+lst])
# this is where the schedule cache should go # this is where the schedule cache should go
becomes_map, schedule, var_vals = complete_create_schedule_with_vars(big_sink) becomes_map, schedule, buffer_map, var_vals = complete_create_schedule_with_vars(big_sink)
_apply_map_to_tensors(becomes_map, name="Apply Schedule Map") _apply_map_to_tensors(becomes_map, name="Apply Schedule Map")
return schedule, var_vals return schedule, buffer_map, var_vals
def schedule(self, *lst:Tensor) -> list[ExecItem]: def schedule(self, *lst:Tensor) -> list[ExecItem]:
"""Creates the schedule needed to realize these Tensor(s).""" """Creates the schedule needed to realize these Tensor(s)."""
schedule, var_vals = self.schedule_with_vars(*lst) schedule, _buffer_map, var_vals = self.schedule_with_vars(*lst)
assert len(var_vals) == 0 assert len(var_vals) == 0
return schedule return schedule
@ -260,7 +260,8 @@ class Tensor(OpMixin):
def realize(self, *lst:Tensor, do_update_stats=True) -> Tensor: def realize(self, *lst:Tensor, do_update_stats=True) -> Tensor:
"""Triggers the computation needed to create these Tensor(s).""" """Triggers the computation needed to create these Tensor(s)."""
if len(to_realize:=[x for x in (self,)+lst if not x.uop.is_contiguous()]): if len(to_realize:=[x for x in (self,)+lst if not x.uop.is_contiguous()]):
run_schedule(*Tensor.schedule_with_vars(*to_realize), do_update_stats=do_update_stats) schedule, buffer_map, var_vals = Tensor.schedule_with_vars(*to_realize)
run_schedule(schedule, buffer_map, var_vals, do_update_stats=do_update_stats)
return self return self
def replace(self, x:Tensor, allow_shape_mismatch=False) -> Tensor: def replace(self, x:Tensor, allow_shape_mismatch=False) -> Tensor: