all toposort

This commit is contained in:
George Hotz 2024-12-02 14:28:45 +08:00
commit db123adfda

View file

@ -130,8 +130,8 @@ def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> List[UOp]:
sink = graph_rewrite(sink, make_basic_blocks, ctx=(block_ctxs, children))
# add BLOCKFORK (slow!)
block_parent_count = collections.Counter(flatten([x.src for x in sink.sparents if x.op is Ops.BLOCK]))
non_block_parents = flatten([x.src for x in sink.sparents if x.op is not Ops.BLOCK])
block_parent_count = collections.Counter(flatten([x.src for x in sink.toposort if x.op is Ops.BLOCK]))
non_block_parents = flatten([x.src for x in sink.toposort if x.op is not Ops.BLOCK])
forks = {}
for u,child_count in block_parent_count.items():
if u.op not in DONT_PLACE_IN_BLOCK and child_count > 1 and u not in non_block_parents:
@ -142,7 +142,7 @@ def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> List[UOp]:
# combine matching BLOCKENDS
blockends_to_arg: Dict[UOp, List[UOp]] = {}
for be in sink.sparents:
for be in sink.toposort:
if be.op is Ops.BLOCKEND: blockends_to_arg.setdefault(be.arg.end, []).append(be)
new_forks = {}
for k,v in blockends_to_arg.items():