mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
1 commit
master
...
er_prioity
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f215f84241 |
2 changed files with 31 additions and 9 deletions
|
|
@ -38,6 +38,12 @@ class TestLinearizer(unittest.TestCase):
|
|||
np.testing.assert_equal(a.numpy(), ta)
|
||||
np.testing.assert_equal(b.numpy(), tb)
|
||||
|
||||
def test_late_bias_load(self):
|
||||
img = Tensor.empty(1, 3, 16, 16)
|
||||
w = Tensor.empty(16, 3, 3, 3)
|
||||
b = Tensor.empty(16)
|
||||
img.conv2d(w, b).realize()
|
||||
|
||||
def _test_no_nested_ranges(self, lins, skip=None):
|
||||
for l in lins:
|
||||
range_in_acc = flatten([[x for x in u.src if x.op is Ops.RANGE] for u in l.uops if u.op is Ops.DEFINE_REG])
|
||||
|
|
|
|||
|
|
@ -8,7 +8,8 @@ def linearize(u:UOp) -> list[UOp]:
|
|||
lst = list(u.toposort())
|
||||
consumers: defaultdict[UOp, list[UOp]] = defaultdict(list)
|
||||
in_degree:dict[UOp, int] = {}
|
||||
priorities:dict[UOp, tuple[int, int]] = {}
|
||||
priorities:dict[UOp, tuple[int, int, int, int]] = {}
|
||||
ended_ranges:dict[UOp, dict[UOp, None]] = {}
|
||||
|
||||
# get consumers and assign priorities
|
||||
# NOTE: this requires the lst be locally toposorted
|
||||
|
|
@ -16,24 +17,39 @@ def linearize(u:UOp) -> list[UOp]:
|
|||
for s in u.src: consumers[s].append(u)
|
||||
in_degree[u] = len(u.src)
|
||||
|
||||
# we place UOps upstream of more end ranges earlier
|
||||
ended_ranges[u] = {}
|
||||
for x in consumers[u]:
|
||||
if x.op is Ops.END: ended_ranges[u][x] = None
|
||||
ended_ranges[u].update(ended_ranges[x])
|
||||
|
||||
# we place UOps with higher run_counts later
|
||||
# this will cause ranges to be placed late and ends to be placed early
|
||||
run_count = prod([int(r.vmax)+1 for r in u.ranges])
|
||||
|
||||
# simple priority override
|
||||
# simple op priority
|
||||
match u.op:
|
||||
# the order and placement of these is important
|
||||
case Ops.DEFINE_GLOBAL | Ops.DEFINE_LOCAL | Ops.DEFINE_REG | Ops.DEFINE_VAR: priority = -20
|
||||
# the order and placement of these is important. they end the loop early
|
||||
case Ops.DEFINE_GLOBAL | Ops.DEFINE_VAR | Ops.DEFINE_LOCAL | Ops.DEFINE_REG:
|
||||
priorities[u] = (-20, 0, 0, 0)
|
||||
continue
|
||||
# early consts
|
||||
case Ops.CONST: priority = -10
|
||||
# place loads early
|
||||
case Ops.LOAD: priority = -1
|
||||
# control flow resets priority
|
||||
case Ops.RANGE|Ops.END|Ops.IF|Ops.ENDIF: priority = 0
|
||||
# prevent priority inversion
|
||||
case _: priority = min([0]+[priorities[x][1] for x in consumers[u]])
|
||||
case Ops.CONST: op_priority = -10
|
||||
# place END as soon as you can
|
||||
case Ops.END: op_priority = -100
|
||||
# nothing else has op_priority
|
||||
case _: op_priority = 0
|
||||
|
||||
priorities[u] = (run_count, priority)
|
||||
# load priority
|
||||
match u.op:
|
||||
# place loads early
|
||||
case Ops.LOAD: load_priority = -1
|
||||
# control flow resets priority
|
||||
case Ops.RANGE|Ops.IF|Ops.ENDIF: load_priority = 0
|
||||
# prevent priority inversion
|
||||
case _: load_priority = min([0]+[priorities[x][-1] for x in consumers[u]])
|
||||
|
||||
priorities[u] = (op_priority, -len(ended_ranges[u]), run_count, load_priority)
|
||||
|
||||
# number the uops in "ideal" order
|
||||
nkey = {u:i for i,u in enumerate(sorted(lst, key=lambda x: priorities[x]+x.tuplize))}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue