little cleanups

This commit is contained in:
George Hotz 2025-11-06 13:58:19 -08:00
commit e0d828dba8
2 changed files with 9 additions and 12 deletions

View file

@ -3,9 +3,9 @@ from collections import defaultdict
from tinygrad.uop.ops import PatternMatcher, UOp, Ops, UPat, multirange_str
from tinygrad.helpers import prod, getenv
def linearize(u:UOp) -> list[UOp]:
def linearize(sink:UOp) -> list[UOp]:
# this is a toposort with priority
lst = list(u.toposort())
lst = list(sink.toposort())
consumers: defaultdict[UOp, list[UOp]] = defaultdict(list)
in_degree:dict[UOp, int] = {}
priorities:dict[UOp, tuple[int, int]] = {}
@ -22,16 +22,12 @@ def linearize(u:UOp) -> list[UOp]:
# simple priority override
match u.op:
# the order and placement of these is important
# the order and placement of these defines is important
case Ops.DEFINE_GLOBAL | Ops.DEFINE_LOCAL | Ops.DEFINE_REG | Ops.DEFINE_VAR: priority = -20
# early consts
case Ops.CONST: priority = -10
# place loads early
case Ops.LOAD: priority = -1
# control flow resets priority
case Ops.RANGE|Ops.END|Ops.IF|Ops.ENDIF: priority = 0
# prevent priority inversion
case _: priority = min([0]+[priorities[x][1] for x in consumers[u]])
case Ops.CONST: priority = -10 # early consts
case Ops.LOAD: priority = -1 # place loads early
case Ops.RANGE|Ops.END|Ops.IF|Ops.ENDIF: priority = 0 # control flow resets priority
case _: priority = min([0]+[priorities[x][1] for x in consumers[u]]) # prevent priority inversion
priorities[u] = (run_count, priority)

View file

@ -856,8 +856,9 @@ def exec_alu(op:Ops, dtype:DType, operands, truncate_output=True):
# ***** uop helpers *****
def print_uops(uops:list[UOp]):
uops_index = {u:i for i,u in enumerate(uops)}
for i,u in enumerate(uops):
formatted_srcs = [(uops.index(x) if x.op is not Ops.CONST else f"{x.arg}") if x in uops else "--" for x in u.src]
formatted_srcs = [(uops_index[x] if x.op is not Ops.CONST else f"{x.arg}") if x in uops else "--" for x in u.src]
print(f"{i:4d} {str(u.op):20s}: {multirange_str(u.ranges, color=True, pad=10)} {str(u.dtype):40s} " f"{str(formatted_srcs):32s} {u.arg}")
# ***** pattern matcher *****