mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
viz/cli: schedule renderer (#16101)
* simpler steps * work * work * iterate * faster * better * simplify more * sys stdin * less * work * work and mv * better * seen bufs * all call graphs * print query * ux * param to buffer / buffer_view * work * respect NO_COLOR in uop_to_json * less * render uops * rm custom renderer * call can't pyrender. * unrelated diff * assert * 5
This commit is contained in:
parent
53f9587099
commit
2dd84416bf
3 changed files with 34 additions and 11 deletions
|
|
@ -320,7 +320,7 @@ class TestVizGC(unittest.TestCase):
|
|||
|
||||
# VIZ integrates with other parts of tinygrad
|
||||
|
||||
from tinygrad import Tensor, Device, TinyJit, Variable
|
||||
from tinygrad import Tensor, Device, TinyJit, Variable, function
|
||||
|
||||
class TestVizIntegration(unittest.TestCase):
|
||||
# codegen supports rendering of code blocks
|
||||
|
|
@ -340,15 +340,18 @@ class TestVizIntegration(unittest.TestCase):
|
|||
c1 = Tensor.empty(4).add(1)
|
||||
c2 = Tensor.empty(8).add(1)
|
||||
sched = c1.schedule_linear(c2)
|
||||
prgs = [to_program(si.src[0], Device[Device.DEFAULT].renderer).arg.name for si in sched.src]
|
||||
with Context(NO_COLOR=0):
|
||||
prgs = [to_program(si.src[0], Device[Device.DEFAULT].renderer).arg.name for si in sched.src]
|
||||
lst = viz.list_items()
|
||||
sched_idx = next(i for i,l in enumerate(lst) if l["name"].startswith("Schedule"))
|
||||
viz_kernel = next(i for i,s in enumerate(lst[sched_idx]["steps"]) if s["name"] == "View Kernel Graph")
|
||||
graph = next(viz.get_details(sched_idx, viz_kernel))["graph"]
|
||||
with Context(NO_COLOR=1):
|
||||
graph = next(viz.get_details(sched_idx, viz_kernel))["graph"]
|
||||
call_nodes = [n for n in graph.values() if n["label"].startswith("CALL")]
|
||||
for i,n in enumerate(call_nodes):
|
||||
assert n["ref"] is not None
|
||||
self.assertEqual(lst[n["ref"]]["name"], prgs[i])
|
||||
assert ansistrip(prgs[i]) in n["label"], f"CALL must contain kernel name, got {n['label']}"
|
||||
|
||||
@Context(TRACEMETA=2)
|
||||
def test_metadata_tracing(self):
|
||||
|
|
@ -980,5 +983,23 @@ class TestCLI(unittest.TestCase):
|
|||
self.assertEqual(len([s for s in select if s.get("value")]), 1, "debug output was not deduped")
|
||||
self.assertEqual(len([s for s in select if s.get("device") == "NULL"]), CNT, f"expected 4 runs for {name}")
|
||||
|
||||
def test_call_graph(self):
|
||||
@function(precompile=True)
|
||||
def f(x):
|
||||
r = x.sum(axis=1).reshape(32, 1).expand(32, 32).contiguous()
|
||||
return x + r
|
||||
# turn of scache because this test requires a complete schedule rewrite
|
||||
with save_viz() as viz, Context(SCACHE=0):
|
||||
f(f(Tensor.empty(32, 32, device="NULL"))).realize()
|
||||
with write_files(viz) as files, Context(NO_COLOR=1):
|
||||
prgs = [s["name"] for s in run_cli(*files, "-s", "NULL")]
|
||||
with Context(DEBUG=5):
|
||||
out = run_cli(*files, "-s", "TINY")
|
||||
i = next(i for i,s in enumerate(out) if s.get("value", "").lstrip() == "View Kernel Graph")
|
||||
# next print is the CALL graph, CLI outputs exactly as web in TestVizIntegration.test_link_sched_codegen
|
||||
call_nodes = [n for n in out[i+1].values() if n["label"].startswith("CALL")]
|
||||
for i,n in enumerate(call_nodes):
|
||||
assert prgs[i] in n["label"], f"CALL must contain kernel name, got {n['label']}"
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ os.environ["VIZ"] = "0"
|
|||
if hasattr(signal, "SIGPIPE"): signal.signal(signal.SIGPIPE, signal.SIG_DFL)
|
||||
from typing import Iterator
|
||||
from tinygrad.viz import serve as viz
|
||||
from tinygrad.viz.serve import fmt_colored
|
||||
from tinygrad.uop.ops import RewriteTrace
|
||||
from tinygrad.helpers import temp, ansistrip, colored, time_to_str, ansilen, ProfilePointEvent, ProfileRangeEvent, TracingKey, unwrap, NO_COLOR, DEBUG
|
||||
|
||||
|
|
@ -45,8 +46,6 @@ def decode_profile(data:bytes) -> dict:
|
|||
for k,rep,num,mode in [u("<IIIB") for _ in range(u("<I")[0])]]}})
|
||||
return {"dur":total_dur, "peak":global_peak, "layout":layout, "markers":markers}
|
||||
|
||||
def fmt_colored(s:str) -> str: return ansistrip(s) if NO_COLOR else s
|
||||
|
||||
def to_str(k:str, v) -> str:
|
||||
if k == "FLOPS" or k.startswith("B/s"): return f"{v*1e-9:.0f} G{k}" if v < 1e13 else f"{v*1e-12:.0f} T{k}"
|
||||
if k == "B": return next((f"{v/s:.0f} {u}" for s,u in ((1e9,"GB"),(1e6,"MB"),(1e3,"KB")) if v>=s), f"{v:.0f} B")
|
||||
|
|
@ -65,11 +64,11 @@ def main(args) -> None:
|
|||
|
||||
def emit(val, to_str=str) -> str: return json.dumps(val if isinstance(val, dict) else {"value":val}) if args.json else to_str(val)
|
||||
|
||||
def print_step(step:dict, reconstruct_matches=False) -> None:
|
||||
def print_step(step:dict, print_graph=False, reconstruct_matches=False) -> None:
|
||||
data = viz.get_render(viz_data, step["query"])
|
||||
if isinstance(data.get("value"), Iterator):
|
||||
for m in data["value"]:
|
||||
if m.get("uop"): print(emit(m["uop"]))
|
||||
if "uop" in m: print(emit(m["graph"] if print_graph else m["uop"]))
|
||||
if not reconstruct_matches: return None
|
||||
if m.get("diff"):
|
||||
loc = pathlib.Path(m["upat"][0][0])
|
||||
|
|
@ -191,11 +190,11 @@ def main(args) -> None:
|
|||
print(emit(k, to_str=fmt_row))
|
||||
if k["ref"] is not None and k["ref"] not in seen_refs:
|
||||
seen_refs.add(k["ref"])
|
||||
for s in viz_data.ctxs[k["ref"]]["steps"]:
|
||||
for i,s in enumerate(viz_data.ctxs[k["ref"]]["steps"]):
|
||||
if DEBUG >= 3 and s["name"] == "View Base AST": print_step(s)
|
||||
if DEBUG >= 4 and s["name"] == "View Source": print_step(s)
|
||||
if DEBUG >= 5 or ls: print(emit(" "*s["depth"]+s["name"]+(f" - {s['match_count']}" if s.get('match_count', 0) else '')))
|
||||
if DEBUG >= 6: print_step(s)
|
||||
if DEBUG >= 6 or (DEBUG >= 5 and s["name"] == "View Kernel Graph"): print_step(s, print_graph=True)
|
||||
if DEBUG >= 7 or s["name"] in args.src: print_step(s, reconstruct_matches=True)
|
||||
elif DEBUG >= 3 and k.get("ext"): print(emit(k["ext"]))
|
||||
for k in (produce_top_kernels if args.t else produce_all_kernels)(): render_event(k)
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from urllib.parse import parse_qs, urlparse
|
|||
from http.server import BaseHTTPRequestHandler
|
||||
from typing import Any, TypedDict, TypeVar, Generator, Callable
|
||||
from tinygrad.helpers import colored, getenv, tqdm, unwrap, word_wrap, TRACEMETA, ProfileEvent, ProfileRangeEvent, TracingKey, ProfilePointEvent, temp
|
||||
from tinygrad.helpers import printable, Context, START_TIME
|
||||
from tinygrad.helpers import printable, Context, START_TIME, NO_COLOR, ansistrip
|
||||
from tinygrad.renderer.amd.dsl import Inst
|
||||
from tinygrad.renderer.amd import detect_format
|
||||
|
||||
|
|
@ -105,6 +105,8 @@ def pystr(u:UOp) -> str:
|
|||
try: return pyrender(u)
|
||||
except Exception: return str(u)
|
||||
|
||||
def fmt_colored(s:str) -> str: return ansistrip(s) if NO_COLOR else s
|
||||
|
||||
def uop_to_json(data:VizData, x:UOp) -> dict[int, dict]:
|
||||
assert isinstance(x, UOp)
|
||||
graph: dict[int, dict] = {}
|
||||
|
|
@ -148,7 +150,8 @@ def uop_to_json(data:VizData, x:UOp) -> dict[int, dict]:
|
|||
label += "\n"+' '.join([f"{range_str(s, color=True)}({s.vmax+1})" for s in trngs])
|
||||
except Exception:
|
||||
label += "\n<ISSUE GETTING LABEL>"
|
||||
if (ref:=data.ref_map.get(u.src[0]) if u.op in {Ops.CALL, Ops.FUNCTION} else None) is not None: label += f"\ncodegen@{data.ctxs[ref]['name']}"
|
||||
ref = data.ref_map.get(u.src[0]) if u.op in {Ops.CALL, Ops.FUNCTION} else None
|
||||
if ref is not None: label += f"\ncodegen@{fmt_colored(data.ctxs[ref]['name'])}"
|
||||
# NOTE: kernel already has metadata in arg
|
||||
if TRACEMETA >= 2 and u.metadata is not None and u.op not in {Ops.CALL, Ops.FUNCTION}: label += "\n"+str(u.metadata)
|
||||
# limit SOURCE labels line count
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue