mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
little cleanups
This commit is contained in:
parent
bfb0c0391f
commit
e0d828dba8
2 changed files with 9 additions and 12 deletions
|
|
@ -3,9 +3,9 @@ from collections import defaultdict
|
|||
from tinygrad.uop.ops import PatternMatcher, UOp, Ops, UPat, multirange_str
|
||||
from tinygrad.helpers import prod, getenv
|
||||
|
||||
def linearize(u:UOp) -> list[UOp]:
|
||||
def linearize(sink:UOp) -> list[UOp]:
|
||||
# this is a toposort with priority
|
||||
lst = list(u.toposort())
|
||||
lst = list(sink.toposort())
|
||||
consumers: defaultdict[UOp, list[UOp]] = defaultdict(list)
|
||||
in_degree:dict[UOp, int] = {}
|
||||
priorities:dict[UOp, tuple[int, int]] = {}
|
||||
|
|
@ -22,16 +22,12 @@ def linearize(u:UOp) -> list[UOp]:
|
|||
|
||||
# simple priority override
|
||||
match u.op:
|
||||
# the order and placement of these is important
|
||||
# the order and placement of these defines is important
|
||||
case Ops.DEFINE_GLOBAL | Ops.DEFINE_LOCAL | Ops.DEFINE_REG | Ops.DEFINE_VAR: priority = -20
|
||||
# early consts
|
||||
case Ops.CONST: priority = -10
|
||||
# place loads early
|
||||
case Ops.LOAD: priority = -1
|
||||
# control flow resets priority
|
||||
case Ops.RANGE|Ops.END|Ops.IF|Ops.ENDIF: priority = 0
|
||||
# prevent priority inversion
|
||||
case _: priority = min([0]+[priorities[x][1] for x in consumers[u]])
|
||||
case Ops.CONST: priority = -10 # early consts
|
||||
case Ops.LOAD: priority = -1 # place loads early
|
||||
case Ops.RANGE|Ops.END|Ops.IF|Ops.ENDIF: priority = 0 # control flow resets priority
|
||||
case _: priority = min([0]+[priorities[x][1] for x in consumers[u]]) # prevent priority inversion
|
||||
|
||||
priorities[u] = (run_count, priority)
|
||||
|
||||
|
|
|
|||
|
|
@ -856,8 +856,9 @@ def exec_alu(op:Ops, dtype:DType, operands, truncate_output=True):
|
|||
# ***** uop helpers *****
|
||||
|
||||
def print_uops(uops:list[UOp]):
|
||||
uops_index = {u:i for i,u in enumerate(uops)}
|
||||
for i,u in enumerate(uops):
|
||||
formatted_srcs = [(uops.index(x) if x.op is not Ops.CONST else f"{x.arg}") if x in uops else "--" for x in u.src]
|
||||
formatted_srcs = [(uops_index[x] if x.op is not Ops.CONST else f"{x.arg}") if x in uops else "--" for x in u.src]
|
||||
print(f"{i:4d} {str(u.op):20s}: {multirange_str(u.ranges, color=True, pad=10)} {str(u.dtype):40s} " f"{str(formatted_srcs):32s} {u.arg}")
|
||||
|
||||
# ***** pattern matcher *****
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue