mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
faster viz to_program [pr] (#10410)
* faster viz to_program [pr] * Callable
This commit is contained in:
parent
cc8dda1d75
commit
f9a5ad24c5
1 changed files with 6 additions and 10 deletions
|
|
@ -2,7 +2,7 @@
|
|||
import multiprocessing, pickle, functools, difflib, os, threading, json, time, sys, webbrowser, socket, argparse, decimal, socketserver
|
||||
from http.server import BaseHTTPRequestHandler
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
from typing import Any, Callable, TypedDict, Generator
|
||||
from typing import Any, TypedDict, Generator
|
||||
from tinygrad.helpers import colored, getenv, tqdm, unwrap, word_wrap, TRACEMETA
|
||||
from tinygrad.uop.ops import TrackedGraphRewrite, UOp, Ops, lines, GroupOp, srender, sint
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
|
|
@ -19,12 +19,6 @@ uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0",
|
|||
|
||||
# VIZ API
|
||||
|
||||
# NOTE: if any extra rendering in VIZ fails, we don't crash
|
||||
def pcall(fxn:Callable[..., str], *args, **kwargs) -> str:
|
||||
err = kwargs.pop("err", "")
|
||||
try: return fxn(*args, **kwargs)
|
||||
except Exception: return f"ERROR in {fxn.__name__}\n{err}"
|
||||
|
||||
# ** Metadata for a track_rewrites scope
|
||||
|
||||
class GraphRewriteMetadata(TypedDict):
|
||||
|
|
@ -36,11 +30,13 @@ class GraphRewriteMetadata(TypedDict):
|
|||
depth: int # depth if it's a subrewrite
|
||||
|
||||
@functools.cache
|
||||
def render_program(k:Kernel): return k.opts.render(k.uops)
|
||||
def render_program(k:Kernel):
|
||||
try: return k.opts.render(k.uops)
|
||||
except Exception as e: return f"ISSUE RENDERING KERNEL: {e}\nast = {k.ast}\nopts = {k.applied_opts}"
|
||||
|
||||
def to_metadata(k:Any, v:TrackedGraphRewrite) -> GraphRewriteMetadata:
|
||||
return {"loc":v.loc, "match_count":len(v.matches), "name":v.name, "depth":v.depth, "code_line":lines(v.loc[0])[v.loc[1]-1].strip(),
|
||||
"kernel_code":pcall(render_program, k, err=f"ast = {k.ast}\nopts = {k.applied_opts}") if isinstance(k, Kernel) else None}
|
||||
"kernel_code":render_program(k) if isinstance(k, Kernel) else None}
|
||||
|
||||
def get_metadata(keys:list[Any], contexts:list[list[TrackedGraphRewrite]]) -> list[tuple[str, list[GraphRewriteMetadata]]]:
|
||||
return [(k.name if isinstance(k, Kernel) else str(k), [to_metadata(k, v) for v in vals]) for k,vals in zip(keys, contexts)]
|
||||
|
|
@ -97,7 +93,7 @@ def get_details(k:Any, ctx:TrackedGraphRewrite) -> Generator[GraphRewriteDetails
|
|||
try: new_sink = next_sink.substitute(replaces)
|
||||
except RecursionError as e: new_sink = UOp(Ops.NOOP, arg=str(e))
|
||||
yield {"graph":(sink_json:=uop_to_json(new_sink)), "uop":str(new_sink), "changed_nodes":[id(x) for x in u1.toposort() if id(x) in sink_json],
|
||||
"diff":list(difflib.unified_diff(pcall(str, u0).splitlines(), pcall(str, u1).splitlines())), "upat":(upat.location, upat.printable())}
|
||||
"diff":list(difflib.unified_diff(str(u0).splitlines(), str(u1).splitlines())), "upat":(upat.location, upat.printable())}
|
||||
if not ctx.bottom_up: next_sink = new_sink
|
||||
|
||||
# Profiler API
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue