mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
parent
c8f47c1d07
commit
16afe04f45
3 changed files with 24 additions and 19 deletions
11
test/external/process_replay/process_replay.py
vendored
11
test/external/process_replay/process_replay.py
vendored
|
|
@ -2,11 +2,11 @@
|
|||
# compare kernels created by HEAD against master
|
||||
import os, multiprocessing, logging, pickle, sqlite3, difflib, functools, warnings
|
||||
from typing import Callable, cast
|
||||
from tinygrad.helpers import VERSION, Context, ContextVar, colored, db_connection, getenv, tqdm
|
||||
from tinygrad.engine.schedule import create_schedule_with_vars
|
||||
from tinygrad.helpers import VERSION, Context, ContextVar, colored, db_connection, getenv, tqdm, dedup
|
||||
from tinygrad.engine.grouper import get_becomes_map
|
||||
from tinygrad.codegen.kernel import Kernel, Opt
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.ops import UOp
|
||||
from tinygrad.ops import UOp, Ops
|
||||
|
||||
# *** process replay settings
|
||||
|
||||
|
|
@ -34,8 +34,9 @@ class ProcessReplayWarning(Warning): pass
|
|||
# *** recreators
|
||||
|
||||
def recreate_sched(big_sink:UOp) -> list[UOp]:
|
||||
sched, _, __ = create_schedule_with_vars(big_sink)
|
||||
return [x.ast for x in sched]
|
||||
sched_sink = get_becomes_map(big_sink)[0][big_sink]
|
||||
return dedup(u.src[1].arg.ast for u in sched_sink.toposort if u.op is Ops.ASSIGN)
|
||||
|
||||
def recreate_kernel(ast:UOp, opts:Renderer, applied_opts:list[Opt], name:str, _) -> str:
|
||||
k = Kernel(ast, opts=opts)
|
||||
for opt in applied_opts: k.apply_opt(opt)
|
||||
|
|
|
|||
|
|
@ -4,8 +4,8 @@ from tinygrad.ops import UOp, Variable, Ops, GroupOp, PatternMatcher, UPat, grap
|
|||
from tinygrad.ops import can_pad, sint, track_rewrites
|
||||
from tinygrad.codegen.lowerer import get_contraction
|
||||
from tinygrad.codegen.symbolic import symbolic_simple
|
||||
from tinygrad.helpers import Metadata, all_int, all_same, colored, prod, dedup, unwrap, flatten, getenv, pluralize
|
||||
from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG, DONT_REALIZE_EXPAND, DONT_GROUP_REDUCES, SPLIT_REDUCEOP
|
||||
from tinygrad.helpers import Metadata, all_int, all_same, colored, prod, dedup, unwrap, flatten, getenv, pluralize, ContextVar, Context, diskcache_put
|
||||
from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG, DONT_REALIZE_EXPAND, DONT_GROUP_REDUCES, SPLIT_REDUCEOP, CAPTURE_PROCESS_REPLAY
|
||||
from tinygrad.dtype import ImageDType
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View, strides_for_shape
|
||||
|
|
@ -439,6 +439,13 @@ def get_name(ret:tuple[dict[UOp, UOp], dict[Variable, int]]) -> str:
|
|||
kcount = len({u.src[1] for u in ret[0].values() if u.op is Ops.ASSIGN})
|
||||
return f"Schedule {pluralize('Kernel', kcount)}"+(f" (with_{pluralize('Var', len(ret[1]))})" if ret[1] else "")
|
||||
|
||||
PROCESS_REPLAY_CAPTURE:dict[str, bytes] = {}
|
||||
if CAPTURE_PROCESS_REPLAY:
|
||||
import atexit
|
||||
@atexit.register
|
||||
def save_process_replay():
|
||||
for k,v in PROCESS_REPLAY_CAPTURE.items(): diskcache_put("schedule_process_replay", k, v, prepickled=True)
|
||||
|
||||
@track_rewrites(name_fxn=get_name)
|
||||
def get_becomes_map(big_sink:UOp) -> tuple[dict[UOp, UOp], dict[Variable, int]]:
|
||||
# merge_views + simplify
|
||||
|
|
@ -487,4 +494,12 @@ def get_becomes_map(big_sink:UOp) -> tuple[dict[UOp, UOp], dict[Variable, int]]:
|
|||
var_vals: dict[Variable, int] = {}
|
||||
sched_sink = graph_rewrite(sched_sink, create_ast, ctx=var_vals, bottom_up=True)
|
||||
becomes_map[big_sink] = sched_sink
|
||||
|
||||
# capture process replay
|
||||
if CAPTURE_PROCESS_REPLAY:
|
||||
with Context(PICKLE_BUFFERS=0):
|
||||
import pickle
|
||||
asts = dedup(u.arg.ast for u in sched_sink.toposort if u.op is Ops.KERNEL)
|
||||
PROCESS_REPLAY_CAPTURE[str(big_sink.key)] = pickle.dumps((big_sink, ContextVar._cache, asts))
|
||||
|
||||
return becomes_map, var_vals
|
||||
|
|
|
|||
|
|
@ -1,9 +1,8 @@
|
|||
import atexit, pickle
|
||||
from dataclasses import dataclass
|
||||
from collections import deque
|
||||
from tinygrad.ops import UOp, Variable, Ops, buffers
|
||||
from tinygrad.device import Buffer
|
||||
from tinygrad.helpers import Metadata, CAPTURE_PROCESS_REPLAY, DEBUG, Context, ContextVar, diskcache_put, unwrap
|
||||
from tinygrad.helpers import Metadata, DEBUG, unwrap
|
||||
from tinygrad.engine.grouper import get_becomes_map
|
||||
|
||||
# **** ScheduleItem return type
|
||||
|
|
@ -14,12 +13,6 @@ class ScheduleItem:
|
|||
bufs: tuple[Buffer, ...]
|
||||
metadata: tuple[Metadata, ...] = ()
|
||||
|
||||
PROCESS_REPLAY_CAPTURE:dict[str, bytes] = {}
|
||||
if CAPTURE_PROCESS_REPLAY:
|
||||
@atexit.register
|
||||
def save_process_replay():
|
||||
for k,v in PROCESS_REPLAY_CAPTURE.items(): diskcache_put("schedule_process_replay", k, v, prepickled=True)
|
||||
|
||||
# **** schedule linearizer
|
||||
|
||||
def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Variable, int], dict[UOp, UOp]]:
|
||||
|
|
@ -53,10 +46,6 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
|
|||
if len(schedule) != len(in_degree): raise RuntimeError(f"created {len(in_degree)} kernels but only scheduled {len(schedule)}")
|
||||
if DEBUG >= 1 and len(schedule) >= 10: print(f"scheduled {len(schedule)} kernels")
|
||||
|
||||
# capture process replay
|
||||
if CAPTURE_PROCESS_REPLAY:
|
||||
with Context(PICKLE_BUFFERS=0): PROCESS_REPLAY_CAPTURE[str(big_sink.key)] = pickle.dumps((big_sink, ContextVar._cache, [x.ast for x in schedule]))
|
||||
|
||||
# map ASSIGN to BUFFER after ScheduleItems are constructed
|
||||
for k,v in becomes_map.items():
|
||||
if v.base.op is Ops.ASSIGN:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue