replace viz graph when it's sink (#6541)

This commit is contained in:
qazal 2024-09-16 16:00:27 +08:00 committed by GitHub
commit dae3615008
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -3,7 +3,7 @@ from dataclasses import asdict, dataclass
from typing import Dict, List, Tuple
import pickle, re, os, sys, time, threading, webbrowser, json, difflib
from tinygrad.helpers import getenv
from tinygrad.ops import TrackedRewriteContext, UOp
from tinygrad.ops import TrackedRewriteContext, UOp, UOps
from tinygrad.engine.graph import uops_colors, word_wrap
from http.server import HTTPServer, BaseHTTPRequestHandler
@ -43,7 +43,10 @@ def create_graph(ctx:TrackedRewriteContext) -> UOpRet:
extra: List[List[str]] = [[str(ctx.sink)]]
for (first, rewritten, pattern) in ctx.rewrites:
diffs.append((pattern, list(difflib.unified_diff(str(first).splitlines(), str(rewritten).splitlines()))))
uops.append(new_sink:=replace_uop(uops[-1], first, rewritten, {}))
# if the sink was replaced, we have to replace the entire graph, otherwise just replace the parent
new_sink = rewritten if first.op is UOps.SINK else replace_uop(uops[-1], first, rewritten, {})
assert new_sink.op is UOps.SINK
uops.append(new_sink)
extra.append([str(new_sink)])
return UOpRet(ctx.loc, list(map(uop_to_json, uops)), diffs, extra)