Compare commits

...

6 commits

Author SHA1 Message Date
George Hotz
771a395240
Merge branch 'master' into kernel_is_call 2026-02-06 09:15:03 +08:00
George Hotz
d83ddc05c8 resolve_call 2026-02-05 12:57:31 +08:00
George Hotz
8e8cac4b0f don't use tag, use KernelInfo 2026-02-05 12:31:14 +08:00
George Hotz
57199fd9de keep the all buffers on same device check 2026-02-05 12:17:32 +08:00
George Hotz
2193d0edfa fix arg order 2026-02-05 12:03:45 +08:00
George Hotz
77adccb925 use call for kernel 2026-02-05 11:48:50 +08:00
10 changed files with 92 additions and 47 deletions

View file

@ -1,7 +1,7 @@
import time
from typing import cast
from collections import deque
from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass, track_rewrites, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, Kernel
from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass, track_rewrites, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, CallInfo
from tinygrad.uop.spec import type_verify, tensor_spec
from tinygrad.device import Buffer, MultiBuffer
from tinygrad.helpers import DEBUG, cpu_profile, TracingKey, SPEC, flatten, pluralize, SCACHE, Metadata
@ -28,8 +28,7 @@ def create_schedule(sched_sink:UOp) -> tuple[list[ExecItem], UOp]:
if (k:=u.src[1]).op is Ops.RANGE: continue # RANGEs are scheduled directly, not through dependency graph
assert k.op in {Ops.KERNEL, Ops.END}, f"AFTER src[1] should be KERNEL or END, not {k.op}"
in_degree.setdefault(k, 0)
if k.op is Ops.END: assert k.src[0].op is Ops.KERNEL, f"END src[0] should be KERNEL, not {k.src[0].op}"
for s in k.src[0].src if k.op is Ops.END else k.src:
for s in k.src[0].src if k.op is Ops.END else k.src[1:]:
match (s := _unwrap_src(s)).op:
case Ops.AFTER:
children.setdefault(s.src[1], []).append(k)
@ -54,14 +53,13 @@ def create_schedule(sched_sink:UOp) -> tuple[list[ExecItem], UOp]:
while len(queue):
k = rk = queue.popleft()
if k.op is Ops.END: k = k.src[0]
assert k.op in {Ops.RANGE, Ops.KERNEL}, f"unexpected op in queue: {k.op}"
assert k.op in {Ops.RANGE, Ops.CALL}, f"unexpected op in queue: {k.op}"
if k.op is Ops.RANGE: schedule.append(k)
elif k.op is Ops.KERNEL:
ast = (kernel:=cast(Kernel, k.arg)).ast
buf_uops = tuple(_unwrap_src(s).buf_uop for s in k.src if s.op is not Ops.BIND)
bound_ranges = tuple(s for s in k.src if s.op is Ops.BIND and len(s.src) > 1 and s.src[1].op is Ops.RANGE)
sched_item[k] = (ast, buf_uops, kernel.metadata, bound_ranges)
schedule.append(k)
elif k.op is Ops.CALL:
ast = k.src[0]
buf_uops = tuple(_unwrap_src(s).buf_uop for s in k.src[1:] if s.op is not Ops.BIND)
bound_ranges = tuple(s for s in k.src[1:] if s.op is Ops.BIND and len(s.src) > 1 and s.src[1].op is Ops.RANGE)
schedule.append((ast, buf_uops, cast(CallInfo, k.arg).metadata, {}, bound_ranges))
if rk.op is Ops.END: schedule.append(rk)
for x in children.get(rk, []):
in_degree[x] -= 1

View file

@ -14,7 +14,7 @@ def reduce_gradient(ctx:UOp, ret:UOp, op:Ops):
if op == Ops.MUL: return (broadcast_to_input(ctx * ret) / ret.src[0],)
def call_gradient(ctx:UOp, k:UOp):
if k.arg is not None: return (None,) + k.arg(ctx, k)
if k.arg.grad_fxn is not None: return (None,) + k.arg.grad_fxn(ctx, k)
# auto-differentiate the function
fxn, args = k.src[0], k.src[1:]
params = sorted([x for x in fxn.toposort() if x.op == Ops.PARAM], key=lambda x: x.arg)

View file

@ -9,7 +9,7 @@ from tinygrad.helpers import argsort, all_same, cpu_profile, PCONTIG, colored
ALWAYS_CONTIGUOUS: set[Ops] = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW,
Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.PARAM,
Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.LOAD, Ops.KERNEL, Ops.ENCDEC}
Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.LOAD, Ops.CALL, Ops.ENCDEC}
def realize(ctx:dict[UOp, None], tr:UOp) -> None: ctx[tr] = None
@ -20,7 +20,7 @@ def realize_srcs(ctx:dict[UOp, None], rb:UOp) -> None:
def realize_assign(ctx:dict[UOp, None], a:UOp) -> None:
if a.src[1].op not in ALWAYS_CONTIGUOUS: ctx[a.src[1]] = None
# if it's a kernel, we don't realize it
if a.src[1].op is not Ops.KERNEL: ctx[a] = None
if a.src[1].op is not Ops.CALL: ctx[a] = None
pm_generate_realize_map = PatternMatcher([
# always realize SINK src
@ -99,7 +99,7 @@ def remove_movement_op_after_rangeify(ctx:IndexingContext, x:UOp):
if x in ctx.range_map or x.src[0].op is Ops.INDEX: return x.src[0]
def add_third_op_to_assign_to_track_shape(ctx:IndexingContext, assign:UOp):
if assign.src[1].op is Ops.KERNEL: return None
if assign.src[1].op is Ops.CALL: return None
to_mop = graph_rewrite(assign.src[0], PatternMatcher([(UPat(GroupOp.Movement, name="x"), lambda x: x.replace(tag=()))]))
ret = assign.replace(src=assign.src+(to_mop,))
ctx.range_map[ret] = ctx.range_map[assign]
@ -174,7 +174,7 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
if x.op in {Ops.DEVICE, Ops.UNIQUE}: continue
# no ranges on kernels, they are internal
if x.op is Ops.KERNEL: continue
if x.op is Ops.CALL: continue
if x.dtype.scalar() == dtypes.index: continue # TODO: why do I need this?
ending_ranges[x] = sum([ending_ranges.get(u, []) for u in consumer_map[x]], [])

View file

@ -2,10 +2,10 @@ from dataclasses import dataclass, field, replace
import itertools
from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, KernelInfo
from tinygrad.uop.ops import graph_rewrite, identity_element, sint, AxisType, BottomUpGate, Kernel, _remove_all_tags, range_str
from tinygrad.uop.ops import graph_rewrite, identity_element, sint, AxisType, BottomUpGate, _remove_all_tags, range_str
from tinygrad.uop.symbolic import symbolic
from tinygrad.helpers import argsort, prod, all_same, getenv, flatten, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY
from tinygrad.helpers import PCONTIG, partition, get_single_element
from tinygrad.helpers import PCONTIG, partition, get_single_element, panic
from tinygrad.codegen.simplify import pm_flatten_range, pm_reduce_simplify
from tinygrad.codegen.opt import Opt
from tinygrad.schedule.indexing import run_rangeify, BufferizeOpts, ALWAYS_CONTIGUOUS, IndexingContext, apply_movement_op
@ -66,9 +66,11 @@ mop_cleanup = PatternMatcher([
def resolve_custom_kernel(ck:UOp) -> UOp:
placeholders = [UOp.placeholder_like(s, slot=i) for i,s in enumerate(ck.src)]
return UOp(Ops.KERNEL, src=ck.src, arg=Kernel(ck.arg.fxn(*placeholders)))
return ck.arg.fxn(*placeholders).call(*ck.src)
def resolve_call(c:UOp) -> UOp:
def resolve_call(c:UOp) -> UOp|None:
# don't resolve CALLs with SINK - those are kernel calls from split_store/custom_kernel
if c.src[0].op is Ops.SINK: return None
params = sorted([x for x in c.src[0].toposort() if x.op == Ops.PARAM], key=lambda x: x.arg)
args = c.src[1:]
# TODO: this check belongs in spec, not here
@ -83,7 +85,7 @@ earliest_rewrites = mop_cleanup+PatternMatcher([
# just removing it works...
(UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD), name="x"), lambda x: x.src[0]),
# resolve calls
# resolve calls (but not CALLs with SINK - those are kernel calls from split_store)
(UPat(Ops.CALL, name="c"), resolve_call),
# resolve custom kernels
@ -384,8 +386,8 @@ pm_add_buffers = pm_mops+pm_flatten_bufferize+to_bufferview+PatternMatcher([
(UPat((Ops.MSELECT, Ops.MSTACK), src=UPat(Ops.RESHAPE), name="m"),
lambda m: m.replace(src=tuple([x.src[0].base for x in m.src]), tag=None).reshape(m.shape).rtag(m.tag)),
# remove any RESHAPEs on KERNEL
(UPat(Ops.KERNEL, name="k"), lambda k: k.replace(src=tuple(x.src[0] if x.op is Ops.RESHAPE else x for x in k.src))),
# remove any RESHAPEs on CALL
(UPat(Ops.CALL, name="k"), lambda k: k.replace(src=(k.src[0],)+tuple(x.src[0] if x.op is Ops.RESHAPE else x for x in k.src[1:]))),
])
pm_add_buffers_local = pm_mops+pm_flatten_bufferize+to_bufferview+PatternMatcher([
@ -515,16 +517,17 @@ def split_store(ctx:list[UOp], x:UOp) -> UOp|None:
else: raise RuntimeError(f"unknown kernel type {ret.op}")
if stored.op in {Ops.COPY, Ops.BUFFER_VIEW, Ops.ENCDEC}: ret = stored
else:
ret = ret.sink(arg=KernelInfo(opts_to_apply=lctx.opts) if lctx.opts is not None else None)
kernel_arg = Kernel(ret,tuple(dedup(flatten([x for x in metadatas if x is not None])))[::-1])
kernel = UOp(Ops.KERNEL, src=tuple(lctx.map.values())+tuple(lctx.vars.keys()), arg=kernel_arg)
if ret.op is Ops.SINK and not all_same([x.device for x in kernel.src if x.op is not Ops.BIND]):
raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop for b in kernel.src)}")
ret = ret.sink(arg=KernelInfo(opts_to_apply=lctx.opts) if lctx.opts is not None else KernelInfo())
metadata = tuple(dedup(flatten([x for x in metadatas if x is not None])))[::-1]
kernel = ret.call(*lctx.map.values(), *lctx.vars.keys(), metadata=metadata)
if ret.op is Ops.SINK and not all_same([x.device for x in kernel.src[1:] if x.op is not Ops.BIND]):
raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop for b in kernel.src[1:])}")
return kernel
split_kernels = PatternMatcher([
(UPat((Ops.STORE, Ops.END), name="x"), split_store),
# if it's a Kernel, stop
(UPat(Ops.SINK, name="sink"), lambda sink: panic(BottomUpGate()) if isinstance(sink.arg, KernelInfo) else None),
])
def tag_uop(ctx:tuple[list[UOp], set[UOp]], x:UOp):
@ -537,7 +540,7 @@ def tag_uop(ctx:tuple[list[UOp], set[UOp]], x:UOp):
return x.replace(tag=(len(ctx[0])-1,))
add_tags = PatternMatcher([
# don't tag BUFFERs, they are global
(UPat(GroupOp.All-{Ops.BUFFER, Ops.CONST, Ops.DEVICE, Ops.UNIQUE, Ops.LUNIQUE, Ops.DEFINE_VAR, Ops.BIND, Ops.KERNEL, Ops.END,
(UPat(GroupOp.All-{Ops.BUFFER, Ops.CONST, Ops.DEVICE, Ops.UNIQUE, Ops.LUNIQUE, Ops.DEFINE_VAR, Ops.BIND, Ops.CALL, Ops.END,
Ops.MSTACK, Ops.MSELECT, Ops.RANGE}.union(GroupOp.Movement), name="x"), tag_uop),
(UPat({Ops.MSTACK, Ops.MSELECT}, name="x"), lambda ctx,x: None if all(s.op is Ops.BUFFER for s in x.src) else tag_uop(ctx, x)),
])
@ -581,7 +584,7 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
# bufferize -> store
lunique_start: int = max([-1]+[x.arg for x in tsink.toposort() if x.op is Ops.LUNIQUE]) + 1
tsink = graph_rewrite(tsink, pm_add_buffers+pm_add_range_tags, ctx=itertools.count(lunique_start), bottom_up=True, name="bufferize to store")
tsink = graph_rewrite(tsink, split_kernels, ctx=uop_list, name="split kernels")
tsink = graph_rewrite(tsink, split_kernels, ctx=uop_list, name="split kernels", bottom_up=True)
# if a kernel depends on a buffer, and that buffer is later assigned to, make the assign depend on the kernel's assign
kernel_assign: dict[UOp, UOp] = {}

View file

@ -240,7 +240,7 @@ class Tensor(OpMixin):
param = UOp.param(slot, self.dtype, self.shape, self.device)
return Tensor(param, device=self.device)
def call(self, *lst:Tensor, fxn:Tensor|UOp, grad_fxn:Callable|None=None) -> Tensor:
return Tensor(UOp.call(*[t.uop for t in (self,)+lst], fxn=fxn.uop if isinstance(fxn, Tensor) else fxn, arg=grad_fxn), device=self.device)
return Tensor((fxn.uop if isinstance(fxn, Tensor) else fxn).call(*[t.uop for t in (self,)+lst], grad_fxn=grad_fxn), device=self.device)
def custom_kernel(self, *lst:Tensor, fxn:Callable, grad_fxn:Callable|None=None) -> list[Tensor]:
"""

View file

@ -364,7 +364,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
@recursive_property
def trace_num(self):
num = next(ucount)
# KERNEL also has a UOp in the arg
# KERNEL has a UOp in the arg, CALL has it in src[0] so no special handling needed
arg = type(self.arg)(self.arg.ast.trace_num, self.arg.metadata) if self.op is Ops.KERNEL else self.arg
uop_fields[num] = (self.op, self.dtype, tuple(s.trace_num for s in self.src), arg, self.tag)+((self.metadata,) if TRACEMETA>=2 else ())
return num
@ -818,7 +818,8 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
src = (UOp(Ops.NOOP) if shape is None else shape_to_shape_arg(shape),) + (() if device is None else (UOp(Ops.DEVICE, arg=device),))
return UOp(Ops.PARAM, dtype, src, arg=slot)
def call(*srcs:UOp, fxn:UOp, arg:Any|None) -> UOp: return UOp(Ops.CALL, fxn.dtype, (fxn,)+srcs, arg)
def call(self, *srcs:UOp, grad_fxn:Callable|None=None, metadata:tuple[Metadata, ...]=()) -> UOp:
return UOp(Ops.CALL, self.dtype, (self,)+srcs, CallInfo(grad_fxn, metadata))
def custom_kernel(*srcs:UOp, fxn:Callable, grad_fxn:Callable|None=None) -> list[UOp]:
contig_srcs = tuple(x.contiguous() if x.op is not Ops.AFTER else x for x in srcs)
kernel = UOp(Ops.CUSTOM_KERNEL, src=contig_srcs, arg=CustomKernel(fxn=fxn, grad_fxn=grad_fxn))
@ -843,6 +844,14 @@ class CustomKernel:
def __reduce__(self): return (CustomKernel, (panic,))
def __repr__(self): return f"CustomKernel({id(self.fxn)})"
@dataclass(frozen=True)
class CallInfo:
grad_fxn: Callable|None = None
metadata: tuple[Metadata, ...] = ()
# grad_fxn can't be pickled, but metadata can
def __reduce__(self): return (CallInfo, (None, self.metadata))
def __repr__(self): return f"CallInfo({id(self.grad_fxn) if self.grad_fxn else None}, {self.metadata})"
@dataclass(frozen=True)
class Kernel:
ast: UOp
@ -1405,7 +1414,7 @@ pm_pyrender_extra = PatternMatcher([
# NOTE: you can remove pm_pyrender_extra and it'll still be correct
pm_pyrender = pm_pyrender_extra+PatternMatcher([
(UPat(Ops.KERNEL, name="u"), lambda ctx,u: f"UOp(Ops.KERNEL, src={srcs(ctx,u.src)}, arg=Kernel({ctx[u.arg.ast]}(), {u.arg.metadata}))"),
(UPat(Ops.CALL, name="u"), lambda ctx,u: f"{ctx[u.src[0]]}.call({', '.join(ctx[s] for s in u.src[1:])}, metadata={u.arg.metadata})"),
(UPat(GroupOp.All, name="u"), lambda ctx,u: f"UOp({u.op}, {u.dtype}, {srcs(ctx,u.src)}"+(f", {repr(u.arg)})" if u.arg is not None else ")")),
])
@ -1415,7 +1424,7 @@ def pyrender(ast:UOp) -> str:
cmap = consumer_map_from_toposort(lst)
not_rendered = {Ops.CONST, Ops.VCONST, Ops.DEVICE}
always_rendered = {Ops.PARAM, Ops.LOAD, Ops.SPECIAL, Ops.RANGE, Ops.CONTIGUOUS, Ops.VECTORIZE,
Ops.BUFFER, Ops.COPY, Ops.KERNEL, Ops.WHERE, Ops.END, Ops.ASSIGN}
Ops.BUFFER, Ops.COPY, Ops.CALL, Ops.WHERE, Ops.END, Ops.ASSIGN}
to_render: set[UOp] = {ast}
for u in lst:

View file

@ -87,8 +87,9 @@ _tensor_spec = PatternMatcher([
(UPat(Ops.BUFFER, src=(UPat((Ops.LUNIQUE, Ops.UNIQUE)), UPat(Ops.DEVICE)), name="buf"),
lambda buf: isinstance(buf.arg, int) and isinstance(buf.dtype, (DType, ImageDType))),
# KERNEL can attach to an AFTER to describe the compute required to realize a BUFFER
(UPat(Ops.KERNEL, src=UPat((Ops.BUFFER, Ops.AFTER, Ops.MSELECT, Ops.MSTACK, Ops.BIND))), lambda: True),
# CALL can attach to an AFTER to describe the compute required to realize a BUFFER
# src[0] is the function (SINK), src[1:] are buffers/bindings
(UPat(Ops.CALL, src=(UPat(Ops.SINK), UPat((Ops.BUFFER, Ops.AFTER, Ops.MSELECT, Ops.MSTACK, Ops.BIND))), allow_any_len=True), lambda: True),
# ASSIGN has a target and a value. It can also optionally depend on other assigns
(UPat(Ops.ASSIGN, name="x"), lambda x: len(x.src) >= 2 and all(s.op is Ops.ASSIGN for s in x.src[2:])),
@ -249,8 +250,8 @@ full_spec = PatternMatcher([
# vectorized index
(UPat(Ops.INDEX, src=(UPat((Ops.VECTORIZE, Ops.CAST)), UPat())), lambda: True),
# linearizer: outputs + intermediate KERNELs
(UPat(Ops.KERNEL, dtype=dtypes.void), lambda: True),
# linearizer: outputs + intermediate CALLs
(UPat(Ops.CALL, dtype=dtypes.void), lambda: True),
# Invalid must have type Index
(UPat(Ops.CONST, arg=Invalid, name="x"), lambda x: x.dtype.scalar() == dtypes.index),

View file

@ -251,7 +251,7 @@ symbolic = symbolic_simple+commutative+PatternMatcher([
((UPat.var("x", dtypes.index) + UPat.cvar("c")).cast(dtypes.sints, name="cast"), lambda x,c,cast:x.cast(cast.dtype)+c.cast(cast.dtype)),
# only RANGE/IF/STORE/KERNEL have side effects
(UPat(Ops.AFTER, name="x"), lambda x: x.replace(src=(x.src[0],)+
tuple(flatten([(y,) if y.op in {Ops.RANGE, Ops.STORE, Ops.KERNEL, Ops.BARRIER, Ops.END, Ops.UNROLL} else y.src for y in x.src[1:]])))),
tuple(flatten([(y,) if y.op in {Ops.RANGE, Ops.STORE, Ops.CALL, Ops.BARRIER, Ops.END, Ops.UNROLL} else y.src for y in x.src[1:]])))),
# after with 1 src is just src[0]
(UPat(Ops.AFTER, src=(UPat.var("s"),)), lambda s: s),
# VECTORIZE/CONST

View file

@ -738,12 +738,13 @@ window.addEventListener("popstate", (e) => {
});
const createToggle = (id, text) => {
const label = d3.create("label").text(text).node();
const label = d3.create("label").style("display", "block").text(text).node();
const toggle = d3.create("input").attr("type", "checkbox").attr("id", id).property("checked", true).node();
label.prepend(toggle);
return { toggle, label };
}
const { toggle, label:toggleLabel } = createToggle("show-indexing", "Show indexing (r)");
const showIndexing = createToggle("show-indexing", "Show indexing (r)");
const showCallSrc = createToggle("show-call-src", "Show CALL src (c)");
const showGraph = createToggle("show-graph", "Show graph (g)");
showGraph.toggle.onchange = () => displaySelection(rect("#graph").width > 0 ? "#custom" : "#graph");
@ -893,11 +894,13 @@ async function main() {
// ** center graph
const data = ret[currentRewrite];
const render = (opts) => renderDag({ data, opts }, { recenter:currentRewrite === 0 });
render({ showIndexing:toggle.checked });
toggle.onchange = (e) => render({ showIndexing:e.target.checked });
const getOpts = () => ({ showIndexing:showIndexing.toggle.checked, showCallSrc:showCallSrc.toggle.checked });
render(getOpts());
showIndexing.toggle.onchange = () => render(getOpts());
showCallSrc.toggle.onchange = () => render(getOpts());
// ** right sidebar metadata
metadata.innerHTML = "";
if (ckey.includes("rewrites")) metadata.appendChild(toggleLabel);
if (ckey.includes("rewrites")) metadata.append(showIndexing.label, showCallSrc.label);
if (step.code_line != null) metadata.appendChild(codeBlock(step.code_line, "python", { loc:step.loc, wrap:true }));
if (step.trace) {
const trace = d3.create("pre").append("code").classed("hljs", true);
@ -1025,7 +1028,9 @@ document.addEventListener("keydown", (event) => {
document.getElementById("zoom-to-fit-btn").click();
}
// r key toggles indexing
if (event.key === "r") toggle.click();
if (event.key === "r") showIndexing.toggle.click();
// c key toggles CALL src
if (event.key === "c") showCallSrc.toggle.click();
// g key toggles graph
if (event.key === "g") showGraph.toggle.click();
});

View file

@ -55,13 +55,42 @@ const layoutUOp = (g, { graph, change }, opts) => {
for (const [port, s] of src) g.setEdge(s, k, { label: edgeCounts[s] > 1 ? {type:"tag", text:edgeCounts[s]} : {type:"port", text:port}});
if (change?.includes(parseInt(k))) g.setParent(k, "overlay");
}
// optionally hide nodes from the layuot
// optionally hide nodes from the layout
if (!opts.showIndexing) {
for (const n of g.nodes()) {
const node = g.node(n);
if (node.label.includes("dtypes.index")) g.removeNode(n);
}
}
if (!opts.showCallSrc) {
// remove edges from src[0] to CALL nodes, track affected nodes
const disconnected = new Set();
for (const n of g.nodes()) {
const node = g.node(n);
if (node?.label?.startsWith("CALL\n") || node?.label === "CALL") {
for (const pred of (g.predecessors(n) || [])) {
const edge = g.edge(pred, n);
if (edge?.label?.text === 0) {
g.removeEdge(pred, n);
disconnected.add(pred);
}
}
}
}
// remove nodes that are now disconnected (no successors), only from affected subtree
let changed = true;
while (changed) {
changed = false;
for (const n of disconnected) {
if (!g.hasNode(n)) continue;
if ((g.successors(n) || []).length === 0) {
for (const pred of (g.predecessors(n) || [])) disconnected.add(pred);
g.removeNode(n);
changed = true;
}
}
}
}
dagre.layout(g);
// remove overlay node if it's empty
if (!g.node("overlay")?.width) g.removeNode("overlay");