mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
hcq2: debug=2 info (#16569)
* hcq2: debug=2 info * t * x * hcq2: debug=2 info * x
This commit is contained in:
parent
a83710396c
commit
e5f498de3b
5 changed files with 55 additions and 33 deletions
|
|
@ -1,9 +1,8 @@
|
|||
from __future__ import annotations
|
||||
from typing import cast, Callable, TypeVar, Generic, Any, TYPE_CHECKING
|
||||
from typing import cast, Callable, TypeVar, Generic, Any
|
||||
import struct, functools, time, collections, importlib, itertools, weakref
|
||||
from dataclasses import replace
|
||||
if TYPE_CHECKING: from tinygrad.engine.realize import ExecContext
|
||||
from tinygrad.helpers import DEV, getenv, select_first_inited, select_by_name, suppress_finalizing, mv_address, DEBUG, dedup, pluralize
|
||||
from tinygrad.helpers import DEV, getenv, select_first_inited, select_by_name, suppress_finalizing, mv_address, DEBUG, dedup, pluralize, to_tuple
|
||||
from tinygrad.device import Device, Buffer, BufferSpec, Compiled, LRUAllocator, MultiBuffer
|
||||
from tinygrad.uop.ops import Ops, sint, UOp, UPat, PatternMatcher, KernelInfo, graph_rewrite, track_rewrites, GroupOp
|
||||
from tinygrad.uop.symbolic import symbolic_simple, symbolic
|
||||
|
|
@ -12,7 +11,7 @@ from dataclasses import dataclass, field
|
|||
from tinygrad.runtime.support.memory import BumpAllocator
|
||||
from tinygrad.runtime.support.hcq import MMIOInterface
|
||||
from tinygrad.renderer import Renderer, Estimates
|
||||
from tinygrad.engine.realize import to_program, track_stats, get_call_arg_uops, resolve_params, pm_flatten_linear
|
||||
from tinygrad.engine.realize import to_program, get_call_arg_uops, get_call_name, get_call_outs_ins, estimate_uop, pm_flatten_linear
|
||||
from tinygrad.engine.jit import DepsTracker
|
||||
|
||||
HCQDeviceType = TypeVar('HCQDeviceType', bound='HCQ2Compiled')
|
||||
|
|
@ -166,10 +165,17 @@ def make_signal_value(devs, queue=None): return UOp.new_buffer(devs, 1, dtypes.u
|
|||
HCQ_DEVS = frozenset(("AMD",))
|
||||
HCQ_P2P_DEVS = HCQ_DEVS | frozenset(("CPU",))
|
||||
|
||||
def to_tuple(d): return d if isinstance(d, tuple) else (d,)
|
||||
|
||||
def all_devices_in(d:Any, c:frozenset[str]) -> bool: return {x.split(":")[0] for x in to_tuple(d)} <= c
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class HCQInfo:
|
||||
name:str = ""
|
||||
estimates:Estimates = Estimates()
|
||||
outs:tuple[int, ...] = ()
|
||||
|
||||
@staticmethod
|
||||
def from_call(call:UOp) -> HCQInfo: return HCQInfo(get_call_name(call, get_call_arg_uops(call)), estimate_uop(call), get_call_outs_ins(call)[0])
|
||||
|
||||
# *****************
|
||||
# 1.1. prep runtimes: staging copies
|
||||
|
||||
|
|
@ -199,7 +205,7 @@ def prep_program(call:UOp, prg:UOp) -> UOp|None:
|
|||
data, image_bytes = lowered
|
||||
buf = UOp.new_buffer(dev, len(image_bytes), dtypes.uint8).rtag("program")
|
||||
blob = UOp(Ops.BINARY, dtypes.void, src=(), arg=image_bytes)
|
||||
return call.replace(src=(prg.replace(src=(buf.after(buf.store(blob)),), arg=(data, prg.arg)),) + call.src[1:])
|
||||
return prg.replace(src=(buf.after(buf.store(blob)),), arg=(data, prg.arg)).call(*call.src[1:], aux=HCQInfo.from_call(call))
|
||||
|
||||
def prep_kernargs(call:UOp, prg:UOp) -> UOp:
|
||||
data, info = prg.arg
|
||||
|
|
@ -229,14 +235,14 @@ def make_submit(*cmds, devs:str|tuple[str, ...], queue:str) -> UOp:
|
|||
return UOp(Ops.CUSTOM_FUNCTION, dtypes.void, src=(UOp(Ops.LINEAR, dtypes.void, src=tuple(cmds), arg=(devs, queue)),), arg="submit")
|
||||
|
||||
def lower_program(call:UOp, prg:UOp) -> UOp:
|
||||
return make_submit(prg, devs=call.src[1].device, queue="COMPUTE:0").sink().call(*call.src[1:]).rtag(("hcq", tuple(prg.arg[1].outs)))
|
||||
return make_submit(prg, devs=call.src[1].device, queue="COMPUTE:0").sink().call(*call.src[1:], aux=call.arg.aux).rtag("hcq")
|
||||
|
||||
def lower_copy(call:UOp, copy:UOp) -> UOp|None:
|
||||
dst, src = call.src[1], call.src[2]
|
||||
if (hcq_dev:=next((b.device for b in (dst, src) if b.device.split(":")[0] in HCQ_DEVS), None)) is None: return None
|
||||
|
||||
cp_op = UOp(Ops.COPY, dtypes.void, src=(dst, src), arg=src.buffer.nbytes)
|
||||
return make_submit(cp_op, devs=hcq_dev, queue="COPY:0").sink().call(*call.src[1:]).rtag(("hcq", (0,)))
|
||||
return make_submit(cp_op, devs=hcq_dev, queue="COPY:0").sink().call(*call.src[1:], aux=HCQInfo.from_call(call)).rtag("hcq")
|
||||
|
||||
pm_lower_ops = PatternMatcher([
|
||||
(UPat(Ops.CALL, src=(UPat(Ops.PROGRAM, src=(UPat(Ops.BUFFER).or_after(), UPat(Ops.BUFFER).or_after()), name="prg"),),
|
||||
|
|
@ -267,7 +273,7 @@ class DepsCtx:
|
|||
def schedule_inner_sync(ctx:DepsCtx, linear:UOp) -> UOp:
|
||||
new_src = []
|
||||
for call in linear.src:
|
||||
if not isinstance(call.tag, tuple) or not call.tag[0] == "hcq":
|
||||
if call.tag != "hcq":
|
||||
new_src.append(call)
|
||||
continue
|
||||
|
||||
|
|
@ -277,10 +283,10 @@ def schedule_inner_sync(ctx:DepsCtx, linear:UOp) -> UOp:
|
|||
deps = []
|
||||
refs = [b.buffer for b in get_call_arg_uops(call)]
|
||||
for lane in range(len(refs[0].bufs) if isinstance(refs[0], MultiBuffer) else 1):
|
||||
deps += ctx.deps.access_resources([b.bufs[lane] if isinstance(b, MultiBuffer) else b for b in refs], call.tag[1], new_q)
|
||||
deps += ctx.deps.access_resources([b.bufs[lane] if isinstance(b, MultiBuffer) else b for b in refs], call.arg.aux.outs, new_q)
|
||||
|
||||
new_q = new_q.after(*dps).rtag("deps") if (dps:=dedup(deps)) else new_q
|
||||
new_src.append(call.replace(src=(call.src[0].substitute({q:new_q}), *call.src[1:]), tag="hcq"))
|
||||
new_src.append(call.replace(src=(call.src[0].substitute({q:new_q}), *call.src[1:])))
|
||||
return linear.replace(src=tuple(new_src))
|
||||
pm_schedule_inner_sync = PatternMatcher([(UPat(Ops.LINEAR, name="linear"), schedule_inner_sync)])
|
||||
|
||||
|
|
@ -295,7 +301,8 @@ def make_finalizer(queues:list[UOp], nbump:int) -> UOp:
|
|||
submit = make_submit(make_signal(devs).store(tl.index(zero) + 1), devs=devs, queue="COMPUTE:0")
|
||||
|
||||
upd = [(tl, 1)] + [(make_signal_value(devs, queue=qn), nbump) for qn in dedup([q.arg[1] for q in queues])]
|
||||
return UOp.barrier(*[s.after(submit).index(zero, dtype=s.dtype.ptr()).store(s.index(zero) + inc) for s, inc in upd]).sink().call().rtag("hcq")
|
||||
return UOp.barrier(*[s.after(submit).index(zero, dtype=s.dtype.ptr()).store(s.index(zero) + inc) for s, inc in upd]) \
|
||||
.sink().call(aux=HCQInfo("hcq finalizer")).rtag("hcq")
|
||||
|
||||
def add_finalizer(ctx:DepsCtx, linear:UOp) -> UOp:
|
||||
parts:dict[str, list[UOp]] = collections.defaultdict(list)
|
||||
|
|
@ -345,22 +352,24 @@ def merge_sinks(old_sink:UOp, new_sink:UOp) -> UOp:
|
|||
|
||||
def merge_queues(linear:UOp) -> UOp:
|
||||
new_src:list[UOp] = []
|
||||
opened_qs:dict[tuple[tuple[str, ...], str], UOp] = {} # (devs, queue) -> sink, kept in submit order
|
||||
opened_qs:dict[tuple[tuple[str, ...], str], tuple[UOp, HCQInfo]] = {} # (devs, queue) -> (sink, aux), kept in submit order
|
||||
|
||||
for call in linear.src:
|
||||
if call.tag != "hcq":
|
||||
new_src += [opened_qs.pop(k).call().rtag('hcq') for k in list(opened_qs)] + [call]
|
||||
new_src += [(sa:=opened_qs.pop(k))[0].call(aux=sa[1]).rtag('hcq') for k in list(opened_qs)] + [call]
|
||||
continue
|
||||
|
||||
devs, queue = get_submit(new_sink:=call.src[0]).src[0].arg
|
||||
if (old_sink:=opened_qs.pop((devs, queue), None)) is not None:
|
||||
new_sink = merge_sinks(old_sink, new_sink) # exact same queue: merge, and re-insert at the end
|
||||
aux = call.arg.aux
|
||||
if (old:=opened_qs.pop((devs, queue), None)) is not None:
|
||||
new_sink = merge_sinks(old[0], new_sink) # exact same queue: merge, and re-insert at the end
|
||||
aux = replace(aux, name=f"{queue.lower()} submit", estimates=old[1].estimates + aux.estimates)
|
||||
else:
|
||||
# no such queue opened: close every open submit on this queue that shares a device, so submit order is kept
|
||||
new_src += [opened_qs.pop(k).call().rtag('hcq') for k in [k for k in opened_qs if k[1] == queue and set(k[0]) & set(devs)]]
|
||||
opened_qs[(devs, queue)] = new_sink
|
||||
new_src += [(sa:=opened_qs.pop(k))[0].call(aux=sa[1]).rtag('hcq') for k in [k for k in opened_qs if k[1] == queue and set(k[0]) & set(devs)]]
|
||||
opened_qs[(devs, queue)] = (new_sink, aux)
|
||||
|
||||
return linear.replace(src=tuple(new_src + [sink.call().rtag('hcq') for sink in opened_qs.values()]))
|
||||
return linear.replace(src=tuple(new_src + [sink.call(aux=aux).rtag('hcq') for sink, aux in opened_qs.values()]))
|
||||
pm_merge_queues = PatternMatcher([(UPat(Ops.LINEAR, name="linear"), merge_queues)])
|
||||
|
||||
# *****************
|
||||
|
|
@ -467,8 +476,8 @@ def parametrize_host_buffers(call:UOp) -> UOp:
|
|||
pm_parametrize_host_buffers = PatternMatcher([(UPat(Ops.CALL, tag="hcq", name="call"), parametrize_host_buffers)])
|
||||
|
||||
def callify_hcq(call:UOp) -> UOp:
|
||||
sink = UOp.sink(call.src[0], arg=KernelInfo(name="hcq_submit", estimates=Estimates()), tag=1)
|
||||
return to_program(sink, Device["CPU"].renderer).call(*call.src[1:])
|
||||
prg = to_program(call.src[0].sink(arg=KernelInfo("hcq_submit"), tag=1), Device["CPU"].renderer)
|
||||
return UOp(Ops.CUSTOM_FUNCTION, dtypes.void, src=(prg,), arg="hcq").call(*call.src[1:], aux=call.arg.aux)
|
||||
pm_callify_hcq = PatternMatcher([(UPat(Ops.CALL, tag="hcq_param", name="call"), callify_hcq)])
|
||||
|
||||
@track_rewrites(lambda _,ret: f"HCQ Schedule {pluralize('Kernel', len(ret.src))}")
|
||||
|
|
@ -477,7 +486,7 @@ def hcq_schedule(linear:UOp) -> UOp:
|
|||
linear = graph_rewrite(linear, pm_prep_runtime, name="prepare runtime")
|
||||
|
||||
linear = graph_rewrite(linear, pm_lower_ops, name="lower ops into hcq ir")
|
||||
linear = graph_rewrite(linear, pm_schedule_inner_sync, ctx=(deps_ctx:=DepsCtx()), name="schedule inner sync")
|
||||
linear = graph_rewrite(linear, pm_schedule_inner_sync, ctx=(deps_ctx:=DepsCtx()), walk=True, name="schedule inner sync")
|
||||
linear = graph_rewrite(linear, pm_add_finalizer, ctx=deps_ctx, walk=True, name="add finalizer")
|
||||
linear = graph_rewrite(linear, pm_add_inner_loads, ctx=(waited:=set()), walk=True, name="add loads", enter_calls=True)
|
||||
linear = graph_rewrite(linear, pm_add_inner_stores, ctx=waited, walk=True, name="add stores", enter_calls=True)
|
||||
|
|
|
|||
|
|
@ -3,12 +3,12 @@ from typing import cast
|
|||
import os, ctypes, struct, hashlib, functools, importlib, mmap, errno, array, contextlib, sys, weakref, itertools, collections, atexit
|
||||
assert sys.platform != 'win32'
|
||||
from dataclasses import dataclass
|
||||
from extra.hcq2.hcq2 import HCQ2Compiled, HCQAllocator, HCQ2Buffer, to_tuple, make_getaddr, make_ins, make_cmdbuf
|
||||
from extra.hcq2.hcq2 import HCQ2Compiled, HCQAllocator, HCQ2Buffer, make_getaddr, make_ins, make_cmdbuf
|
||||
from tinygrad.uop.ops import sint, UOp
|
||||
from tinygrad.device import Compiled, BufferSpec, Buffer, Device
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.helpers import getenv, round_up, data64_le, DEBUG, PROFILE, ProfileEvent, lo32, hi32, colored, prod, ContextVar, TracingKey
|
||||
from tinygrad.helpers import VIZ, ceildiv, unwrap, pluralize
|
||||
from tinygrad.helpers import VIZ, ceildiv, unwrap, pluralize, to_tuple
|
||||
from tinygrad.renderer.cstyle import HIPRenderer, HIPCCRenderer
|
||||
from tinygrad.renderer.llvmir import AMDLLVMRenderer
|
||||
from tinygrad.runtime.autogen import kfd, hsa, sqtt, amdgpu_kd, amdgpu_drm
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
from __future__ import annotations
|
||||
from typing import cast, Iterator, Any
|
||||
from typing import cast, Iterator, Any, Sequence
|
||||
import time, random, itertools, math, contextlib, weakref
|
||||
from dataclasses import dataclass, replace, field
|
||||
from tinygrad.helpers import colored, DEBUG, GlobalCounters, ansilen, all_int, TRACEMETA, prod, flatten, Context, getenv
|
||||
from tinygrad.helpers import colored, DEBUG, GlobalCounters, ansilen, all_int, TRACEMETA, prod, flatten, Context, getenv, dedup, to_tuple
|
||||
from tinygrad.helpers import BEAM, size_to_str, time_to_str, VALIDATE_WITH_CPU, PROFILE, ProfilePointEvent, cpu_events
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.uop.ops import Ops, PatternMatcher, UOp, UPat, sym_infer, buffers, graph_rewrite, ProgramInfo
|
||||
|
|
@ -22,17 +22,19 @@ def get_call_outs_ins(call:UOp) -> tuple[tuple[int, ...], tuple[int, ...]]:
|
|||
if ast.op is Ops.CUSTOM_FUNCTION and ast.arg == "encdec": return (0,), tuple(range(1, len(get_call_arg_uops(call))))
|
||||
return (), ()
|
||||
|
||||
def get_call_name(call:UOp, bufs:list[Buffer], var_vals:dict[str, int]|None=None) -> str:
|
||||
def get_call_name(call:UOp, bufs:Sequence[Buffer|UOp], var_vals:dict[str, int]|None=None) -> str:
|
||||
def _uop_sz_to_str(uop:UOp) -> str: return size_to_str(sym_infer(prod(uop.shape) * uop.dtype.itemsize, var_vals or {}))
|
||||
def _dev_str(buf:Buffer|UOp) -> str: return ', '.join(d[:7] for d in to_tuple(buf.device))
|
||||
|
||||
ast, arg_uops = call.src[0], get_call_arg_uops(call)
|
||||
if ast.op is Ops.PROGRAM: return ast.arg.name
|
||||
if ast.op is Ops.SLICE:
|
||||
offset = ast.src[1].arg * arg_uops[1].dtype.itemsize
|
||||
return colored(f"view {_uop_sz_to_str(arg_uops[0]):>10} @ {offset:<10d}", "yellow")
|
||||
if ast.op is Ops.COPY: return colored(f"copy {_uop_sz_to_str(arg_uops[0]):>10}, {bufs[0].device[:7]:>7s} <- {bufs[1].device[:7]:7s}", "yellow")
|
||||
if ast.op is Ops.COPY: return colored(f"copy {_uop_sz_to_str(arg_uops[0]):>10}, {_dev_str(bufs[0]):>7s} <- {_dev_str(bufs[1]):7s}", "yellow")
|
||||
if ast.op is Ops.CUSTOM_FUNCTION and ast.arg == "encdec": return colored(f"enc/dec {_uop_sz_to_str(arg_uops[0])}", "yellow")
|
||||
if ast.op is Ops.CUSTOM_FUNCTION and ast.arg == "graph": return colored(f"batched {len(ast.src[0].src)}", "cyan")
|
||||
if ast.op is Ops.CUSTOM_FUNCTION and ast.arg == "hcq": return call.arg.aux.name
|
||||
raise NotImplementedError("get_call_name is not implemented")
|
||||
|
||||
# **************** Stat ****************
|
||||
|
|
@ -44,6 +46,7 @@ def estimate_uop(call:UOp) -> Estimates:
|
|||
nbytes = prod(call.src[1].shape) * call.src[1].dtype.itemsize
|
||||
return Estimates(lds=nbytes, mem=nbytes)
|
||||
if ast.op is Ops.CUSTOM_FUNCTION and ast.arg == "graph": return get_graph_runtime(ast).estimates
|
||||
if ast.op is Ops.CUSTOM_FUNCTION and ast.arg == "hcq": return call.arg.aux.estimates
|
||||
return Estimates()
|
||||
|
||||
first_run_cache:set[bytes] = set()
|
||||
|
|
@ -205,6 +208,13 @@ def exec_graph(ctx:ExecContext, call:UOp, ast:UOp) -> float|None:
|
|||
with track_stats(ctx, call, rt.device, [], ctx.var_vals) as t: t[0] = rt(ctx.input_uops, ctx.var_vals, wait=ctx.wait) # type: ignore[call-arg]
|
||||
return t[0]
|
||||
|
||||
def exec_hcq(ctx:ExecContext, call:UOp, ast:UOp) -> float|None:
|
||||
pm_exec.rewrite(call.replace(src=(ast,) + call.src[1:]), replace(ctx, update_stats=False))
|
||||
for d in dedup(flatten([to_tuple(u.device) for u in resolve_params(call, ctx.input_uops)])):
|
||||
with track_stats(ctx, call, d, [], ctx.var_vals):
|
||||
if ctx.wait: Device[d].synchronize()
|
||||
return None
|
||||
|
||||
# flatten LINEAR-in-LINEAR: any nested LINEAR child gets inlined into its parent's src
|
||||
pm_flatten_linear = PatternMatcher([
|
||||
(UPat(Ops.LINEAR, custom_early_reject={Ops.LINEAR}, name="lin"),
|
||||
|
|
@ -239,6 +249,7 @@ pm_exec = PatternMatcher([
|
|||
(UPat(Ops.CALL, src=(UPat(Ops.PROGRAM, name="ast"),), name="call", allow_any_len=True), exec_kernel),
|
||||
(UPat(Ops.CALL, src=(UPat(Ops.CUSTOM_FUNCTION, arg="encdec", name="ast"),), name="call", allow_any_len=True), exec_encdec),
|
||||
(UPat(Ops.CALL, src=(UPat(Ops.CUSTOM_FUNCTION, arg="graph", name="ast"),), name="call", allow_any_len=True), exec_graph),
|
||||
(UPat(Ops.CALL, src=(UPat(Ops.CUSTOM_FUNCTION, arg="hcq", src=(UPat(Ops.PROGRAM, name="ast"),)),), name="call", allow_any_len=True), exec_hcq),
|
||||
(UPat(Ops.CALL, src=(UPat(Ops.CUSTOM_FUNCTION, arg="validate", name="ast"),), name="call", allow_any_len=True), exec_validate),
|
||||
])
|
||||
|
||||
|
|
|
|||
|
|
@ -45,6 +45,7 @@ def size_to_str(s:int) -> str: return next((f"{s / d:.2f} {pr}" for d,pr in [(1<
|
|||
def ansistrip(s:str): return re.sub('\x1b\\[(K|.*?m)', '', s)
|
||||
def ansilen(s:str): return len(ansistrip(s))
|
||||
def make_tuple(x:int|Sequence[int], cnt:int) -> tuple[int, ...]: return (x,)*cnt if isinstance(x, int) else tuple(x)
|
||||
def to_tuple(x:T|tuple[T, ...]) -> tuple[T, ...]: return x if isinstance(x, tuple) else (x,)
|
||||
def flatten(l:Iterable[Iterable[T]]): return [item for sublist in l for item in sublist]
|
||||
def fully_flatten(l):
|
||||
if not (hasattr(l, "__len__") and hasattr(l, "__getitem__")) or isinstance(l, str): return [l]
|
||||
|
|
|
|||
|
|
@ -1080,13 +1080,13 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
|
|||
# opaque bodies stay as Ops.CALL; value-producing bodies become Ops.FUNCTION (wrapped in TUPLE)
|
||||
_OPAQUE_CALL_BODIES = {Ops.SINK, Ops.PROGRAM, Ops.LINEAR, Ops.COPY, Ops.SLICE, Ops.CUSTOM_FUNCTION}
|
||||
def call(self, *srcs:UOp, grad_fxn:Callable|None=None, metadata:tuple[Metadata, ...]=(),
|
||||
name:str|None=None, precompile:bool=False, precompile_backward:bool=False) -> UOp:
|
||||
name:str|None=None, precompile:bool=False, precompile_backward:bool=False, aux:Any=None) -> UOp:
|
||||
assert len(self.ranges) == 0, f"ranges {self.ranges} are leaking out of the call in {self.pyrender()}"
|
||||
if self.op in UOp._OPAQUE_CALL_BODIES:
|
||||
return UOp(Ops.CALL, dtypes.void, (self,)+srcs, CallInfo(grad_fxn, metadata, name, precompile, precompile_backward))
|
||||
return UOp(Ops.CALL, dtypes.void, (self,)+srcs, CallInfo(grad_fxn, metadata, name, precompile, precompile_backward, aux))
|
||||
# value-producing bodies are always wrapped in TUPLE so FUNCTION dtype is always void
|
||||
body = self if self.op is Ops.TUPLE else UOp.maketuple(self)
|
||||
return UOp(Ops.FUNCTION, dtypes.void, (body,)+srcs, CallInfo(grad_fxn, metadata, name, precompile, precompile_backward))
|
||||
return UOp(Ops.FUNCTION, dtypes.void, (body,)+srcs, CallInfo(grad_fxn, metadata, name, precompile, precompile_backward, aux))
|
||||
def custom_kernel(*srcs:UOp, fxn:Callable, grad_fxn:Callable|None=None) -> list[UOp]:
|
||||
contig_srcs = tuple(x.contiguous() if x.op is not Ops.AFTER else x for x in srcs)
|
||||
placeholders = [UOp.placeholder_like(s, slot=i) for i,s in enumerate(contig_srcs)]
|
||||
|
|
@ -1159,8 +1159,9 @@ class CallInfo:
|
|||
name: str|None = None
|
||||
precompile: bool = False
|
||||
precompile_backward: bool = False
|
||||
aux: Any = None
|
||||
# grad_fxn can't be pickled, but metadata can
|
||||
def __reduce__(self): return (CallInfo, (None, self.metadata, self.name, self.precompile, self.precompile_backward))
|
||||
def __reduce__(self): return (CallInfo, (None, self.metadata, self.name, self.precompile, self.precompile_backward, self.aux))
|
||||
def __repr__(self):
|
||||
gf = id(self.grad_fxn) if self.grad_fxn else None
|
||||
return f"CallInfo({gf}, {self.metadata}, {repr(self.name)}, {self.precompile}, {self.precompile_backward})"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue