mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
parent
a51b5ba424
commit
5ebd44aa12
1 changed files with 38 additions and 1 deletions
|
|
@ -332,6 +332,42 @@ def add_stores(ctx:set[int], submit:UOp, q:UOp) -> UOp:
|
|||
return submit.replace(src=(q.replace(src=tuple(new_src)),))
|
||||
pm_add_inner_stores = PatternMatcher([(UPat(Ops.CUSTOM_FUNCTION, arg="submit", src=(UPat(Ops.LINEAR, name="q"),), name="submit"), add_stores)])
|
||||
|
||||
# *****************
|
||||
# 4.1. merge queues
|
||||
|
||||
def get_submit(ast:UOp) -> UOp: return next(u for u in ast.toposort() if u.op is Ops.CUSTOM_FUNCTION and u.arg == "submit")
|
||||
|
||||
def merge_sinks(old_sink:UOp, new_sink:UOp) -> UOp:
|
||||
old_submit, new_submit = get_submit(old_sink), get_submit(new_sink)
|
||||
old_queue, new_queue = old_submit.src[0], new_submit.src[0]
|
||||
merged_submit = new_submit.replace(src=(new_queue.replace(src=old_queue.src + new_queue.src),))
|
||||
old_root = old_sink.src[0].substitute({old_submit: merged_submit})
|
||||
new_anchor = merged_submit if old_sink.src[0] is old_submit else old_root
|
||||
return new_sink.substitute({new_submit: new_anchor})
|
||||
|
||||
def merge_queues(linear:UOp) -> UOp:
|
||||
new_src:list[UOp] = []
|
||||
opened_qs:dict[tuple[tuple[str, ...], str], UOp] = {} # (devs, queue) -> sink, kept in submit order
|
||||
|
||||
for call in linear.src:
|
||||
if call.tag != "hcq":
|
||||
new_src += [opened_qs.pop(k).call().rtag('hcq') for k in list(opened_qs)] + [call]
|
||||
continue
|
||||
|
||||
devs, queue = get_submit(new_sink:=call.src[0]).src[0].arg
|
||||
if (old_sink:=opened_qs.pop((devs, queue), None)) is not None:
|
||||
new_sink = merge_sinks(old_sink, new_sink) # exact same queue: merge, and re-insert at the end
|
||||
else:
|
||||
# no such queue opened: close every open submit on this queue that shares a device, so submit order is kept
|
||||
new_src += [opened_qs.pop(k).call().rtag('hcq') for k in [k for k in opened_qs if k[1] == queue and set(k[0]) & set(devs)]]
|
||||
opened_qs[(devs, queue)] = new_sink
|
||||
|
||||
return linear.replace(src=tuple(new_src + [sink.call().rtag('hcq') for sink in opened_qs.values()]))
|
||||
pm_merge_queues = PatternMatcher([(UPat(Ops.LINEAR, name="linear"), merge_queues)])
|
||||
|
||||
# *****************
|
||||
# 4.2. global sync
|
||||
|
||||
def add_global_sync(ctx:set[tuple[str, ...]], submit:UOp, q:UOp) -> UOp|None:
|
||||
if (devs:=q.arg[0]) in ctx: return None
|
||||
ctx.add(devs)
|
||||
|
|
@ -342,7 +378,7 @@ def add_global_sync(ctx:set[tuple[str, ...]], submit:UOp, q:UOp) -> UOp|None:
|
|||
pm_add_global_sync = PatternMatcher([(UPat(Ops.CUSTOM_FUNCTION, arg="submit", src=(UPat(Ops.LINEAR, name="q"),), name="submit"), add_global_sync)])
|
||||
|
||||
# *****************
|
||||
# 4.1. encode cmdbufs
|
||||
# 5.1. encode cmdbufs
|
||||
|
||||
@functools.cache
|
||||
def get_pm_lower(name:str) -> PatternMatcher|None:
|
||||
|
|
@ -447,6 +483,7 @@ def hcq_schedule(linear:UOp) -> UOp:
|
|||
linear = graph_rewrite(linear, pm_add_finalizer, ctx=deps_ctx, walk=True, name="add finalizer")
|
||||
linear = graph_rewrite(linear, pm_add_inner_loads + pm_flatten_linear, ctx=(waited:=set()), walk=True, name="add loads", enter_calls=True)
|
||||
linear = graph_rewrite(linear, pm_add_inner_stores, ctx=waited, walk=True, name="add stores", enter_calls=True)
|
||||
linear = graph_rewrite(linear, pm_merge_queues, name="merge queues")
|
||||
linear = graph_rewrite(linear, pm_add_global_sync, ctx=set(), walk=True, name="add global sync", enter_calls=True)
|
||||
linear = graph_rewrite(linear, pm_encode_cmdbufs, walk=True, name="encode cmdbufs", enter_calls=True)
|
||||
linear = graph_rewrite(linear, pm_lift_patches_to_cmdbuf, name="lift patches to cmdbuf", enter_calls=True)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue