print_tree UPat +fix (#5132)

* fix and extend print_tree

* typing

* typing

* fix upat

* fix none

* ws

* rm prefix

* mv luop dag

* typo

* test print_tree
This commit is contained in:
kormann 2024-06-27 00:02:19 +02:00 committed by GitHub
commit 3a04e518ec
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 75 additions and 8 deletions

66
test/test_print_tree.py Normal file
View file

@ -0,0 +1,66 @@
#%%
import unittest
from tinygrad.engine.graph import print_tree
from tinygrad import Tensor, dtypes
from tinygrad.codegen.uops import UOps, UOp, UPat
from tinygrad.ops import BinaryOps
import sys, io
class TestPrintTree(unittest.TestCase):
def _capture_print(self, fn):
capturedOutput = io.StringIO()
sys.stdout = capturedOutput
fn()
sys.stdout = sys.__stdout__
return capturedOutput.getvalue()
def test_print_uop(self):
x = Tensor.arange(10).schedule()[-1].ast[0]
output = self._capture_print(lambda: print_tree(x))
assert output == '\
0 BufferOps.STORE MemBuffer(idx=0, dtype=dtypes.int, \
st=ShapeTracker(views=(View(shape=(10, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)))\n\
1 BinaryOps.ADD None\n\
2 ReduceOps.SUM (1,)\n\
3 BufferOps.CONST ConstBuffer(val=1, dtype=dtypes.int, st=ShapeTrac\
ker(views=(View(shape=(11, 19), strides=(0, 0), offset=0, mask=((0, 11), (9, 19))\
, contiguous=False), View(shape=(10, 10), strides=(1, 20), offset=0, mask=None, contiguous=False))))\n\
4 BufferOps.CONST ConstBuffer(val=-1, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(10,\
1), strides=(0, 0), offset=0, mask=None, contiguous=False),)))\n'
x = UOp.var("x", dtypes.int)
x = (x + x) - UOp.const(dtypes.int, 2)
output = self._capture_print(lambda: print_tree(x))
assert output == '\
0 UOps.ALU BinaryOps.ADD\n\
1 UOps.ALU BinaryOps.ADD\n\
2 UOps.VAR x\n\
3 UOps.VAR x\n\
4 UOps.ALU UnaryOps.NEG\n\
5 UOps.CONST 2\n'
x = UPat(UOp.alu(BinaryOps.ADD, UOp.var("x", dtypes.int), UOp.var("x", dtypes.int)))
assert self._capture_print(lambda: print_tree(x)) == '\
0 UOps.ALU : dtypes.int [<UOps.VAR: 2>, <UOps.VAR: 2>] BinaryOps.ADD None\n'
x = UPat.compile(UOp.store(UOp.var("buf"), UOp.var("idx"),
UOp(UOps.CAST, src=tuple(UOp(UOps.GEP, arg=i, src=(UOp.var("val"),)) for i in range(4)))), UOp.store)
assert self._capture_print(lambda: print_tree(x)) == '\
0 UOps.STORE None\n\
1 None None\n\
2 None None\n\
3 UOps.CAST None\n\
4 UOps.GEP 0\n\
5 None None\n\
6 UOps.GEP 1\n\
7 None None\n\
8 UOps.GEP 2\n\
9 None None\n\
10 UOps.GEP 3\n\
11 None None\n'
if __name__ == "__main__":
unittest.main()

View file

@ -4,7 +4,7 @@ from typing import List, Any, DefaultDict, Union
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, LoadOps, BufferOps, TernaryOps, LazyOp
from tinygrad.device import Device
from tinygrad.helpers import GRAPHPATH, DEBUG, GlobalCounters, getenv
from tinygrad.codegen.uops import UOps, UOp
from tinygrad.codegen.uops import UOps, UOp, UPat
from tinygrad.shape.symbolic import NumNode
from tinygrad.lazy import LazyBuffer
@ -75,18 +75,19 @@ def log_lazybuffer(lb:'LazyBuffer', scheduled=False):
# realized but unseen?
G.add_node(nm(lb), label=f'"{str(lb.base.realized)[5:-1].replace(" ", chr(10))}\nb:{nm(lb.realized)}"', style='filled', fillcolor="#f0c08080")
def _tree(luop:Union[LazyOp,UOp], cycles, cnt, prefix=""):
def _tree(dag:Union[LazyOp, UOp, UPat], cycles, cnt):
cnt[0] += 1
if len(luop.src) == 0: return [f"━━ {prefix}{luop.op.name} {luop.arg if luop.arg else ''}"]
if (lid := id(luop)) in cycles and cycles[lid][1] > (tcnt := getenv("TREE_CYCLE_CNT", 5)) and tcnt >= 0:
return [f"━⬆︎ goto {cycles[id(luop)][0]}: {luop.op.name}"]
src = dag.src if isinstance(dag.src, (list, tuple)) else [] if dag.src is None else [dag.src]
if len(src) == 0: return [f"━━ {dag.op} {dag.arg}"]
if (lid := id(dag)) in cycles and cycles[lid][1] > (tcnt := getenv("TREE_CYCLE_CNT", 5)) and tcnt >= 0:
return [f"━⬆︎ goto {cycles[id(dag)][0]}: {dag.op}"]
cycles[lid] = (cnt[0], 1 if lid not in cycles else cycles[lid][1]+1)
lines = [f"━┳ {prefix}{luop.op.name} {luop.arg if luop.arg else ''}"]
childs = [_tree(c, cycles, cnt) for c in luop.src[:]]
lines = [f"━┳ {dag.op} {dag.arg}"]
childs = [_tree(c, cycles, cnt) for c in src]
for c in childs[:-1]: lines += [f"{c[0]}"] + [f"{l}" for l in c[1:]]
return lines + [""+childs[-1][0]] + [" "+l for l in childs[-1][1:]]
def print_tree(luop:Union[LazyOp,UOp]): print("\n".join([f"{str(i).rjust(3)} {s}" for i,s in enumerate(_tree(luop, {}, [-1]))]))
def print_tree(dag:Union[LazyOp, UOp, UPat]): print("\n".join([f"{str(i).rjust(3)} {s}" for i,s in enumerate(_tree(dag, {}, [-1]))]))
def graph_uops(uops:List[UOp]):
colors = {UOps.ALU: "#ffffc0", UOps.LOAD: "#ffc0c0", UOps.STORE: "#c0ffc0", UOps.SPECIAL: "#c0c0ff", UOps.CONST: "#e0e0e0",