rename BUFFERIZE to STAGE (#16125)

This commit is contained in:
George Hotz 2026-05-10 09:26:46 -07:00 committed by GitHub
commit daed602569
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 38 additions and 38 deletions

View file

@ -3,7 +3,7 @@ import functools, pickle
from tinygrad.uop.ops import UOp, Ops
from tinygrad.helpers import tqdm, temp, time_to_str, cpu_profile
BENCHMARK_OPS = {Ops.INDEX, Ops.BUFFERIZE}
BENCHMARK_OPS = {Ops.INDEX, Ops.STAGE}
@functools.cache
def create_uop(a:int) -> UOp:

View file

@ -799,12 +799,12 @@ class TestConstBufferize(unittest.TestCase):
from tinygrad.schedule.rangeify import pm_const_buffer_folding, BufferizeOpts
c = UOp.const(dtypes.float, 42.0)
r1 = UOp.range(3, 0)
bufferize_with_range = UOp(Ops.BUFFERIZE, dtypes.float, (c, r1), arg=BufferizeOpts(device="CPU"))
bufferize_with_range = UOp(Ops.STAGE, dtypes.float, (c, r1), arg=BufferizeOpts(device="CPU"))
self.assertEqual(len(bufferize_with_range.src), 2) # const + 1 range
result = graph_rewrite(bufferize_with_range, pm_const_buffer_folding, name='test')
# BUFFERIZE should be removed, result is const broadcast to shape
self.assertNotEqual(result.op, Ops.BUFFERIZE)
self.assertNotEqual(result.op, Ops.STAGE)
const_vals = [u.arg for u in result.toposort() if u.op is Ops.CONST and u.dtype == dtypes.float]
self.assertIn(42.0, const_vals)
@ -814,12 +814,12 @@ class TestConstBufferize(unittest.TestCase):
c = UOp.const(dtypes.float, 3.14)
r1 = UOp.range(3, 0)
r2 = UOp.range(4, 1)
bufferize_with_ranges = UOp(Ops.BUFFERIZE, dtypes.float, (c, r1, r2), arg=BufferizeOpts(device="CPU"))
bufferize_with_ranges = UOp(Ops.STAGE, dtypes.float, (c, r1, r2), arg=BufferizeOpts(device="CPU"))
self.assertEqual(len(bufferize_with_ranges.src), 3) # const + 2 ranges
result = graph_rewrite(bufferize_with_ranges, pm_const_buffer_folding, name='test')
# BUFFERIZE should be removed
self.assertNotEqual(result.op, Ops.BUFFERIZE)
self.assertNotEqual(result.op, Ops.STAGE)
const_vals = [u.arg for u in result.toposort() if u.op is Ops.CONST and u.dtype == dtypes.float]
self.assertIn(3.14, const_vals)

View file

@ -98,13 +98,13 @@ expander = PatternMatcher([
# END on UNROLL ends the UNROLL
(UPat(Ops.END, name="u"), end_unrolls),
# BUFFERIZE puts UNROLLs for ranges as contract
(UPat(Ops.BUFFERIZE, src=(UPat(Ops.UNROLL), UPat(Ops.UNROLL)), name="x"),
(UPat(Ops.STAGE, src=(UPat(Ops.UNROLL), UPat(Ops.UNROLL)), name="x"),
lambda x: x.replace(src=tuple(UOp(Ops.CONTRACT, dtype=s.dtype.vec(x.src[1].src[0].dtype.count), src=(s,), arg=x.src[1].arg) for s in x.src))),
# double expand
(UPat(Ops.UNROLL, name="outer", src=(UPat(Ops.UNROLL, name="inner"),)),
lambda outer, inner: UOp(Ops.UNROLL, outer.dtype, (inner.src[0],), inner.arg+outer.arg)),
# do expansion
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.GEP, Ops.WMMA, Ops.LOAD, Ops.STORE, Ops.INDEX, Ops.BUFFERIZE,
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.GEP, Ops.WMMA, Ops.LOAD, Ops.STORE, Ops.INDEX, Ops.STAGE,
Ops.STACK, Ops.REDUCE, Ops.END, Ops.AFTER), name="root", custom_early_reject=set([Ops.UNROLL])), do_expand),
(UPat(Ops.CONTRACT, name="con"), do_contract),
# empty UNROLL is NOOP

View file

@ -67,7 +67,7 @@ class Scheduler:
ret = [r for r in self._output_rngs() if r.arg[-1] == AxisType.LOOP]
# exclude any output ranges from global that don't appear in all BUFFERIZE
for x in self.ast.toposort():
if x.op is Ops.BUFFERIZE:
if x.op is Ops.STAGE:
ret = [r for r in ret if r in x.ranges]
return ret
@ -347,6 +347,6 @@ def apply_opts(ast:UOp, ren:Renderer, beam:int=0) -> UOp:
elif not NOOPT and (ast.arg is None or ast.arg.applied_opts == ()):
from tinygrad.codegen.opt.heuristic import hand_coded_optimizations
# NOTE: hand_coded_optimizations doesn't support multiblock opts yet
if not any(u.op is Ops.BUFFERIZE for u in ast.backward_slice):
if not any(u.op is Ops.STAGE for u in ast.backward_slice):
k = hand_coded_optimizations(k)
return k.get_optimized_ast(name_override=ast.arg.name if ast.arg is not None and ast.arg.name != "test" else None)

View file

@ -54,7 +54,7 @@ class IndexingContext:
return UOp.range(s, next(self.range_idx), axistype) if resolve(s!=1) else UOp.const(dtypes.weakint, 0)
def create_bufferize_and_index_based_on_ranges(ctx:IndexingContext, x:UOp):
if x.op in {Ops.BUFFERIZE, Ops.INDEX}: return None
if x.op in {Ops.STAGE, Ops.INDEX}: return None
new_srcs = []
for s in x.src:
new_src = s
@ -74,7 +74,7 @@ def create_bufferize_and_index_based_on_ranges(ctx:IndexingContext, x:UOp):
# None in the device assigns it a number later
opts = BufferizeOpts(device=s.device, removable=removable) if len(ctx.range_map[s][1]) == len(realized_ranges) else \
BufferizeOpts(device=s.device, addrspace=AddrSpace.LOCAL, removable=removable)
new_src = UOp(Ops.BUFFERIZE, s.dtype, src=(new_src,)+closed_ranges, arg=opts)
new_src = UOp(Ops.STAGE, s.dtype, src=(new_src,)+closed_ranges, arg=opts)
if x in ctx.range_map: new_src = new_src.index(*[r for i,r in enumerate(ctx.range_map[x][0]) if i in realized_ranges])
new_srcs.append(new_src)
# NOTE: do we need this?

View file

@ -246,7 +246,7 @@ def remove_bufferize(src:UOp, buf:UOp, idx:UOp):
indexes: list[UOp] = []
reduces: list[UOp] = []
def red_gate(x:UOp):
if (x.op is Ops.BUFFERIZE and x.arg.addrspace == AddrSpace.GLOBAL) or x.op is Ops.MSTACK:
if (x.op is Ops.STAGE and x.arg.addrspace == AddrSpace.GLOBAL) or x.op is Ops.MSTACK:
accessed_buffers.append(x)
return False
if x.op is Ops.STORE:
@ -269,7 +269,7 @@ def remove_bufferize(src:UOp, buf:UOp, idx:UOp):
buffer_in_reduce = False
def buf_gate(x:UOp):
nonlocal buffer_in_reduce
if x.op in {Ops.PARAM, Ops.BUFFERIZE}: buffer_in_reduce = True
if x.op in {Ops.PARAM, Ops.STAGE}: buffer_in_reduce = True
return not buffer_in_reduce
UOp.sink(*[x.src[0] for x in reduces]).toposort(gate=buf_gate)
del buf_gate
@ -278,7 +278,7 @@ def remove_bufferize(src:UOp, buf:UOp, idx:UOp):
out_in_ratio = (prod(buf.shape)+1) / (sum([x.numel() for x in accessed_buffers])+1)
if out_in_ratio < 10: return None
# here we have to check the indexes, we might do a partial contig here
local_indexes = [x for x in indexes if x.src[0].op is Ops.BUFFERIZE and x.src[0].arg.addrspace == AddrSpace.LOCAL]
local_indexes = [x for x in indexes if x.src[0].op is Ops.STAGE and x.src[0].arg.addrspace == AddrSpace.LOCAL]
exclude_ranges = UOp.group(*[UOp.group(*x.src[1:]) for x in local_indexes]).ranges
subs = [(k,v) for k,v in zip(buf.src[1:], idx.src[1:]) if k.op is not Ops.CONST]
# if it's bufferized or a reduce, it's pcontig
@ -302,11 +302,11 @@ def remove_noop_bufferize(idx,b2):
return idx.src[0].shrink(tuple((0, s) for s in b2.shape)) if b2.shape else idx.src[0]
pm_const_buffer_folding = pm_mops+PatternMatcher([
(UPat(Ops.BUFFERIZE, name="b"), cleanup_dead_axes),
(UPat(Ops.STAGE, name="b"), cleanup_dead_axes),
# remove noop buffers. if we look at the next index we can remove even more of these
(UPat(Ops.INDEX, name="idx").f(Ops.BUFFERIZE, allow_any_len=True, name="b2"), remove_noop_bufferize),
(UPat(Ops.INDEX, name="idx").f(Ops.STAGE, allow_any_len=True, name="b2"), remove_noop_bufferize),
# no buffers for const (ranges don't matter for const - it's the same value everywhere)
(UPat(Ops.CONST, name='c').f(Ops.BUFFERIZE, allow_any_len=True, name="b"), lambda c,b: b.const_like(c.arg)),
(UPat(Ops.CONST, name='c').f(Ops.STAGE, allow_any_len=True, name="b"), lambda c,b: b.const_like(c.arg)),
# indexing a const is a const
(UPat(Ops.INDEX, src=(UPat(Ops.CONST, name="c"),),), lambda c: c),
# copy on CONST is CONST
@ -320,7 +320,7 @@ pm_const_buffer_folding = pm_mops+PatternMatcher([
pm_remove_bufferize = PatternMatcher([
# remove reindexing with cost function
(UPat.var("src").f(Ops.BUFFERIZE, allow_any_len=True, name="buf").f(Ops.INDEX, allow_any_len=True, name="idx"), remove_bufferize),
(UPat.var("src").f(Ops.STAGE, allow_any_len=True, name="buf").f(Ops.INDEX, allow_any_len=True, name="idx"), remove_bufferize),
# STORE to self is NOOP
(UPat.var("x").store(UPat.var("x")), lambda x: UOp(Ops.NOOP)),
# END on NOOP is NOOP
@ -345,7 +345,7 @@ def late_buffer_view(t:UOp, b:UOp):
return b.replace(src=(UOp(Ops.BUFFER_VIEW, t.dtype, (x.base,), (size, offset)), b.src[1]))
to_bufferview = PatternMatcher([
(UPat(Ops.BUFFERIZE, src=(UPat((Ops.BITCAST, Ops.CONTIGUOUS), name="t"), UPat()), name="b"), late_buffer_view),
(UPat(Ops.STAGE, src=(UPat((Ops.BITCAST, Ops.CONTIGUOUS), name="t"), UPat()), name="b"), late_buffer_view),
])
DEVICE_MAX_BUFS = {"METAL": 31, "WEBGPU": 8} # TODO: get from device?
@ -357,7 +357,7 @@ def limit_bufs(ctx:IndexingContext, root:UOp):
bufs: set[UOp] = set()
def gate_input(u:UOp):
# TODO: add cache to fix n^2
if is_load:=(u.op in {Ops.BUFFERIZE, Ops.AFTER, Ops.PARAM, Ops.MSELECT, Ops.MSTACK, Ops.DEFINE_VAR}): bufs.add(u)
if is_load:=(u.op in {Ops.STAGE, Ops.AFTER, Ops.PARAM, Ops.MSELECT, Ops.MSTACK, Ops.DEFINE_VAR}): bufs.add(u)
return not is_load
root.toposort(gate=gate_input)
@ -394,7 +394,7 @@ def bufferize_to_store(ctx:itertools.count, x:UOp, idx:UOp, allow_locals=True):
ended_stores = []
for store in stores:
store_target = store.src[0]
if store_target.src[0].op is Ops.BUFFERIZE and store_target.src[0].src[0].op is Ops.INDEX:
if store_target.src[0].op is Ops.STAGE and store_target.src[0].src[0].op is Ops.INDEX:
store_target = store_target.src[0].src[0]
if store.src[1] is store_target: continue # skip self-assign
end_rngs = sorted(dedup(tuple(store_target.ranges) + tuple(rngs)), key=lambda x: x.arg)
@ -423,10 +423,10 @@ def flatten_bufferize(x:UOp):
sym_shape = tuple([r.src[0] if r.op is not Ops.CONST else 1 for r in rngs])
ret = ret.shrink(tuple([(0,x) for x in sym_shape]))
return ret
pm_flatten_bufferize = PatternMatcher([(UPat(Ops.BUFFERIZE, name="x"), flatten_bufferize)])
pm_flatten_bufferize = PatternMatcher([(UPat(Ops.STAGE, name="x"), flatten_bufferize)])
pm_add_buffers = pm_mops+pm_flatten_bufferize+to_bufferview+PatternMatcher([
(UPat(Ops.BUFFERIZE, src=(UPat(), UPat(name="idx")), name="x"), lambda ctx,x,idx: bufferize_to_store(ctx, x, idx, allow_locals=False)),
(UPat(Ops.STAGE, src=(UPat(), UPat(name="idx")), name="x"), lambda ctx,x,idx: bufferize_to_store(ctx, x, idx, allow_locals=False)),
# move RESHAPEs through MSELECT/MSTACK
(UPat((Ops.MSELECT, Ops.MSTACK), src=UPat(Ops.RESHAPE), name="m"),
@ -447,7 +447,7 @@ pm_add_buffers = pm_mops+pm_flatten_bufferize+to_bufferview+PatternMatcher([
])
pm_add_buffers_local = pm_mops+pm_flatten_bufferize+to_bufferview+PatternMatcher([
(UPat(Ops.BUFFERIZE, src=(UPat(), UPat(name="idx")), name="x"), bufferize_to_store),
(UPat(Ops.STAGE, src=(UPat(), UPat(name="idx")), name="x"), bufferize_to_store),
])
# *****************
@ -506,7 +506,7 @@ to_define_global = PatternMatcher([
(UPat((Ops.MSTACK, Ops.MSELECT, Ops.AFTER), name="after"), handle_after),
# remove device from local BUFFERIZE
(UPat(Ops.BUFFERIZE, name="b"), lambda b: b.replace(arg=replace(b.arg, device=None))),
(UPat(Ops.STAGE, name="b"), lambda b: b.replace(arg=replace(b.arg, device=None))),
# remove UNIQUE/DEVICE to dedup CONST
(UPat(Ops.CONST, name="c"), lambda c: c.replace(src=()) if len(c.src) else None),
@ -581,7 +581,7 @@ def get_kernel_graph(sink:UOp) -> UOp:
# bufferize -> store
lunique_start: int = max([-1]+[x.arg for x in tsink.toposort() if x.op is Ops.LUNIQUE]) + 1
tsink = graph_rewrite(tsink, pm_add_buffers+pm_add_range_tags, ctx=itertools.count(lunique_start), bottom_up=True, name="bufferize to store")
tsink = graph_rewrite(tsink, pm_add_buffers+pm_add_range_tags, ctx=itertools.count(lunique_start), bottom_up=True, name="stage to store")
tsink = graph_rewrite(tsink, split_kernels, bottom_up=True, name="split kernels")
# WAR deps: if kernel U reads buffer S, and S is also written by another kernel, S's write must wait for U to finish

View file

@ -96,7 +96,7 @@ class Ops(FastEnum):
CONTIGUOUS = auto(); CONTIGUOUS_BACKWARD = auto(); DETACH = auto()
# buffer ops
BUFFERIZE = auto(); COPY = auto(); BUFFER = auto(); BUFFER_VIEW = auto(); MSELECT = auto(); MSTACK = auto(); CUSTOM_FUNCTION = auto()
STAGE = auto(); COPY = auto(); BUFFER = auto(); BUFFER_VIEW = auto(); MSELECT = auto(); MSTACK = auto(); CUSTOM_FUNCTION = auto()
# the core 6 movement ops! these only exist in the tensor graph
RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); FLIP = auto()

View file

@ -27,7 +27,7 @@ axis_colors = {AxisType.GLOBAL: "blue", AxisType.THREAD: "BLUE", AxisType.LOCAL:
axis_to_pos = {AxisType.LOOP: -1, AxisType.THREAD: 0, AxisType.GLOBAL: 0, AxisType.WARP: 1, AxisType.LOCAL: 2, AxisType.UPCAST: 3,
AxisType.GROUP_REDUCE: 2, AxisType.REDUCE: 4, AxisType.UNROLL: 5}
range_start = {Ops.BUFFERIZE: 1, Ops.REDUCE: 1, Ops.WMMA: 3, Ops.END: 1, Ops.CALL: 1, Ops.FUNCTION: 1,
range_start = {Ops.STAGE: 1, Ops.REDUCE: 1, Ops.WMMA: 3, Ops.END: 1, Ops.CALL: 1, Ops.FUNCTION: 1,
Ops.COPY: 2, Ops.BUFFER_VIEW: 1, Ops.LINEAR: 0}
# https://en.wikipedia.org/wiki/Identity_element
@ -246,7 +246,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
if self.src[0].op is Ops.INDEX: return ()
return (self.arg[0],)
case Ops.CUSTOM_FUNCTION: return None
case Ops.BUFFERIZE: return tuple([int(r.vmax+1) for r in self.src[1:]])
case Ops.STAGE: return tuple([int(r.vmax+1) for r in self.src[1:]])
case Ops.DEFINE_LOCAL | Ops.DEFINE_REG: return (self.ptrdtype.size,)
case Ops.PARAM:
if isinstance(self.dtype, PtrDType): return (self.ptrdtype.size,)
@ -525,7 +525,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
if self.op is Ops.CONTIGUOUS: return self
if self.has_buffer_identity(): return self
return UOp(Ops.CONTIGUOUS, dtype=self.dtype, src=(self,)+args, **kwargs)
def bufferize(self, *args, **kwargs): return UOp(Ops.BUFFERIZE, dtype=self.dtype, src=(self,)+args, **kwargs)
def bufferize(self, *args, **kwargs): return UOp(Ops.STAGE, dtype=self.dtype, src=(self,)+args, **kwargs)
def allreduce(self, op, device:str|tuple[str, ...]|UOp):
assert isinstance(self.device, tuple), f"allreduce must be on tuple {self.device} isn't"
return UOp(Ops.ALLREDUCE, self.dtype, (self, UOp(Ops.DEVICE, arg=device) if not isinstance(device, UOp) else device), op)
@ -674,7 +674,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
@recursive_property
def _device(self) -> str|tuple[str, ...]|None:
if self.op is Ops.DEVICE: return self.arg
if self.op is Ops.BUFFERIZE: return self.arg.device
if self.op is Ops.STAGE: return self.arg.device
if self.op is Ops.AFTER: return self.src[0]._device
if self.op is Ops.MSELECT:
assert isinstance(self.src[0].device, tuple), f"mselect must be on tuple device, getting {self.src[0].device}"
@ -691,7 +691,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
if self.op is Ops.MSTACK: return UOp(Ops.MSTACK, self.dtype, src=tuple(x.buf_uop for x in self.src))
if self.base.op is Ops.AFTER: return self.base.src[0].buf_uop.base
s = self
while len(s.src) and s.op not in {Ops.BUFFER, Ops.PARAM, Ops.BUFFERIZE, Ops.MSTACK}: s = s.src[0]
while len(s.src) and s.op not in {Ops.BUFFER, Ops.PARAM, Ops.STAGE, Ops.MSTACK}: s = s.src[0]
return s
def contiguous_view_offset(self) -> int|None:

View file

@ -49,7 +49,7 @@ renderer = PatternMatcher([
(UPat(Ops.CDIV, name="x"), lambda ctx,x: f"cdiv({ctx[x.src[0]]}, {ctx[x.src[1]]})"),
(UPat(Ops.CMOD, name="x"), lambda ctx,x: f"cmod({ctx[x.src[0]]}, {ctx[x.src[1]]})"),
(UPat(set(syms.keys()), name="x"), lambda ctx,x: strip_binary_parens(x, ctx[x.src[0]], ctx[x.src[1]], lambda a,b: f"({a}{syms[x.op]}{b})")),
(UPat((Ops.INDEX, Ops.BUFFERIZE), name="x"), lambda x, ctx: ''.join([f"[{strip_parens(ctx[y])}]" for y in x.src[1:]])),
(UPat((Ops.INDEX, Ops.STAGE), name="x"), lambda x, ctx: ''.join([f"[{strip_parens(ctx[y])}]" for y in x.src[1:]])),
(UPat(Ops.STACK, name="x"),
lambda ctx,x: f"{{{','.join([ctx[y] for y in x.src])}}}" if not x.src or not all_same(x.src) else f"{{{ctx[x.src[0]]}, ...}}"),
(UPat(GroupOp.All, name="x"), lambda x: str(x)),

View file

@ -215,7 +215,7 @@ kernel_spec = PatternMatcher([
(UPat(Ops.END, src=(UPat(), UPat()), allow_any_len=True), lambda: True),
# bufferize can be on anything
(UPat(Ops.BUFFERIZE, src=(UPat(),), allow_any_len=True), lambda: True),
(UPat(Ops.STAGE, src=(UPat(),), allow_any_len=True), lambda: True),
# REDUCE has arg=(op, axis_tuple), src[1:] are ranges after lowering
(UPat(Ops.REDUCE, src=(UPat(),), allow_any_len=True, name="x"),

View file

@ -290,7 +290,7 @@ symbolic = symbolic_simple+commutative+PatternMatcher([
((UPat.var("x", dtypes.weakint) + UPat.cvar("c")).cast(dtypes.sints, name="cast"), lambda x,c,cast:x.cast(cast.dtype)+c.cast(cast.dtype)),
# only RANGE/IF/STORE/KERNEL have side effects
(UPat(Ops.AFTER, name="x"), lambda x: x.replace(src=(x.src[0],)+
tuple(dedup(flatten([(y,) if y.op in {Ops.RANGE, Ops.STORE, Ops.CALL, Ops.FUNCTION, Ops.BARRIER, Ops.END, Ops.UNROLL, Ops.LINEAR, Ops.BUFFERIZE}
tuple(dedup(flatten([(y,) if y.op in {Ops.RANGE, Ops.STORE, Ops.CALL, Ops.FUNCTION, Ops.BARRIER, Ops.END, Ops.UNROLL, Ops.LINEAR, Ops.STAGE}
else y.src for y in x.src[1:]]))))),
# after with 1 src is just src[0]
(UPat(Ops.AFTER, src=(UPat.var("s"),)), lambda s: s),

View file

@ -53,7 +53,7 @@ uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0",
Ops.CALL: "#00B7C8", Ops.FUNCTION: "#C07788", Ops.PARAM: "#14686F", Ops.SOURCE: "#c0c0c0", Ops.BINARY: "#404040",
Ops.LINEAR: "#7DF4FF",
Ops.ALLREDUCE: "#ff40a0", Ops.MSELECT: "#d040a0", Ops.MSTACK: "#d040a0", Ops.CONTIGUOUS: "#FFC14D",
Ops.BUFFERIZE: "#FF991C", Ops.REWRITE_ERROR: "#ff2e2e", Ops.AFTER: "#8A7866", Ops.END: "#524C46"}
Ops.STAGE: "#AC640D", Ops.REWRITE_ERROR: "#ff2e2e", Ops.AFTER: "#8A7866", Ops.END: "#524C46"}
# VIZ API
@ -125,7 +125,7 @@ def uop_to_json(data:VizData, x:UOp) -> dict[int, dict]:
wrap_len = 200 if u.op is Ops.SOURCE else 80
label = f"{str(u.op).split('.')[1]}{(chr(10)+word_wrap(argst.replace(':', ''), wrap=wrap_len)) if u.arg is not None else ''}"
if u.dtype != dtypes.void: label += f"\n{u.dtype}"
for idx,x in enumerate(u.src[:1] if u.op in {Ops.BUFFERIZE, Ops.INDEX} else (u.src if u.op is not Ops.END else [])):
for idx,x in enumerate(u.src[:1] if u.op in {Ops.STAGE, Ops.INDEX} else (u.src if u.op is not Ops.END else [])):
if x in excluded:
# walk through excluded movement ops to find the underlying CONST
cx = x
@ -139,7 +139,7 @@ def uop_to_json(data:VizData, x:UOp) -> dict[int, dict]:
label += f"\n{shape_to_str(u.shape)}"
if u.op in {Ops.CALL, Ops.FUNCTION}:
label += f"\n{u.src[0].key.hex()[:8]}"
if u.op in {Ops.INDEX, Ops.BUFFERIZE}:
if u.op in {Ops.INDEX, Ops.STAGE}:
if len(u.toposort()) < 30: label += f"\n{u.render()}"
ranges: list[UOp] = []
for us in u.src[1:]: ranges += [s for s in us.toposort() if s.op in {Ops.RANGE, Ops.SPECIAL}]