mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
6 commits
master
...
kernel_is_
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
771a395240 |
||
|
|
d83ddc05c8 | ||
|
|
8e8cac4b0f | ||
|
|
57199fd9de | ||
|
|
2193d0edfa | ||
|
|
77adccb925 |
10 changed files with 92 additions and 47 deletions
|
|
@ -1,7 +1,7 @@
|
|||
import time
|
||||
from typing import cast
|
||||
from collections import deque
|
||||
from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass, track_rewrites, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, Kernel
|
||||
from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass, track_rewrites, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, CallInfo
|
||||
from tinygrad.uop.spec import type_verify, tensor_spec
|
||||
from tinygrad.device import Buffer, MultiBuffer
|
||||
from tinygrad.helpers import DEBUG, cpu_profile, TracingKey, SPEC, flatten, pluralize, SCACHE, Metadata
|
||||
|
|
@ -28,8 +28,7 @@ def create_schedule(sched_sink:UOp) -> tuple[list[ExecItem], UOp]:
|
|||
if (k:=u.src[1]).op is Ops.RANGE: continue # RANGEs are scheduled directly, not through dependency graph
|
||||
assert k.op in {Ops.KERNEL, Ops.END}, f"AFTER src[1] should be KERNEL or END, not {k.op}"
|
||||
in_degree.setdefault(k, 0)
|
||||
if k.op is Ops.END: assert k.src[0].op is Ops.KERNEL, f"END src[0] should be KERNEL, not {k.src[0].op}"
|
||||
for s in k.src[0].src if k.op is Ops.END else k.src:
|
||||
for s in k.src[0].src if k.op is Ops.END else k.src[1:]:
|
||||
match (s := _unwrap_src(s)).op:
|
||||
case Ops.AFTER:
|
||||
children.setdefault(s.src[1], []).append(k)
|
||||
|
|
@ -54,14 +53,13 @@ def create_schedule(sched_sink:UOp) -> tuple[list[ExecItem], UOp]:
|
|||
while len(queue):
|
||||
k = rk = queue.popleft()
|
||||
if k.op is Ops.END: k = k.src[0]
|
||||
assert k.op in {Ops.RANGE, Ops.KERNEL}, f"unexpected op in queue: {k.op}"
|
||||
assert k.op in {Ops.RANGE, Ops.CALL}, f"unexpected op in queue: {k.op}"
|
||||
if k.op is Ops.RANGE: schedule.append(k)
|
||||
elif k.op is Ops.KERNEL:
|
||||
ast = (kernel:=cast(Kernel, k.arg)).ast
|
||||
buf_uops = tuple(_unwrap_src(s).buf_uop for s in k.src if s.op is not Ops.BIND)
|
||||
bound_ranges = tuple(s for s in k.src if s.op is Ops.BIND and len(s.src) > 1 and s.src[1].op is Ops.RANGE)
|
||||
sched_item[k] = (ast, buf_uops, kernel.metadata, bound_ranges)
|
||||
schedule.append(k)
|
||||
elif k.op is Ops.CALL:
|
||||
ast = k.src[0]
|
||||
buf_uops = tuple(_unwrap_src(s).buf_uop for s in k.src[1:] if s.op is not Ops.BIND)
|
||||
bound_ranges = tuple(s for s in k.src[1:] if s.op is Ops.BIND and len(s.src) > 1 and s.src[1].op is Ops.RANGE)
|
||||
schedule.append((ast, buf_uops, cast(CallInfo, k.arg).metadata, {}, bound_ranges))
|
||||
if rk.op is Ops.END: schedule.append(rk)
|
||||
for x in children.get(rk, []):
|
||||
in_degree[x] -= 1
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ def reduce_gradient(ctx:UOp, ret:UOp, op:Ops):
|
|||
if op == Ops.MUL: return (broadcast_to_input(ctx * ret) / ret.src[0],)
|
||||
|
||||
def call_gradient(ctx:UOp, k:UOp):
|
||||
if k.arg is not None: return (None,) + k.arg(ctx, k)
|
||||
if k.arg.grad_fxn is not None: return (None,) + k.arg.grad_fxn(ctx, k)
|
||||
# auto-differentiate the function
|
||||
fxn, args = k.src[0], k.src[1:]
|
||||
params = sorted([x for x in fxn.toposort() if x.op == Ops.PARAM], key=lambda x: x.arg)
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from tinygrad.helpers import argsort, all_same, cpu_profile, PCONTIG, colored
|
|||
|
||||
ALWAYS_CONTIGUOUS: set[Ops] = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW,
|
||||
Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.PARAM,
|
||||
Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.LOAD, Ops.KERNEL, Ops.ENCDEC}
|
||||
Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.LOAD, Ops.CALL, Ops.ENCDEC}
|
||||
|
||||
def realize(ctx:dict[UOp, None], tr:UOp) -> None: ctx[tr] = None
|
||||
|
||||
|
|
@ -20,7 +20,7 @@ def realize_srcs(ctx:dict[UOp, None], rb:UOp) -> None:
|
|||
def realize_assign(ctx:dict[UOp, None], a:UOp) -> None:
|
||||
if a.src[1].op not in ALWAYS_CONTIGUOUS: ctx[a.src[1]] = None
|
||||
# if it's a kernel, we don't realize it
|
||||
if a.src[1].op is not Ops.KERNEL: ctx[a] = None
|
||||
if a.src[1].op is not Ops.CALL: ctx[a] = None
|
||||
|
||||
pm_generate_realize_map = PatternMatcher([
|
||||
# always realize SINK src
|
||||
|
|
@ -99,7 +99,7 @@ def remove_movement_op_after_rangeify(ctx:IndexingContext, x:UOp):
|
|||
if x in ctx.range_map or x.src[0].op is Ops.INDEX: return x.src[0]
|
||||
|
||||
def add_third_op_to_assign_to_track_shape(ctx:IndexingContext, assign:UOp):
|
||||
if assign.src[1].op is Ops.KERNEL: return None
|
||||
if assign.src[1].op is Ops.CALL: return None
|
||||
to_mop = graph_rewrite(assign.src[0], PatternMatcher([(UPat(GroupOp.Movement, name="x"), lambda x: x.replace(tag=()))]))
|
||||
ret = assign.replace(src=assign.src+(to_mop,))
|
||||
ctx.range_map[ret] = ctx.range_map[assign]
|
||||
|
|
@ -174,7 +174,7 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
|
|||
if x.op in {Ops.DEVICE, Ops.UNIQUE}: continue
|
||||
|
||||
# no ranges on kernels, they are internal
|
||||
if x.op is Ops.KERNEL: continue
|
||||
if x.op is Ops.CALL: continue
|
||||
|
||||
if x.dtype.scalar() == dtypes.index: continue # TODO: why do I need this?
|
||||
ending_ranges[x] = sum([ending_ranges.get(u, []) for u in consumer_map[x]], [])
|
||||
|
|
|
|||
|
|
@ -2,10 +2,10 @@ from dataclasses import dataclass, field, replace
|
|||
import itertools
|
||||
from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, KernelInfo
|
||||
from tinygrad.uop.ops import graph_rewrite, identity_element, sint, AxisType, BottomUpGate, Kernel, _remove_all_tags, range_str
|
||||
from tinygrad.uop.ops import graph_rewrite, identity_element, sint, AxisType, BottomUpGate, _remove_all_tags, range_str
|
||||
from tinygrad.uop.symbolic import symbolic
|
||||
from tinygrad.helpers import argsort, prod, all_same, getenv, flatten, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY
|
||||
from tinygrad.helpers import PCONTIG, partition, get_single_element
|
||||
from tinygrad.helpers import PCONTIG, partition, get_single_element, panic
|
||||
from tinygrad.codegen.simplify import pm_flatten_range, pm_reduce_simplify
|
||||
from tinygrad.codegen.opt import Opt
|
||||
from tinygrad.schedule.indexing import run_rangeify, BufferizeOpts, ALWAYS_CONTIGUOUS, IndexingContext, apply_movement_op
|
||||
|
|
@ -66,9 +66,11 @@ mop_cleanup = PatternMatcher([
|
|||
|
||||
def resolve_custom_kernel(ck:UOp) -> UOp:
|
||||
placeholders = [UOp.placeholder_like(s, slot=i) for i,s in enumerate(ck.src)]
|
||||
return UOp(Ops.KERNEL, src=ck.src, arg=Kernel(ck.arg.fxn(*placeholders)))
|
||||
return ck.arg.fxn(*placeholders).call(*ck.src)
|
||||
|
||||
def resolve_call(c:UOp) -> UOp:
|
||||
def resolve_call(c:UOp) -> UOp|None:
|
||||
# don't resolve CALLs with SINK - those are kernel calls from split_store/custom_kernel
|
||||
if c.src[0].op is Ops.SINK: return None
|
||||
params = sorted([x for x in c.src[0].toposort() if x.op == Ops.PARAM], key=lambda x: x.arg)
|
||||
args = c.src[1:]
|
||||
# TODO: this check belongs in spec, not here
|
||||
|
|
@ -83,7 +85,7 @@ earliest_rewrites = mop_cleanup+PatternMatcher([
|
|||
# just removing it works...
|
||||
(UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD), name="x"), lambda x: x.src[0]),
|
||||
|
||||
# resolve calls
|
||||
# resolve calls (but not CALLs with SINK - those are kernel calls from split_store)
|
||||
(UPat(Ops.CALL, name="c"), resolve_call),
|
||||
|
||||
# resolve custom kernels
|
||||
|
|
@ -384,8 +386,8 @@ pm_add_buffers = pm_mops+pm_flatten_bufferize+to_bufferview+PatternMatcher([
|
|||
(UPat((Ops.MSELECT, Ops.MSTACK), src=UPat(Ops.RESHAPE), name="m"),
|
||||
lambda m: m.replace(src=tuple([x.src[0].base for x in m.src]), tag=None).reshape(m.shape).rtag(m.tag)),
|
||||
|
||||
# remove any RESHAPEs on KERNEL
|
||||
(UPat(Ops.KERNEL, name="k"), lambda k: k.replace(src=tuple(x.src[0] if x.op is Ops.RESHAPE else x for x in k.src))),
|
||||
# remove any RESHAPEs on CALL
|
||||
(UPat(Ops.CALL, name="k"), lambda k: k.replace(src=(k.src[0],)+tuple(x.src[0] if x.op is Ops.RESHAPE else x for x in k.src[1:]))),
|
||||
])
|
||||
|
||||
pm_add_buffers_local = pm_mops+pm_flatten_bufferize+to_bufferview+PatternMatcher([
|
||||
|
|
@ -515,16 +517,17 @@ def split_store(ctx:list[UOp], x:UOp) -> UOp|None:
|
|||
else: raise RuntimeError(f"unknown kernel type {ret.op}")
|
||||
if stored.op in {Ops.COPY, Ops.BUFFER_VIEW, Ops.ENCDEC}: ret = stored
|
||||
else:
|
||||
ret = ret.sink(arg=KernelInfo(opts_to_apply=lctx.opts) if lctx.opts is not None else None)
|
||||
|
||||
kernel_arg = Kernel(ret,tuple(dedup(flatten([x for x in metadatas if x is not None])))[::-1])
|
||||
kernel = UOp(Ops.KERNEL, src=tuple(lctx.map.values())+tuple(lctx.vars.keys()), arg=kernel_arg)
|
||||
if ret.op is Ops.SINK and not all_same([x.device for x in kernel.src if x.op is not Ops.BIND]):
|
||||
raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop for b in kernel.src)}")
|
||||
ret = ret.sink(arg=KernelInfo(opts_to_apply=lctx.opts) if lctx.opts is not None else KernelInfo())
|
||||
metadata = tuple(dedup(flatten([x for x in metadatas if x is not None])))[::-1]
|
||||
kernel = ret.call(*lctx.map.values(), *lctx.vars.keys(), metadata=metadata)
|
||||
if ret.op is Ops.SINK and not all_same([x.device for x in kernel.src[1:] if x.op is not Ops.BIND]):
|
||||
raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop for b in kernel.src[1:])}")
|
||||
return kernel
|
||||
|
||||
split_kernels = PatternMatcher([
|
||||
(UPat((Ops.STORE, Ops.END), name="x"), split_store),
|
||||
# if it's a Kernel, stop
|
||||
(UPat(Ops.SINK, name="sink"), lambda sink: panic(BottomUpGate()) if isinstance(sink.arg, KernelInfo) else None),
|
||||
])
|
||||
|
||||
def tag_uop(ctx:tuple[list[UOp], set[UOp]], x:UOp):
|
||||
|
|
@ -537,7 +540,7 @@ def tag_uop(ctx:tuple[list[UOp], set[UOp]], x:UOp):
|
|||
return x.replace(tag=(len(ctx[0])-1,))
|
||||
add_tags = PatternMatcher([
|
||||
# don't tag BUFFERs, they are global
|
||||
(UPat(GroupOp.All-{Ops.BUFFER, Ops.CONST, Ops.DEVICE, Ops.UNIQUE, Ops.LUNIQUE, Ops.DEFINE_VAR, Ops.BIND, Ops.KERNEL, Ops.END,
|
||||
(UPat(GroupOp.All-{Ops.BUFFER, Ops.CONST, Ops.DEVICE, Ops.UNIQUE, Ops.LUNIQUE, Ops.DEFINE_VAR, Ops.BIND, Ops.CALL, Ops.END,
|
||||
Ops.MSTACK, Ops.MSELECT, Ops.RANGE}.union(GroupOp.Movement), name="x"), tag_uop),
|
||||
(UPat({Ops.MSTACK, Ops.MSELECT}, name="x"), lambda ctx,x: None if all(s.op is Ops.BUFFER for s in x.src) else tag_uop(ctx, x)),
|
||||
])
|
||||
|
|
@ -581,7 +584,7 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
|
|||
# bufferize -> store
|
||||
lunique_start: int = max([-1]+[x.arg for x in tsink.toposort() if x.op is Ops.LUNIQUE]) + 1
|
||||
tsink = graph_rewrite(tsink, pm_add_buffers+pm_add_range_tags, ctx=itertools.count(lunique_start), bottom_up=True, name="bufferize to store")
|
||||
tsink = graph_rewrite(tsink, split_kernels, ctx=uop_list, name="split kernels")
|
||||
tsink = graph_rewrite(tsink, split_kernels, ctx=uop_list, name="split kernels", bottom_up=True)
|
||||
|
||||
# if a kernel depends on a buffer, and that buffer is later assigned to, make the assign depend on the kernel's assign
|
||||
kernel_assign: dict[UOp, UOp] = {}
|
||||
|
|
|
|||
|
|
@ -240,7 +240,7 @@ class Tensor(OpMixin):
|
|||
param = UOp.param(slot, self.dtype, self.shape, self.device)
|
||||
return Tensor(param, device=self.device)
|
||||
def call(self, *lst:Tensor, fxn:Tensor|UOp, grad_fxn:Callable|None=None) -> Tensor:
|
||||
return Tensor(UOp.call(*[t.uop for t in (self,)+lst], fxn=fxn.uop if isinstance(fxn, Tensor) else fxn, arg=grad_fxn), device=self.device)
|
||||
return Tensor((fxn.uop if isinstance(fxn, Tensor) else fxn).call(*[t.uop for t in (self,)+lst], grad_fxn=grad_fxn), device=self.device)
|
||||
|
||||
def custom_kernel(self, *lst:Tensor, fxn:Callable, grad_fxn:Callable|None=None) -> list[Tensor]:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -364,7 +364,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
@recursive_property
|
||||
def trace_num(self):
|
||||
num = next(ucount)
|
||||
# KERNEL also has a UOp in the arg
|
||||
# KERNEL has a UOp in the arg, CALL has it in src[0] so no special handling needed
|
||||
arg = type(self.arg)(self.arg.ast.trace_num, self.arg.metadata) if self.op is Ops.KERNEL else self.arg
|
||||
uop_fields[num] = (self.op, self.dtype, tuple(s.trace_num for s in self.src), arg, self.tag)+((self.metadata,) if TRACEMETA>=2 else ())
|
||||
return num
|
||||
|
|
@ -818,7 +818,8 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
src = (UOp(Ops.NOOP) if shape is None else shape_to_shape_arg(shape),) + (() if device is None else (UOp(Ops.DEVICE, arg=device),))
|
||||
return UOp(Ops.PARAM, dtype, src, arg=slot)
|
||||
|
||||
def call(*srcs:UOp, fxn:UOp, arg:Any|None) -> UOp: return UOp(Ops.CALL, fxn.dtype, (fxn,)+srcs, arg)
|
||||
def call(self, *srcs:UOp, grad_fxn:Callable|None=None, metadata:tuple[Metadata, ...]=()) -> UOp:
|
||||
return UOp(Ops.CALL, self.dtype, (self,)+srcs, CallInfo(grad_fxn, metadata))
|
||||
def custom_kernel(*srcs:UOp, fxn:Callable, grad_fxn:Callable|None=None) -> list[UOp]:
|
||||
contig_srcs = tuple(x.contiguous() if x.op is not Ops.AFTER else x for x in srcs)
|
||||
kernel = UOp(Ops.CUSTOM_KERNEL, src=contig_srcs, arg=CustomKernel(fxn=fxn, grad_fxn=grad_fxn))
|
||||
|
|
@ -843,6 +844,14 @@ class CustomKernel:
|
|||
def __reduce__(self): return (CustomKernel, (panic,))
|
||||
def __repr__(self): return f"CustomKernel({id(self.fxn)})"
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CallInfo:
|
||||
grad_fxn: Callable|None = None
|
||||
metadata: tuple[Metadata, ...] = ()
|
||||
# grad_fxn can't be pickled, but metadata can
|
||||
def __reduce__(self): return (CallInfo, (None, self.metadata))
|
||||
def __repr__(self): return f"CallInfo({id(self.grad_fxn) if self.grad_fxn else None}, {self.metadata})"
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Kernel:
|
||||
ast: UOp
|
||||
|
|
@ -1405,7 +1414,7 @@ pm_pyrender_extra = PatternMatcher([
|
|||
|
||||
# NOTE: you can remove pm_pyrender_extra and it'll still be correct
|
||||
pm_pyrender = pm_pyrender_extra+PatternMatcher([
|
||||
(UPat(Ops.KERNEL, name="u"), lambda ctx,u: f"UOp(Ops.KERNEL, src={srcs(ctx,u.src)}, arg=Kernel({ctx[u.arg.ast]}(), {u.arg.metadata}))"),
|
||||
(UPat(Ops.CALL, name="u"), lambda ctx,u: f"{ctx[u.src[0]]}.call({', '.join(ctx[s] for s in u.src[1:])}, metadata={u.arg.metadata})"),
|
||||
(UPat(GroupOp.All, name="u"), lambda ctx,u: f"UOp({u.op}, {u.dtype}, {srcs(ctx,u.src)}"+(f", {repr(u.arg)})" if u.arg is not None else ")")),
|
||||
])
|
||||
|
||||
|
|
@ -1415,7 +1424,7 @@ def pyrender(ast:UOp) -> str:
|
|||
cmap = consumer_map_from_toposort(lst)
|
||||
not_rendered = {Ops.CONST, Ops.VCONST, Ops.DEVICE}
|
||||
always_rendered = {Ops.PARAM, Ops.LOAD, Ops.SPECIAL, Ops.RANGE, Ops.CONTIGUOUS, Ops.VECTORIZE,
|
||||
Ops.BUFFER, Ops.COPY, Ops.KERNEL, Ops.WHERE, Ops.END, Ops.ASSIGN}
|
||||
Ops.BUFFER, Ops.COPY, Ops.CALL, Ops.WHERE, Ops.END, Ops.ASSIGN}
|
||||
|
||||
to_render: set[UOp] = {ast}
|
||||
for u in lst:
|
||||
|
|
|
|||
|
|
@ -87,8 +87,9 @@ _tensor_spec = PatternMatcher([
|
|||
(UPat(Ops.BUFFER, src=(UPat((Ops.LUNIQUE, Ops.UNIQUE)), UPat(Ops.DEVICE)), name="buf"),
|
||||
lambda buf: isinstance(buf.arg, int) and isinstance(buf.dtype, (DType, ImageDType))),
|
||||
|
||||
# KERNEL can attach to an AFTER to describe the compute required to realize a BUFFER
|
||||
(UPat(Ops.KERNEL, src=UPat((Ops.BUFFER, Ops.AFTER, Ops.MSELECT, Ops.MSTACK, Ops.BIND))), lambda: True),
|
||||
# CALL can attach to an AFTER to describe the compute required to realize a BUFFER
|
||||
# src[0] is the function (SINK), src[1:] are buffers/bindings
|
||||
(UPat(Ops.CALL, src=(UPat(Ops.SINK), UPat((Ops.BUFFER, Ops.AFTER, Ops.MSELECT, Ops.MSTACK, Ops.BIND))), allow_any_len=True), lambda: True),
|
||||
|
||||
# ASSIGN has a target and a value. It can also optionally depend on other assigns
|
||||
(UPat(Ops.ASSIGN, name="x"), lambda x: len(x.src) >= 2 and all(s.op is Ops.ASSIGN for s in x.src[2:])),
|
||||
|
|
@ -249,8 +250,8 @@ full_spec = PatternMatcher([
|
|||
# vectorized index
|
||||
(UPat(Ops.INDEX, src=(UPat((Ops.VECTORIZE, Ops.CAST)), UPat())), lambda: True),
|
||||
|
||||
# linearizer: outputs + intermediate KERNELs
|
||||
(UPat(Ops.KERNEL, dtype=dtypes.void), lambda: True),
|
||||
# linearizer: outputs + intermediate CALLs
|
||||
(UPat(Ops.CALL, dtype=dtypes.void), lambda: True),
|
||||
|
||||
# Invalid must have type Index
|
||||
(UPat(Ops.CONST, arg=Invalid, name="x"), lambda x: x.dtype.scalar() == dtypes.index),
|
||||
|
|
|
|||
|
|
@ -251,7 +251,7 @@ symbolic = symbolic_simple+commutative+PatternMatcher([
|
|||
((UPat.var("x", dtypes.index) + UPat.cvar("c")).cast(dtypes.sints, name="cast"), lambda x,c,cast:x.cast(cast.dtype)+c.cast(cast.dtype)),
|
||||
# only RANGE/IF/STORE/KERNEL have side effects
|
||||
(UPat(Ops.AFTER, name="x"), lambda x: x.replace(src=(x.src[0],)+
|
||||
tuple(flatten([(y,) if y.op in {Ops.RANGE, Ops.STORE, Ops.KERNEL, Ops.BARRIER, Ops.END, Ops.UNROLL} else y.src for y in x.src[1:]])))),
|
||||
tuple(flatten([(y,) if y.op in {Ops.RANGE, Ops.STORE, Ops.CALL, Ops.BARRIER, Ops.END, Ops.UNROLL} else y.src for y in x.src[1:]])))),
|
||||
# after with 1 src is just src[0]
|
||||
(UPat(Ops.AFTER, src=(UPat.var("s"),)), lambda s: s),
|
||||
# VECTORIZE/CONST
|
||||
|
|
|
|||
|
|
@ -738,12 +738,13 @@ window.addEventListener("popstate", (e) => {
|
|||
});
|
||||
|
||||
const createToggle = (id, text) => {
|
||||
const label = d3.create("label").text(text).node();
|
||||
const label = d3.create("label").style("display", "block").text(text).node();
|
||||
const toggle = d3.create("input").attr("type", "checkbox").attr("id", id).property("checked", true).node();
|
||||
label.prepend(toggle);
|
||||
return { toggle, label };
|
||||
}
|
||||
const { toggle, label:toggleLabel } = createToggle("show-indexing", "Show indexing (r)");
|
||||
const showIndexing = createToggle("show-indexing", "Show indexing (r)");
|
||||
const showCallSrc = createToggle("show-call-src", "Show CALL src (c)");
|
||||
const showGraph = createToggle("show-graph", "Show graph (g)");
|
||||
showGraph.toggle.onchange = () => displaySelection(rect("#graph").width > 0 ? "#custom" : "#graph");
|
||||
|
||||
|
|
@ -893,11 +894,13 @@ async function main() {
|
|||
// ** center graph
|
||||
const data = ret[currentRewrite];
|
||||
const render = (opts) => renderDag({ data, opts }, { recenter:currentRewrite === 0 });
|
||||
render({ showIndexing:toggle.checked });
|
||||
toggle.onchange = (e) => render({ showIndexing:e.target.checked });
|
||||
const getOpts = () => ({ showIndexing:showIndexing.toggle.checked, showCallSrc:showCallSrc.toggle.checked });
|
||||
render(getOpts());
|
||||
showIndexing.toggle.onchange = () => render(getOpts());
|
||||
showCallSrc.toggle.onchange = () => render(getOpts());
|
||||
// ** right sidebar metadata
|
||||
metadata.innerHTML = "";
|
||||
if (ckey.includes("rewrites")) metadata.appendChild(toggleLabel);
|
||||
if (ckey.includes("rewrites")) metadata.append(showIndexing.label, showCallSrc.label);
|
||||
if (step.code_line != null) metadata.appendChild(codeBlock(step.code_line, "python", { loc:step.loc, wrap:true }));
|
||||
if (step.trace) {
|
||||
const trace = d3.create("pre").append("code").classed("hljs", true);
|
||||
|
|
@ -1025,7 +1028,9 @@ document.addEventListener("keydown", (event) => {
|
|||
document.getElementById("zoom-to-fit-btn").click();
|
||||
}
|
||||
// r key toggles indexing
|
||||
if (event.key === "r") toggle.click();
|
||||
if (event.key === "r") showIndexing.toggle.click();
|
||||
// c key toggles CALL src
|
||||
if (event.key === "c") showCallSrc.toggle.click();
|
||||
// g key toggles graph
|
||||
if (event.key === "g") showGraph.toggle.click();
|
||||
});
|
||||
|
|
|
|||
|
|
@ -55,13 +55,42 @@ const layoutUOp = (g, { graph, change }, opts) => {
|
|||
for (const [port, s] of src) g.setEdge(s, k, { label: edgeCounts[s] > 1 ? {type:"tag", text:edgeCounts[s]} : {type:"port", text:port}});
|
||||
if (change?.includes(parseInt(k))) g.setParent(k, "overlay");
|
||||
}
|
||||
// optionally hide nodes from the layuot
|
||||
// optionally hide nodes from the layout
|
||||
if (!opts.showIndexing) {
|
||||
for (const n of g.nodes()) {
|
||||
const node = g.node(n);
|
||||
if (node.label.includes("dtypes.index")) g.removeNode(n);
|
||||
}
|
||||
}
|
||||
if (!opts.showCallSrc) {
|
||||
// remove edges from src[0] to CALL nodes, track affected nodes
|
||||
const disconnected = new Set();
|
||||
for (const n of g.nodes()) {
|
||||
const node = g.node(n);
|
||||
if (node?.label?.startsWith("CALL\n") || node?.label === "CALL") {
|
||||
for (const pred of (g.predecessors(n) || [])) {
|
||||
const edge = g.edge(pred, n);
|
||||
if (edge?.label?.text === 0) {
|
||||
g.removeEdge(pred, n);
|
||||
disconnected.add(pred);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// remove nodes that are now disconnected (no successors), only from affected subtree
|
||||
let changed = true;
|
||||
while (changed) {
|
||||
changed = false;
|
||||
for (const n of disconnected) {
|
||||
if (!g.hasNode(n)) continue;
|
||||
if ((g.successors(n) || []).length === 0) {
|
||||
for (const pred of (g.predecessors(n) || [])) disconnected.add(pred);
|
||||
g.removeNode(n);
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
dagre.layout(g);
|
||||
// remove overlay node if it's empty
|
||||
if (!g.node("overlay")?.width) g.removeNode("overlay");
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue