mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
schedule cache cleanup
This commit is contained in:
parent
add768aab0
commit
79f2cfcb96
4 changed files with 18 additions and 21 deletions
2
test/external/external_uop_gc.py
vendored
2
test/external/external_uop_gc.py
vendored
|
|
@ -1,5 +1,6 @@
|
|||
import gc
|
||||
from tinygrad import Tensor, UOp, Device, nn
|
||||
from tinygrad.engine.schedule import schedule_cache
|
||||
from tinygrad.engine.realize import method_cache, get_program
|
||||
from tinygrad.schedule.indexing import apply_movement_op, _apply_reshape
|
||||
from tinygrad.uop.divandmod import fold_divmod_general
|
||||
|
|
@ -68,6 +69,7 @@ if __name__ == "__main__":
|
|||
t()
|
||||
|
||||
# these caches will keep uops alive
|
||||
schedule_cache.clear()
|
||||
method_cache.clear()
|
||||
apply_movement_op.cache_clear()
|
||||
_apply_reshape.cache_clear()
|
||||
|
|
|
|||
|
|
@ -143,7 +143,7 @@ pm_post_sched_cache = PatternMatcher([
|
|||
(UPat(Ops.CONST, src=(UPat(Ops.DEVICE), UPat(Ops.LUNIQUE)), name="b"), replace_input_buffer_back),
|
||||
])
|
||||
|
||||
schedule_cache = {}
|
||||
schedule_cache: dict[bytes, tuple[UOp, dict[UOp, UOp]]] = {}
|
||||
|
||||
@track_rewrites(lambda _,ret: f"Schedule {pluralize('Kernel', len(ret[1]))}", True)
|
||||
def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], list[ScheduleItem], dict[str, int]]:
|
||||
|
|
@ -158,11 +158,7 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
|
|||
big_sink = graph_rewrite(big_sink, pm_pre_sched_cache, ctx=input_buffers, name="rewrite for sched cache")
|
||||
sched_cache_key = big_sink.key
|
||||
|
||||
if (sc_ret:=schedule_cache.get(sched_cache_key, None)) is not None:
|
||||
# schedule cache hit
|
||||
print("SC HIT", sched_cache_key)
|
||||
big_sink, tensor_map = sc_ret
|
||||
else:
|
||||
if (sc_ret:=schedule_cache.get(sched_cache_key, None)) is None:
|
||||
# tensor map is what we return
|
||||
tensor_map: dict[UOp, UOp] = {}
|
||||
|
||||
|
|
@ -175,20 +171,17 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
|
|||
big_sink = big_sink.substitute(tensor_map, name="Apply Kernelize Map")
|
||||
|
||||
# save in schedule cache
|
||||
print("sc miss", sched_cache_key)
|
||||
schedule_cache[sched_cache_key] = (big_sink, tensor_map)
|
||||
|
||||
#assert len([x for x in big_sink.toposort() if x.op is Ops.UNIQUE]) == 0
|
||||
else:
|
||||
# schedule cache hit
|
||||
big_sink, tensor_map = sc_ret
|
||||
|
||||
# replace all the LUNIQUEs with UNIQUEs
|
||||
input_buffers_reverse = {v:k for k,v in input_buffers.items()}
|
||||
big_sink = graph_rewrite(big_sink, pm_post_sched_cache, ctx=input_buffers_reverse, name="unrewrite for sched cache")
|
||||
new_tensor_map = {}
|
||||
for k,v in tensor_map.items():
|
||||
k = graph_rewrite(k, pm_post_sched_cache, ctx=input_buffers_reverse)
|
||||
v = graph_rewrite(v, pm_post_sched_cache, ctx=input_buffers_reverse)
|
||||
new_tensor_map[k] = v
|
||||
tensor_map = new_tensor_map
|
||||
tensor_map_sink = UOp.sink(*flatten([(k,v) for k,v in tensor_map.items()]))
|
||||
tm_src = graph_rewrite(tensor_map_sink, pm_post_sched_cache, ctx=input_buffers_reverse, name="unrewrite for tensor map").src
|
||||
tensor_map = {tm_src[i]:tm_src[i+1] for i in range(0, len(tm_src), 2)}
|
||||
|
||||
# create the schedule
|
||||
schedule, var_vals = create_schedule_with_vars(big_sink)
|
||||
|
|
@ -197,6 +190,8 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
|
|||
# remove all AFTERs, after scheduling, the tensors are just buffers
|
||||
tensor_map |= {u:u.buf_uop for u in big_sink.toposort() if u.op is Ops.AFTER}
|
||||
|
||||
if (DEBUG >= 1 and len(schedule) > 1) or DEBUG >= 3:
|
||||
print(f"scheduled {len(schedule)} kernels in {(time.perf_counter()-st)*1000:.2f} ms ({len(UOpMetaClass.ucache)} uops in cache)")
|
||||
if (DEBUG >= (1 if sc_ret is None else 2) and len(schedule) > 1) or DEBUG >= 3:
|
||||
print(f"scheduled {len(schedule)} kernels in {(time.perf_counter()-st)*1000:8.2f} ms"+\
|
||||
f" | {' cache hit' if sc_ret is not None else 'CACHE MISS'} {sched_cache_key.hex()[:8]}"+\
|
||||
f" | {len(UOpMetaClass.ucache)} uops in cache")
|
||||
return tensor_map, schedule, var_vals
|
||||
|
|
|
|||
|
|
@ -2,10 +2,10 @@ from dataclasses import dataclass, field
|
|||
import itertools
|
||||
from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, KernelInfo
|
||||
from tinygrad.uop.ops import track_rewrites, graph_rewrite, identity_element, sint, AxisType, BottomUpGate, Kernel, _remove_all_tags, range_str
|
||||
from tinygrad.uop.ops import graph_rewrite, identity_element, sint, AxisType, BottomUpGate, Kernel, _remove_all_tags, range_str
|
||||
from tinygrad.uop.symbolic import symbolic
|
||||
from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, flatten, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY
|
||||
from tinygrad.helpers import PCONTIG, partition, get_single_element, unwrap
|
||||
from tinygrad.helpers import argsort, prod, all_same, getenv, flatten, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY
|
||||
from tinygrad.helpers import PCONTIG, partition, get_single_element
|
||||
from tinygrad.codegen.simplify import pm_flatten_range, pm_reduce_simplify
|
||||
from tinygrad.codegen.opt import Opt
|
||||
from tinygrad.schedule.indexing import run_rangeify, BufferizeOpts, ALWAYS_CONTIGUOUS, IndexingContext, apply_movement_op
|
||||
|
|
|
|||
|
|
@ -662,7 +662,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
assert all_same([x.size for x in ret.bufs]) and all_same([x.dtype for x in ret.bufs]), "multibuffers mismatch buffers"
|
||||
return ret
|
||||
assert self.op is Ops.BUFFER, f"must be BUFFER {self.op}"
|
||||
assert self.src[0].op is Ops.UNIQUE, f"buffer src[0] must be UNIQUE"
|
||||
assert self.src[0].op is Ops.UNIQUE, "buffer src[0] must be UNIQUE"
|
||||
if (cret:=buffers.get(self)) is not None: return cret
|
||||
rdtype = self.dtype if isinstance(self.dtype, ImageDType) else self.dtype.base
|
||||
if isinstance(self.device, tuple): ret = MultiBuffer(self.device, self.size, rdtype).ref(1)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue