faster block reorder (#9990)

* faster block reorder [pr]

* that shouldn't change order

* key just in sorted

* ind
This commit is contained in:
George Hotz 2025-04-22 19:18:57 +01:00 committed by GitHub
commit feee6986c9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 62 additions and 14 deletions

View file

@ -0,0 +1,50 @@
import unittest, random
from tinygrad.dtype import dtypes
from tinygrad.ops import print_uops, UOp, Ops
from tinygrad.codegen.linearize import block_reorder
#from tinygrad.renderer.cstyle import ClangRenderer
def is_toposorted(lst:list[UOp]):
seen = set()
for u in lst:
if any(p not in seen for p in u.src): return False
seen.add(u)
return True
class TestBlockReorder(unittest.TestCase):
def test_loads_randomize(self):
a = UOp(Ops.DEFINE_GLOBAL, dtype=dtypes.float.ptr(), arg=0)
b = UOp(Ops.DEFINE_GLOBAL, dtype=dtypes.float.ptr(), arg=1)
c = UOp(Ops.DEFINE_GLOBAL, dtype=dtypes.float.ptr(), arg=2)
v1 = UOp(Ops.DEFINE_VAR, dtype=dtypes.int, arg=("a",))
v2 = UOp(Ops.DEFINE_VAR, dtype=dtypes.int, arg=("b",))
sink = c.store(sum([
a.index(v1).load(dtype=dtypes.float),
a.index(v1+1).load(dtype=dtypes.float),
a.index(v1+2).load(dtype=dtypes.float),
b.index(v2).load(dtype=dtypes.float),
b.index(v2+1).load(dtype=dtypes.float),
b.index(v2+2).load(dtype=dtypes.float),
])).sink()
golden = block_reorder(sink.toposort)
# test random order is always same
for _ in range(50):
# shuffle and form a valid toposort
lst = golden[:]
random.shuffle(lst)
topolst = []
for u in lst:
for p in u.toposort:
if p not in topolst: topolst.append(p)
assert is_toposorted(topolst)
for x,y in zip(golden, this_order:=block_reorder(topolst)):
if x is not y:
print_uops(golden)
print_uops(this_order)
self.assertIs(x, y)
if __name__ == '__main__':
unittest.main()

View file

@ -11,11 +11,12 @@ from tinygrad.spec import type_verify
def block_reorder(lst:list[UOp]) -> list[UOp]:
in_this_block = set(lst)
local_children: defaultdict[UOp, list[UOp]] = defaultdict(list)
in_degree: defaultdict[UOp, int] = defaultdict(int)
in_degree:dict[UOp, int] = {}
priorities:dict[UOp, int] = {}
# get local children and assign priorities
for u in reversed(lst):
in_degree[u] = 0
for s in u.src:
if s in in_this_block:
local_children[s].append(u)
@ -26,21 +27,18 @@ def block_reorder(lst:list[UOp]) -> list[UOp]:
if u.op is Ops.BARRIER: priority.append(-1500)
priorities[u] = min(priority)
# placement queue
queue:list[tuple[int, tuple, UOp]] = []
def push(u:UOp): heapq.heappush(queue, (priorities[u], u.tuplize, u))
# place the first ones that don't have deps
for u in lst:
if u not in in_degree: push(u)
# number the uops in "ideal" order
nkey = {u:i for i,u in enumerate(sorted(lst, key=lambda x: (priorities[x], x.tuplize)))}
# then force then to be toposorted in as close to the ideal order as possible
heapq.heapify(heap:=[(nkey[u],u) for u in lst if in_degree[u] == 0])
newlst = []
while queue:
_,_,x = heapq.heappop(queue)
newlst.append(x)
for u in local_children[x]:
in_degree[u] -= 1
if in_degree[u] == 0: push(u)
while heap:
_,u = heapq.heappop(heap)
newlst.append(u)
for v in local_children[u]:
in_degree[v] -= 1
if in_degree[v] == 0: heapq.heappush(heap, (nkey[v],v))
assert len(newlst) == len(lst), f"len mismatch {len(newlst)} != {len(lst)}"
return newlst