mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
don't render index (#12796)
* don't render index * update to ignore_indexing --------- Co-authored-by: qazal <qazal.software@gmail.com>
This commit is contained in:
parent
cad3ada909
commit
ba593f7b98
2 changed files with 11 additions and 7 deletions
|
|
@ -149,7 +149,7 @@ class TestViz(BaseTestViz):
|
|||
self.assertEqual(ansistrip(a2["label"]), f"CUSTOM\n{TestStruct.__qualname__}(colored_field='xyz12345')")
|
||||
|
||||
def test_inf_loop(self):
|
||||
a = UOp.variable('a', 0, 10)
|
||||
a = UOp.variable('a', 0, 10, dtype=dtypes.int)
|
||||
b = a.replace(op=Ops.CONST)
|
||||
pm = PatternMatcher([
|
||||
(UPat(Ops.DEFINE_VAR, name="x"), lambda x: x.replace(op=Ops.CONST)),
|
||||
|
|
@ -164,8 +164,8 @@ class TestViz(BaseTestViz):
|
|||
self.assertEqual(graphs[2], uop_to_json(nop)[id(nop)])
|
||||
|
||||
def test_const_node_visibility(self):
|
||||
a = UOp.variable("a", 0, 10)
|
||||
z = UOp.const(dtypes.index, 0)
|
||||
a = UOp.variable("a", 0, 10, dtype=dtypes.int)
|
||||
z = UOp.const(a.dtype, 0)
|
||||
alu = a*z
|
||||
exec_rewrite(alu, [sym])
|
||||
lst = get_viz_list()
|
||||
|
|
|
|||
|
|
@ -56,7 +56,7 @@ def pystr(u:UOp, i:int) -> str:
|
|||
except Exception: pass
|
||||
return str(u)
|
||||
|
||||
def uop_to_json(x:UOp) -> dict[int, dict]:
|
||||
def uop_to_json(x:UOp, ignore_indexing=False) -> dict[int, dict]:
|
||||
assert isinstance(x, UOp)
|
||||
graph: dict[int, dict] = {}
|
||||
excluded: set[UOp] = set()
|
||||
|
|
@ -64,13 +64,14 @@ def uop_to_json(x:UOp) -> dict[int, dict]:
|
|||
# always exclude DEVICE/CONST/UNIQUE
|
||||
if u.op in {Ops.DEVICE, Ops.CONST, Ops.UNIQUE} and u is not x: excluded.add(u)
|
||||
if u.op is Ops.VCONST and u.dtype.scalar() == dtypes.index and u is not x: excluded.add(u)
|
||||
if u.dtype.scalar() is dtypes.index and ignore_indexing: excluded.update(u.backward_slice_with_self)
|
||||
for u in toposort:
|
||||
if u in excluded: continue
|
||||
argst = codecs.decode(str(u.arg), "unicode_escape")
|
||||
if u.op in GroupOp.Movement: argst = (mask_to_str if u.op in {Ops.SHRINK, Ops.PAD} else shape_to_str)(u.marg)
|
||||
label = f"{str(u.op).split('.')[1]}{(chr(10)+word_wrap(argst.replace(':', ''))) if u.arg is not None else ''}"
|
||||
if u.dtype != dtypes.void: label += f"\n{u.dtype}"
|
||||
for idx,x in enumerate(u.src):
|
||||
for idx,x in enumerate(u.src[:1] if u.op in {Ops.BUFFERIZE, Ops.INDEX} else u.src):
|
||||
if x in excluded:
|
||||
arg = f"{x.arg:g}" if x.op is Ops.CONST and dtypes.is_float(x.dtype) else f"{x.arg}"
|
||||
label += f"\n{x.op.name}{idx} {arg}" + (f" {x.src[0].op}" if len(x.src) else "")
|
||||
|
|
@ -97,14 +98,17 @@ def _reconstruct(a:int):
|
|||
return UOp(op, dtype, tuple(_reconstruct(s) for s in src), arg, *rest)
|
||||
|
||||
def get_full_rewrite(ctx:TrackedGraphRewrite, i:int=0) -> Generator[GraphRewriteDetails, None, None]:
|
||||
yield {"graph":uop_to_json(next_sink:=_reconstruct(ctx.sink)), "uop":pystr(next_sink,i), "changed_nodes":None, "diff":None, "upat":None}
|
||||
ignore_indexing = not (isinstance(trace.keys[i].ret, ProgramSpec) or ctx.name in {"kernel split"})
|
||||
yield {"graph":uop_to_json(next_sink:=_reconstruct(ctx.sink), ignore_indexing), "uop":pystr(next_sink,i), "changed_nodes":None,
|
||||
"diff":None, "upat":None}
|
||||
replaces: dict[UOp, UOp] = {}
|
||||
for u0_num,u1_num,upat_loc,dur in tqdm(ctx.matches):
|
||||
replaces[u0:=_reconstruct(u0_num)] = u1 = _reconstruct(u1_num)
|
||||
try: new_sink = next_sink.substitute(replaces)
|
||||
except RuntimeError as e: new_sink = UOp(Ops.NOOP, arg=str(e))
|
||||
match_repr = f"# {dur*1e6:.2f} us\n"+printable(upat_loc)
|
||||
yield {"graph":(sink_json:=uop_to_json(new_sink)), "uop":pystr(new_sink,i), "changed_nodes":[id(x) for x in u1.toposort() if id(x) in sink_json],
|
||||
yield {"graph":(sink_json:=uop_to_json(new_sink, ignore_indexing)), "uop":pystr(new_sink,i),
|
||||
"changed_nodes":[id(x) for x in u1.toposort() if id(x) in sink_json],
|
||||
"diff":list(difflib.unified_diff(pystr(u0,i).splitlines(),pystr(u1,i).splitlines())), "upat":(upat_loc, match_repr)}
|
||||
if not ctx.bottom_up: next_sink = new_sink
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue