add err handling tests to viz + cleanups (#9825)

* cleanup

* add err handling tests to viz + cleanups

* lint
This commit is contained in:
qazal 2025-04-10 14:05:05 +08:00 committed by GitHub
commit 498a2bf738
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 25 additions and 2 deletions

View file

@ -4,7 +4,7 @@ from tinygrad.ops import TRACK_MATCH_STATS, TrackedPatternMatcher, UOp, graph_re
from tinygrad.codegen.symbolic import symbolic
from tinygrad.ops import tracked_ctxs as contexts, tracked_keys as keys, _name_cnt, _substitute
from tinygrad.device import ProfileDeviceEvent, ProfileRangeEvent, ProfileGraphEvent, ProfileGraphEntry
from tinygrad.viz.serve import get_metadata, uop_to_json, to_perfetto
from tinygrad.viz.serve import get_metadata, get_details, uop_to_json, to_perfetto
# NOTE: VIZ tests always use the tracked PatternMatcher instance
symbolic = TrackedPatternMatcher(symbolic.patterns)
@ -174,6 +174,27 @@ class TestViz(unittest.TestCase):
self.assertEqual([x.name for x in tracked], ["outer", "inner_x", "inner_y"])
self.assertEqual([len(x.matches) for x in tracked], [1, 1, 1])
def test_shape_label(self):
a = UOp.new_buffer("CPU", 1, dtypes.uint8).expand((4,))
b = UOp.new_buffer("CPU", 1, dtypes.uint8).expand((8,))
n = a+b
ser = uop_to_json(n)
self.assertIn("(4,)", ser[id(a)]["label"])
self.assertIn("(8,)", ser[id(b)]["label"])
with self.assertRaises(AssertionError): n.st
_ = ser[id(n)]["label"] # VIZ should not crash
@unittest.skip("TODO: doesn't work")
def test_recursion_err(self):
inf = TrackedPatternMatcher([
(UPat.const(dtypes.int, 0).named("a"), lambda a: a.const_like(1)),
(UPat.const(dtypes.int, 1).named("b"), lambda b: b.const_like(0)),
])
@track_rewrites(named=True)
def func(u): return graph_rewrite(u, inf)
with self.assertRaises(RecursionError): func(UOp.const(dtypes.int, 0))
_ = list(get_details(keys[0], contexts[0][0]))
class TextVizProfiler(unittest.TestCase):
def test_perfetto_node(self):
prof = [ProfileRangeEvent(device='NV', name='E_2', st=decimal.Decimal(1000), en=decimal.Decimal(1010), is_copy=False),

View file

@ -34,9 +34,11 @@ class GraphRewriteMetadata(TypedDict):
@functools.cache
def render_program(k:Kernel): return k.opts.render(k.uops)
def to_metadata(k:Any, v:TrackedGraphRewrite) -> GraphRewriteMetadata:
return {"loc":v.loc, "match_count":len(v.matches), "code_line":lines(v.loc[0])[v.loc[1]-1].strip(),
"kernel_code":pcall(render_program, k) if isinstance(k, Kernel) else None, "name":v.name}
def get_metadata(keys:list[Any], contexts:list[list[TrackedGraphRewrite]]) -> list[tuple[str, list[GraphRewriteMetadata]]]:
return [(k.name if isinstance(k, Kernel) else str(k), [to_metadata(k, v) for v in vals]) for k,vals in zip(keys, contexts)]
@ -45,7 +47,7 @@ def get_metadata(keys:list[Any], contexts:list[list[TrackedGraphRewrite]]) -> li
class GraphRewriteDetails(TypedDict):
graph: dict # JSON serialized UOp for this rewrite step
uop: str # strigified UOp for this rewrite step
diff: list[str]|None # string diff of the single UOp that changed
diff: list[str]|None # diff of the single UOp that changed
changed_nodes: list[int]|None # the changed UOp id + all its parents ids
upat: tuple[tuple[str, int], str]|None # [loc, source_code] of the matched UPat