mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
print_tree UPat +fix (#5132)
* fix and extend print_tree * typing * typing * fix upat * fix none * ws * rm prefix * mv luop dag * typo * test print_tree
This commit is contained in:
parent
0ba093dea0
commit
3a04e518ec
2 changed files with 75 additions and 8 deletions
66
test/test_print_tree.py
Normal file
66
test/test_print_tree.py
Normal file
|
|
@ -0,0 +1,66 @@
|
|||
#%%
|
||||
import unittest
|
||||
from tinygrad.engine.graph import print_tree
|
||||
|
||||
from tinygrad import Tensor, dtypes
|
||||
from tinygrad.codegen.uops import UOps, UOp, UPat
|
||||
from tinygrad.ops import BinaryOps
|
||||
|
||||
import sys, io
|
||||
|
||||
class TestPrintTree(unittest.TestCase):
|
||||
|
||||
def _capture_print(self, fn):
|
||||
capturedOutput = io.StringIO()
|
||||
sys.stdout = capturedOutput
|
||||
fn()
|
||||
sys.stdout = sys.__stdout__
|
||||
return capturedOutput.getvalue()
|
||||
|
||||
def test_print_uop(self):
|
||||
x = Tensor.arange(10).schedule()[-1].ast[0]
|
||||
output = self._capture_print(lambda: print_tree(x))
|
||||
assert output == '\
|
||||
0 ━┳ BufferOps.STORE MemBuffer(idx=0, dtype=dtypes.int, \
|
||||
st=ShapeTracker(views=(View(shape=(10, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)))\n\
|
||||
1 ┗━┳ BinaryOps.ADD None\n\
|
||||
2 ┣━┳ ReduceOps.SUM (1,)\n\
|
||||
3 ┃ ┗━━ BufferOps.CONST ConstBuffer(val=1, dtype=dtypes.int, st=ShapeTrac\
|
||||
ker(views=(View(shape=(11, 19), strides=(0, 0), offset=0, mask=((0, 11), (9, 19))\
|
||||
, contiguous=False), View(shape=(10, 10), strides=(1, 20), offset=0, mask=None, contiguous=False))))\n\
|
||||
4 ┗━━ BufferOps.CONST ConstBuffer(val=-1, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(10,\
|
||||
1), strides=(0, 0), offset=0, mask=None, contiguous=False),)))\n'
|
||||
|
||||
x = UOp.var("x", dtypes.int)
|
||||
x = (x + x) - UOp.const(dtypes.int, 2)
|
||||
output = self._capture_print(lambda: print_tree(x))
|
||||
assert output == '\
|
||||
0 ━┳ UOps.ALU BinaryOps.ADD\n\
|
||||
1 ┣━┳ UOps.ALU BinaryOps.ADD\n\
|
||||
2 ┃ ┣━━ UOps.VAR x\n\
|
||||
3 ┃ ┗━━ UOps.VAR x\n\
|
||||
4 ┗━┳ UOps.ALU UnaryOps.NEG\n\
|
||||
5 ┗━━ UOps.CONST 2\n'
|
||||
|
||||
x = UPat(UOp.alu(BinaryOps.ADD, UOp.var("x", dtypes.int), UOp.var("x", dtypes.int)))
|
||||
assert self._capture_print(lambda: print_tree(x)) == '\
|
||||
0 ━━ UOps.ALU : dtypes.int [<UOps.VAR: 2>, <UOps.VAR: 2>] BinaryOps.ADD None\n'
|
||||
|
||||
x = UPat.compile(UOp.store(UOp.var("buf"), UOp.var("idx"),
|
||||
UOp(UOps.CAST, src=tuple(UOp(UOps.GEP, arg=i, src=(UOp.var("val"),)) for i in range(4)))), UOp.store)
|
||||
assert self._capture_print(lambda: print_tree(x)) == '\
|
||||
0 ━┳ UOps.STORE None\n\
|
||||
1 ┣━━ None None\n\
|
||||
2 ┣━━ None None\n\
|
||||
3 ┗━┳ UOps.CAST None\n\
|
||||
4 ┣━┳ UOps.GEP 0\n\
|
||||
5 ┃ ┗━━ None None\n\
|
||||
6 ┣━┳ UOps.GEP 1\n\
|
||||
7 ┃ ┗━━ None None\n\
|
||||
8 ┣━┳ UOps.GEP 2\n\
|
||||
9 ┃ ┗━━ None None\n\
|
||||
10 ┗━┳ UOps.GEP 3\n\
|
||||
11 ┗━━ None None\n'
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
@ -4,7 +4,7 @@ from typing import List, Any, DefaultDict, Union
|
|||
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, LoadOps, BufferOps, TernaryOps, LazyOp
|
||||
from tinygrad.device import Device
|
||||
from tinygrad.helpers import GRAPHPATH, DEBUG, GlobalCounters, getenv
|
||||
from tinygrad.codegen.uops import UOps, UOp
|
||||
from tinygrad.codegen.uops import UOps, UOp, UPat
|
||||
from tinygrad.shape.symbolic import NumNode
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
|
||||
|
|
@ -75,18 +75,19 @@ def log_lazybuffer(lb:'LazyBuffer', scheduled=False):
|
|||
# realized but unseen?
|
||||
G.add_node(nm(lb), label=f'"{str(lb.base.realized)[5:-1].replace(" ", chr(10))}\nb:{nm(lb.realized)}"', style='filled', fillcolor="#f0c08080")
|
||||
|
||||
def _tree(luop:Union[LazyOp,UOp], cycles, cnt, prefix=""):
|
||||
def _tree(dag:Union[LazyOp, UOp, UPat], cycles, cnt):
|
||||
cnt[0] += 1
|
||||
if len(luop.src) == 0: return [f"━━ {prefix}{luop.op.name} {luop.arg if luop.arg else ''}"]
|
||||
if (lid := id(luop)) in cycles and cycles[lid][1] > (tcnt := getenv("TREE_CYCLE_CNT", 5)) and tcnt >= 0:
|
||||
return [f"━⬆︎ goto {cycles[id(luop)][0]}: {luop.op.name}"]
|
||||
src = dag.src if isinstance(dag.src, (list, tuple)) else [] if dag.src is None else [dag.src]
|
||||
if len(src) == 0: return [f"━━ {dag.op} {dag.arg}"]
|
||||
if (lid := id(dag)) in cycles and cycles[lid][1] > (tcnt := getenv("TREE_CYCLE_CNT", 5)) and tcnt >= 0:
|
||||
return [f"━⬆︎ goto {cycles[id(dag)][0]}: {dag.op}"]
|
||||
cycles[lid] = (cnt[0], 1 if lid not in cycles else cycles[lid][1]+1)
|
||||
lines = [f"━┳ {prefix}{luop.op.name} {luop.arg if luop.arg else ''}"]
|
||||
childs = [_tree(c, cycles, cnt) for c in luop.src[:]]
|
||||
lines = [f"━┳ {dag.op} {dag.arg}"]
|
||||
childs = [_tree(c, cycles, cnt) for c in src]
|
||||
for c in childs[:-1]: lines += [f" ┣{c[0]}"] + [f" ┃{l}" for l in c[1:]]
|
||||
return lines + [" ┗"+childs[-1][0]] + [" "+l for l in childs[-1][1:]]
|
||||
|
||||
def print_tree(luop:Union[LazyOp,UOp]): print("\n".join([f"{str(i).rjust(3)} {s}" for i,s in enumerate(_tree(luop, {}, [-1]))]))
|
||||
def print_tree(dag:Union[LazyOp, UOp, UPat]): print("\n".join([f"{str(i).rjust(3)} {s}" for i,s in enumerate(_tree(dag, {}, [-1]))]))
|
||||
|
||||
def graph_uops(uops:List[UOp]):
|
||||
colors = {UOps.ALU: "#ffffc0", UOps.LOAD: "#ffc0c0", UOps.STORE: "#c0ffc0", UOps.SPECIAL: "#c0c0ff", UOps.CONST: "#e0e0e0",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue