Compare commits

...

1 commit

Author SHA1 Message Date
George Hotz
f215f84241 use end range count in priority 2025-11-06 10:17:35 -08:00
2 changed files with 31 additions and 9 deletions

View file

@ -38,6 +38,12 @@ class TestLinearizer(unittest.TestCase):
np.testing.assert_equal(a.numpy(), ta)
np.testing.assert_equal(b.numpy(), tb)
def test_late_bias_load(self):
img = Tensor.empty(1, 3, 16, 16)
w = Tensor.empty(16, 3, 3, 3)
b = Tensor.empty(16)
img.conv2d(w, b).realize()
def _test_no_nested_ranges(self, lins, skip=None):
for l in lins:
range_in_acc = flatten([[x for x in u.src if x.op is Ops.RANGE] for u in l.uops if u.op is Ops.DEFINE_REG])

View file

@ -8,7 +8,8 @@ def linearize(u:UOp) -> list[UOp]:
lst = list(u.toposort())
consumers: defaultdict[UOp, list[UOp]] = defaultdict(list)
in_degree:dict[UOp, int] = {}
priorities:dict[UOp, tuple[int, int]] = {}
priorities:dict[UOp, tuple[int, int, int, int]] = {}
ended_ranges:dict[UOp, dict[UOp, None]] = {}
# get consumers and assign priorities
# NOTE: this requires the lst be locally toposorted
@ -16,24 +17,39 @@ def linearize(u:UOp) -> list[UOp]:
for s in u.src: consumers[s].append(u)
in_degree[u] = len(u.src)
# we place UOps upstream of more end ranges earlier
ended_ranges[u] = {}
for x in consumers[u]:
if x.op is Ops.END: ended_ranges[u][x] = None
ended_ranges[u].update(ended_ranges[x])
# we place UOps with higher run_counts later
# this will cause ranges to be placed late and ends to be placed early
run_count = prod([int(r.vmax)+1 for r in u.ranges])
# simple priority override
# simple op priority
match u.op:
# the order and placement of these is important
case Ops.DEFINE_GLOBAL | Ops.DEFINE_LOCAL | Ops.DEFINE_REG | Ops.DEFINE_VAR: priority = -20
# the order and placement of these is important. they end the loop early
case Ops.DEFINE_GLOBAL | Ops.DEFINE_VAR | Ops.DEFINE_LOCAL | Ops.DEFINE_REG:
priorities[u] = (-20, 0, 0, 0)
continue
# early consts
case Ops.CONST: priority = -10
# place loads early
case Ops.LOAD: priority = -1
# control flow resets priority
case Ops.RANGE|Ops.END|Ops.IF|Ops.ENDIF: priority = 0
# prevent priority inversion
case _: priority = min([0]+[priorities[x][1] for x in consumers[u]])
case Ops.CONST: op_priority = -10
# place END as soon as you can
case Ops.END: op_priority = -100
# nothing else has op_priority
case _: op_priority = 0
priorities[u] = (run_count, priority)
# load priority
match u.op:
# place loads early
case Ops.LOAD: load_priority = -1
# control flow resets priority
case Ops.RANGE|Ops.IF|Ops.ENDIF: load_priority = 0
# prevent priority inversion
case _: load_priority = min([0]+[priorities[x][-1] for x in consumers[u]])
priorities[u] = (op_priority, -len(ended_ranges[u]), run_count, load_priority)
# number the uops in "ideal" order
nkey = {u:i for i,u in enumerate(sorted(lst, key=lambda x: priorities[x]+x.tuplize))}