mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
add err handling tests to viz + cleanups (#9825)
* cleanup * add err handling tests to viz + cleanups * lint
This commit is contained in:
parent
a0b72f066a
commit
498a2bf738
2 changed files with 25 additions and 2 deletions
|
|
@ -4,7 +4,7 @@ from tinygrad.ops import TRACK_MATCH_STATS, TrackedPatternMatcher, UOp, graph_re
|
|||
from tinygrad.codegen.symbolic import symbolic
|
||||
from tinygrad.ops import tracked_ctxs as contexts, tracked_keys as keys, _name_cnt, _substitute
|
||||
from tinygrad.device import ProfileDeviceEvent, ProfileRangeEvent, ProfileGraphEvent, ProfileGraphEntry
|
||||
from tinygrad.viz.serve import get_metadata, uop_to_json, to_perfetto
|
||||
from tinygrad.viz.serve import get_metadata, get_details, uop_to_json, to_perfetto
|
||||
|
||||
# NOTE: VIZ tests always use the tracked PatternMatcher instance
|
||||
symbolic = TrackedPatternMatcher(symbolic.patterns)
|
||||
|
|
@ -174,6 +174,27 @@ class TestViz(unittest.TestCase):
|
|||
self.assertEqual([x.name for x in tracked], ["outer", "inner_x", "inner_y"])
|
||||
self.assertEqual([len(x.matches) for x in tracked], [1, 1, 1])
|
||||
|
||||
def test_shape_label(self):
|
||||
a = UOp.new_buffer("CPU", 1, dtypes.uint8).expand((4,))
|
||||
b = UOp.new_buffer("CPU", 1, dtypes.uint8).expand((8,))
|
||||
n = a+b
|
||||
ser = uop_to_json(n)
|
||||
self.assertIn("(4,)", ser[id(a)]["label"])
|
||||
self.assertIn("(8,)", ser[id(b)]["label"])
|
||||
with self.assertRaises(AssertionError): n.st
|
||||
_ = ser[id(n)]["label"] # VIZ should not crash
|
||||
|
||||
@unittest.skip("TODO: doesn't work")
|
||||
def test_recursion_err(self):
|
||||
inf = TrackedPatternMatcher([
|
||||
(UPat.const(dtypes.int, 0).named("a"), lambda a: a.const_like(1)),
|
||||
(UPat.const(dtypes.int, 1).named("b"), lambda b: b.const_like(0)),
|
||||
])
|
||||
@track_rewrites(named=True)
|
||||
def func(u): return graph_rewrite(u, inf)
|
||||
with self.assertRaises(RecursionError): func(UOp.const(dtypes.int, 0))
|
||||
_ = list(get_details(keys[0], contexts[0][0]))
|
||||
|
||||
class TextVizProfiler(unittest.TestCase):
|
||||
def test_perfetto_node(self):
|
||||
prof = [ProfileRangeEvent(device='NV', name='E_2', st=decimal.Decimal(1000), en=decimal.Decimal(1010), is_copy=False),
|
||||
|
|
|
|||
|
|
@ -34,9 +34,11 @@ class GraphRewriteMetadata(TypedDict):
|
|||
|
||||
@functools.cache
|
||||
def render_program(k:Kernel): return k.opts.render(k.uops)
|
||||
|
||||
def to_metadata(k:Any, v:TrackedGraphRewrite) -> GraphRewriteMetadata:
|
||||
return {"loc":v.loc, "match_count":len(v.matches), "code_line":lines(v.loc[0])[v.loc[1]-1].strip(),
|
||||
"kernel_code":pcall(render_program, k) if isinstance(k, Kernel) else None, "name":v.name}
|
||||
|
||||
def get_metadata(keys:list[Any], contexts:list[list[TrackedGraphRewrite]]) -> list[tuple[str, list[GraphRewriteMetadata]]]:
|
||||
return [(k.name if isinstance(k, Kernel) else str(k), [to_metadata(k, v) for v in vals]) for k,vals in zip(keys, contexts)]
|
||||
|
||||
|
|
@ -45,7 +47,7 @@ def get_metadata(keys:list[Any], contexts:list[list[TrackedGraphRewrite]]) -> li
|
|||
class GraphRewriteDetails(TypedDict):
|
||||
graph: dict # JSON serialized UOp for this rewrite step
|
||||
uop: str # strigified UOp for this rewrite step
|
||||
diff: list[str]|None # string diff of the single UOp that changed
|
||||
diff: list[str]|None # diff of the single UOp that changed
|
||||
changed_nodes: list[int]|None # the changed UOp id + all its parents ids
|
||||
upat: tuple[tuple[str, int], str]|None # [loc, source_code] of the matched UPat
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue