catch cycles in print_tree (#2891)

* feat: smaller tree on references

* fix: shorter line

* fix: huh

* fix: should be all

* feat: cleaner

* fix: extra imports

* fix: pass by reference
This commit is contained in:
wozeparrot 2023-12-21 21:40:37 -05:00 committed by GitHub
commit 5f3d5cfb02
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -2,7 +2,7 @@ import os, atexit
from typing import List, Any
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, BufferOps, TernaryOps, Op, LazyOp
from tinygrad.device import Device
from tinygrad.helpers import GRAPH, GRAPHPATH, DEBUG, GlobalCounters
from tinygrad.helpers import GRAPH, GRAPHPATH, DEBUG, GlobalCounters, getenv
from tinygrad.codegen.linearizer import UOps, UOp
from tinygrad.shape.symbolic import NumNode
@ -86,16 +86,19 @@ def log_lazybuffer(lb, scheduled=False):
if nm(lb) not in G.nodes:
# realized but unseen?
G.add_node(nm(lb), label=f'"{str(lb.base.realized)[5:-1].replace(" ", chr(10))}\nb:{bm(lb.realized)}"', style='filled', fillcolor="#f0c08080")
def _tree(lazydata, prefix=""):
if type(lazydata).__name__ == "LazyBuffer":
return [f"━━ realized {lazydata.dtype.name} {lazydata.shape}"] if (lazydata.realized) else _tree(lazydata.op, "LB ")
def _tree(lazydata, cycles, cnt, prefix=""):
cnt[0] += 1
if len(lazydata.src) == 0: return [f"━━ {prefix}{lazydata.op.name} {lazydata.arg if lazydata.arg else ''}"]
if (lid := id(lazydata)) in cycles and cycles[lid][1] > (tcnt := getenv("TREE_CYCLE_CNT", 5)) and tcnt >= 0:
return [f"━⬆︎ goto {cycles[id(lazydata)][0]}: {lazydata.op.name}"]
cycles[lid] = (cnt[0], 1 if lid not in cycles else cycles[lid][1]+1)
lines = [f"━┳ {prefix}{lazydata.op.name} {lazydata.arg if lazydata.arg else ''}"]
childs = [_tree(c) for c in lazydata.src[:]]
childs = [_tree(c, cycles, cnt) for c in lazydata.src[:]]
for c in childs[:-1]: lines += [f"{c[0]}"] + [f"{l}" for l in c[1:]]
return lines + [""+childs[-1][0]] + [" "+l for l in childs[-1][1:]]
def print_tree(lazydata:LazyOp): print("\n".join([f"{str(i).rjust(3)} {s}" for i,s in enumerate(_tree(lazydata))]))
def print_tree(lazydata:LazyOp): print("\n".join([f"{str(i).rjust(3)} {s}" for i,s in enumerate(_tree(lazydata, {}, [-1]))]))
def graph_uops(uops:List[UOp]):
import networkx as nx