don't render index (#12796)

* don't render index

* update to ignore_indexing

---------

Co-authored-by: qazal <qazal.software@gmail.com>
This commit is contained in:
George Hotz 2025-10-20 09:48:36 +08:00 committed by GitHub
commit ba593f7b98
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 11 additions and 7 deletions

View file

@ -149,7 +149,7 @@ class TestViz(BaseTestViz):
self.assertEqual(ansistrip(a2["label"]), f"CUSTOM\n{TestStruct.__qualname__}(colored_field='xyz12345')")
def test_inf_loop(self):
a = UOp.variable('a', 0, 10)
a = UOp.variable('a', 0, 10, dtype=dtypes.int)
b = a.replace(op=Ops.CONST)
pm = PatternMatcher([
(UPat(Ops.DEFINE_VAR, name="x"), lambda x: x.replace(op=Ops.CONST)),
@ -164,8 +164,8 @@ class TestViz(BaseTestViz):
self.assertEqual(graphs[2], uop_to_json(nop)[id(nop)])
def test_const_node_visibility(self):
a = UOp.variable("a", 0, 10)
z = UOp.const(dtypes.index, 0)
a = UOp.variable("a", 0, 10, dtype=dtypes.int)
z = UOp.const(a.dtype, 0)
alu = a*z
exec_rewrite(alu, [sym])
lst = get_viz_list()

View file

@ -56,7 +56,7 @@ def pystr(u:UOp, i:int) -> str:
except Exception: pass
return str(u)
def uop_to_json(x:UOp) -> dict[int, dict]:
def uop_to_json(x:UOp, ignore_indexing=False) -> dict[int, dict]:
assert isinstance(x, UOp)
graph: dict[int, dict] = {}
excluded: set[UOp] = set()
@ -64,13 +64,14 @@ def uop_to_json(x:UOp) -> dict[int, dict]:
# always exclude DEVICE/CONST/UNIQUE
if u.op in {Ops.DEVICE, Ops.CONST, Ops.UNIQUE} and u is not x: excluded.add(u)
if u.op is Ops.VCONST and u.dtype.scalar() == dtypes.index and u is not x: excluded.add(u)
if u.dtype.scalar() is dtypes.index and ignore_indexing: excluded.update(u.backward_slice_with_self)
for u in toposort:
if u in excluded: continue
argst = codecs.decode(str(u.arg), "unicode_escape")
if u.op in GroupOp.Movement: argst = (mask_to_str if u.op in {Ops.SHRINK, Ops.PAD} else shape_to_str)(u.marg)
label = f"{str(u.op).split('.')[1]}{(chr(10)+word_wrap(argst.replace(':', ''))) if u.arg is not None else ''}"
if u.dtype != dtypes.void: label += f"\n{u.dtype}"
for idx,x in enumerate(u.src):
for idx,x in enumerate(u.src[:1] if u.op in {Ops.BUFFERIZE, Ops.INDEX} else u.src):
if x in excluded:
arg = f"{x.arg:g}" if x.op is Ops.CONST and dtypes.is_float(x.dtype) else f"{x.arg}"
label += f"\n{x.op.name}{idx} {arg}" + (f" {x.src[0].op}" if len(x.src) else "")
@ -97,14 +98,17 @@ def _reconstruct(a:int):
return UOp(op, dtype, tuple(_reconstruct(s) for s in src), arg, *rest)
def get_full_rewrite(ctx:TrackedGraphRewrite, i:int=0) -> Generator[GraphRewriteDetails, None, None]:
yield {"graph":uop_to_json(next_sink:=_reconstruct(ctx.sink)), "uop":pystr(next_sink,i), "changed_nodes":None, "diff":None, "upat":None}
ignore_indexing = not (isinstance(trace.keys[i].ret, ProgramSpec) or ctx.name in {"kernel split"})
yield {"graph":uop_to_json(next_sink:=_reconstruct(ctx.sink), ignore_indexing), "uop":pystr(next_sink,i), "changed_nodes":None,
"diff":None, "upat":None}
replaces: dict[UOp, UOp] = {}
for u0_num,u1_num,upat_loc,dur in tqdm(ctx.matches):
replaces[u0:=_reconstruct(u0_num)] = u1 = _reconstruct(u1_num)
try: new_sink = next_sink.substitute(replaces)
except RuntimeError as e: new_sink = UOp(Ops.NOOP, arg=str(e))
match_repr = f"# {dur*1e6:.2f} us\n"+printable(upat_loc)
yield {"graph":(sink_json:=uop_to_json(new_sink)), "uop":pystr(new_sink,i), "changed_nodes":[id(x) for x in u1.toposort() if id(x) in sink_json],
yield {"graph":(sink_json:=uop_to_json(new_sink, ignore_indexing)), "uop":pystr(new_sink,i),
"changed_nodes":[id(x) for x in u1.toposort() if id(x) in sink_json],
"diff":list(difflib.unified_diff(pystr(u0,i).splitlines(),pystr(u1,i).splitlines())), "upat":(upat_loc, match_repr)}
if not ctx.bottom_up: next_sink = new_sink