schedule cache cleanup

This commit is contained in:
George Hotz 2025-12-02 16:59:32 -08:00
commit 79f2cfcb96
4 changed files with 18 additions and 21 deletions

View file

@ -1,5 +1,6 @@
import gc
from tinygrad import Tensor, UOp, Device, nn
from tinygrad.engine.schedule import schedule_cache
from tinygrad.engine.realize import method_cache, get_program
from tinygrad.schedule.indexing import apply_movement_op, _apply_reshape
from tinygrad.uop.divandmod import fold_divmod_general
@ -68,6 +69,7 @@ if __name__ == "__main__":
t()
# these caches will keep uops alive
schedule_cache.clear()
method_cache.clear()
apply_movement_op.cache_clear()
_apply_reshape.cache_clear()

View file

@ -143,7 +143,7 @@ pm_post_sched_cache = PatternMatcher([
(UPat(Ops.CONST, src=(UPat(Ops.DEVICE), UPat(Ops.LUNIQUE)), name="b"), replace_input_buffer_back),
])
schedule_cache = {}
schedule_cache: dict[bytes, tuple[UOp, dict[UOp, UOp]]] = {}
@track_rewrites(lambda _,ret: f"Schedule {pluralize('Kernel', len(ret[1]))}", True)
def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], list[ScheduleItem], dict[str, int]]:
@ -158,11 +158,7 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
big_sink = graph_rewrite(big_sink, pm_pre_sched_cache, ctx=input_buffers, name="rewrite for sched cache")
sched_cache_key = big_sink.key
if (sc_ret:=schedule_cache.get(sched_cache_key, None)) is not None:
# schedule cache hit
print("SC HIT", sched_cache_key)
big_sink, tensor_map = sc_ret
else:
if (sc_ret:=schedule_cache.get(sched_cache_key, None)) is None:
# tensor map is what we return
tensor_map: dict[UOp, UOp] = {}
@ -175,20 +171,17 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
big_sink = big_sink.substitute(tensor_map, name="Apply Kernelize Map")
# save in schedule cache
print("sc miss", sched_cache_key)
schedule_cache[sched_cache_key] = (big_sink, tensor_map)
#assert len([x for x in big_sink.toposort() if x.op is Ops.UNIQUE]) == 0
else:
# schedule cache hit
big_sink, tensor_map = sc_ret
# replace all the LUNIQUEs with UNIQUEs
input_buffers_reverse = {v:k for k,v in input_buffers.items()}
big_sink = graph_rewrite(big_sink, pm_post_sched_cache, ctx=input_buffers_reverse, name="unrewrite for sched cache")
new_tensor_map = {}
for k,v in tensor_map.items():
k = graph_rewrite(k, pm_post_sched_cache, ctx=input_buffers_reverse)
v = graph_rewrite(v, pm_post_sched_cache, ctx=input_buffers_reverse)
new_tensor_map[k] = v
tensor_map = new_tensor_map
tensor_map_sink = UOp.sink(*flatten([(k,v) for k,v in tensor_map.items()]))
tm_src = graph_rewrite(tensor_map_sink, pm_post_sched_cache, ctx=input_buffers_reverse, name="unrewrite for tensor map").src
tensor_map = {tm_src[i]:tm_src[i+1] for i in range(0, len(tm_src), 2)}
# create the schedule
schedule, var_vals = create_schedule_with_vars(big_sink)
@ -197,6 +190,8 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
# remove all AFTERs, after scheduling, the tensors are just buffers
tensor_map |= {u:u.buf_uop for u in big_sink.toposort() if u.op is Ops.AFTER}
if (DEBUG >= 1 and len(schedule) > 1) or DEBUG >= 3:
print(f"scheduled {len(schedule)} kernels in {(time.perf_counter()-st)*1000:.2f} ms ({len(UOpMetaClass.ucache)} uops in cache)")
if (DEBUG >= (1 if sc_ret is None else 2) and len(schedule) > 1) or DEBUG >= 3:
print(f"scheduled {len(schedule)} kernels in {(time.perf_counter()-st)*1000:8.2f} ms"+\
f" | {' cache hit' if sc_ret is not None else 'CACHE MISS'} {sched_cache_key.hex()[:8]}"+\
f" | {len(UOpMetaClass.ucache)} uops in cache")
return tensor_map, schedule, var_vals

View file

@ -2,10 +2,10 @@ from dataclasses import dataclass, field
import itertools
from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, KernelInfo
from tinygrad.uop.ops import track_rewrites, graph_rewrite, identity_element, sint, AxisType, BottomUpGate, Kernel, _remove_all_tags, range_str
from tinygrad.uop.ops import graph_rewrite, identity_element, sint, AxisType, BottomUpGate, Kernel, _remove_all_tags, range_str
from tinygrad.uop.symbolic import symbolic
from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, flatten, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY
from tinygrad.helpers import PCONTIG, partition, get_single_element, unwrap
from tinygrad.helpers import argsort, prod, all_same, getenv, flatten, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY
from tinygrad.helpers import PCONTIG, partition, get_single_element
from tinygrad.codegen.simplify import pm_flatten_range, pm_reduce_simplify
from tinygrad.codegen.opt import Opt
from tinygrad.schedule.indexing import run_rangeify, BufferizeOpts, ALWAYS_CONTIGUOUS, IndexingContext, apply_movement_op

View file

@ -662,7 +662,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
assert all_same([x.size for x in ret.bufs]) and all_same([x.dtype for x in ret.bufs]), "multibuffers mismatch buffers"
return ret
assert self.op is Ops.BUFFER, f"must be BUFFER {self.op}"
assert self.src[0].op is Ops.UNIQUE, f"buffer src[0] must be UNIQUE"
assert self.src[0].op is Ops.UNIQUE, "buffer src[0] must be UNIQUE"
if (cret:=buffers.get(self)) is not None: return cret
rdtype = self.dtype if isinstance(self.dtype, ImageDType) else self.dtype.base
if isinstance(self.device, tuple): ret = MultiBuffer(self.device, self.size, rdtype).ref(1)