mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
faster process replay [pr] (#13564)
This commit is contained in:
parent
6eab756578
commit
3eae146139
3 changed files with 10 additions and 10 deletions
|
|
@ -177,7 +177,7 @@ WINO, CAPTURING, TRACEMETA = ContextVar("WINO", 0), ContextVar("CAPTURING", 1),
|
|||
USE_TC, TC_SELECT, TC_OPT, AMX = ContextVar("TC", 1), ContextVar("TC_SELECT", -1), ContextVar("TC_OPT", 0), ContextVar("AMX", 0)
|
||||
TRANSCENDENTAL, NOLOCALS = ContextVar("TRANSCENDENTAL", 1), ContextVar("NOLOCALS", 0)
|
||||
SPLIT_REDUCEOP, NO_MEMORY_PLANNER, RING = ContextVar("SPLIT_REDUCEOP", 1), ContextVar("NO_MEMORY_PLANNER", 0), ContextVar("RING", 1)
|
||||
PICKLE_BUFFERS, LRU = ContextVar("PICKLE_BUFFERS", 1), ContextVar("LRU", 1)
|
||||
LRU = ContextVar("LRU", 1)
|
||||
CACHELEVEL, IGNORE_BEAM_CACHE, DEVECTORIZE = ContextVar("CACHELEVEL", 2), ContextVar("IGNORE_BEAM_CACHE", 0), ContextVar("DEVECTORIZE", 1)
|
||||
VALIDATE_WITH_CPU, DISABLE_FAST_IDIV = ContextVar("VALIDATE_WITH_CPU", 0), ContextVar("DISABLE_FAST_IDIV", 0)
|
||||
CORRECT_DIVMOD_FOLDING, FUSE_OPTIM = ContextVar("CORRECT_DIVMOD_FOLDING", 0), ContextVar("FUSE_OPTIM", 0)
|
||||
|
|
|
|||
|
|
@ -538,7 +538,7 @@ replace_contiguous = PatternMatcher([
|
|||
(UPat(GroupOp.ALU, name="alu"), lambda ctx,alu: alu.replace(src=new_src) if (new_src:=tuple(ctx.get(s, s) for s in alu.src)) != alu.src else None),
|
||||
])
|
||||
|
||||
@track_rewrites(lambda _,ret: f"Schedule {pluralize('Kernel', len([u for u in UOp.sink(*ret.values()).toposort() if u.op is Ops.KERNEL]))}", True)
|
||||
@track_rewrites(lambda _,ret: f"Schedule {pluralize('Kernel', len([u for u in UOp.sink(*ret.values()).toposort() if u.op is Ops.KERNEL]))}")
|
||||
def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
|
||||
if getenv("VIZ"): graph_rewrite(sink, PatternMatcher([]), name="View Input Graph")
|
||||
uop_list: list[UOp] = []
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from enum import Enum, auto
|
|||
from tinygrad.uop import Ops, GroupOp
|
||||
from tinygrad.dtype import ConstType, ImageDType, dtypes, DType, truncate, PtrDType, least_upper_dtype, Invalid, InvalidType, AddrSpace
|
||||
from tinygrad.helpers import ContextVar, all_int, prod, getenv, all_same, Context, partition, temp, unwrap, T, argfix, Metadata, flatten, TRACEMETA
|
||||
from tinygrad.helpers import PICKLE_BUFFERS, PROFILE, dedup, cdiv, cmod, diskcache_put, to_function_name, cpu_profile, TracingKey, VIZ, SPEC, CI
|
||||
from tinygrad.helpers import PROFILE, dedup, cdiv, cmod, diskcache_put, to_function_name, cpu_profile, TracingKey, VIZ, SPEC, CI
|
||||
from tinygrad.helpers import strip_parens, colored, ansilen, printable
|
||||
if TYPE_CHECKING:
|
||||
from tinygrad.device import Buffer, MultiBuffer
|
||||
|
|
@ -133,7 +133,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
except AttributeError: pass
|
||||
def __reduce__(self):
|
||||
args = [self.op, self.dtype, self.src, self.arg, self.tag, self.metadata]
|
||||
if self.op is Ops.BUFFER and self.realized is not None and PICKLE_BUFFERS: args.append(self.realized)
|
||||
if self.op is Ops.BUFFER and self.realized is not None: args.append(self.realized)
|
||||
return UOp, tuple(args)
|
||||
def replace(self, **kwargs) -> UOp:
|
||||
new_args = (kwargs.pop("op", self.op), kwargs.pop("dtype", self.dtype), kwargs.pop("src", self.src),
|
||||
|
|
@ -1072,11 +1072,12 @@ tracked_ctxs:list[list[TrackedGraphRewrite]] = []
|
|||
_name_cnt:dict[str, itertools.count] = {}
|
||||
|
||||
if getenv("CAPTURE_PROCESS_REPLAY"):
|
||||
replay_capture: dict[str, bytes] = {}
|
||||
import atexit
|
||||
replay_capture: list[bytes] = []
|
||||
import atexit, uuid
|
||||
@atexit.register
|
||||
def save_to_diskcache():
|
||||
for k,v in replay_capture.items(): diskcache_put("process_replay", k, v, prepickled=True)
|
||||
uid = uuid.uuid4() # one id per process
|
||||
for i,v in enumerate(replay_capture): diskcache_put("process_replay", f"{uid}_{i}", v, prepickled=True)
|
||||
|
||||
def add_trace_group(kt:TracingKey) -> None:
|
||||
tracked_keys.append(kt)
|
||||
|
|
@ -1100,9 +1101,8 @@ def track_rewrites(name:Callable[..., str|TracingKey]|bool=True, replay:bool=Fal
|
|||
while (f_back:=frm.f_back) is not None and "unittest" not in f_back.f_code.co_filename: frm = f_back
|
||||
loc = f"{frm.f_code.co_filename.split('/')[-1]}:{frm.f_lineno} {frm.f_code.co_name}"
|
||||
# capture global context vars and all the args passed in
|
||||
with Context(PICKLE_BUFFERS=0):
|
||||
inputs = (fn, args, kwargs, ContextVar._cache)
|
||||
replay_capture[hashlib.sha256(pickle.dumps(inputs)).hexdigest()] = pickle.dumps(inputs+(loc, ret))
|
||||
inputs = (fn, args, kwargs, ContextVar._cache)
|
||||
replay_capture.append(pickle.dumps(inputs+(loc, ret)))
|
||||
return ret
|
||||
return __wrapper
|
||||
return _decorator
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue