faster viz to_program [pr] (#10410)

* faster viz to_program [pr]

* Callable
This commit is contained in:
qazal 2025-05-19 12:27:49 +03:00 committed by GitHub
commit f9a5ad24c5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -2,7 +2,7 @@
import multiprocessing, pickle, functools, difflib, os, threading, json, time, sys, webbrowser, socket, argparse, decimal, socketserver
from http.server import BaseHTTPRequestHandler
from urllib.parse import parse_qs, urlparse
from typing import Any, Callable, TypedDict, Generator
from typing import Any, TypedDict, Generator
from tinygrad.helpers import colored, getenv, tqdm, unwrap, word_wrap, TRACEMETA
from tinygrad.uop.ops import TrackedGraphRewrite, UOp, Ops, lines, GroupOp, srender, sint
from tinygrad.codegen.kernel import Kernel
@ -19,12 +19,6 @@ uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0",
# VIZ API
# NOTE: if any extra rendering in VIZ fails, we don't crash
def pcall(fxn:Callable[..., str], *args, **kwargs) -> str:
err = kwargs.pop("err", "")
try: return fxn(*args, **kwargs)
except Exception: return f"ERROR in {fxn.__name__}\n{err}"
# ** Metadata for a track_rewrites scope
class GraphRewriteMetadata(TypedDict):
@ -36,11 +30,13 @@ class GraphRewriteMetadata(TypedDict):
depth: int # depth if it's a subrewrite
@functools.cache
def render_program(k:Kernel): return k.opts.render(k.uops)
def render_program(k:Kernel):
try: return k.opts.render(k.uops)
except Exception as e: return f"ISSUE RENDERING KERNEL: {e}\nast = {k.ast}\nopts = {k.applied_opts}"
def to_metadata(k:Any, v:TrackedGraphRewrite) -> GraphRewriteMetadata:
return {"loc":v.loc, "match_count":len(v.matches), "name":v.name, "depth":v.depth, "code_line":lines(v.loc[0])[v.loc[1]-1].strip(),
"kernel_code":pcall(render_program, k, err=f"ast = {k.ast}\nopts = {k.applied_opts}") if isinstance(k, Kernel) else None}
"kernel_code":render_program(k) if isinstance(k, Kernel) else None}
def get_metadata(keys:list[Any], contexts:list[list[TrackedGraphRewrite]]) -> list[tuple[str, list[GraphRewriteMetadata]]]:
return [(k.name if isinstance(k, Kernel) else str(k), [to_metadata(k, v) for v in vals]) for k,vals in zip(keys, contexts)]
@ -97,7 +93,7 @@ def get_details(k:Any, ctx:TrackedGraphRewrite) -> Generator[GraphRewriteDetails
try: new_sink = next_sink.substitute(replaces)
except RecursionError as e: new_sink = UOp(Ops.NOOP, arg=str(e))
yield {"graph":(sink_json:=uop_to_json(new_sink)), "uop":str(new_sink), "changed_nodes":[id(x) for x in u1.toposort() if id(x) in sink_json],
"diff":list(difflib.unified_diff(pcall(str, u0).splitlines(), pcall(str, u1).splitlines())), "upat":(upat.location, upat.printable())}
"diff":list(difflib.unified_diff(str(u0).splitlines(), str(u1).splitlines())), "upat":(upat.location, upat.printable())}
if not ctx.bottom_up: next_sink = new_sink
# Profiler API