mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
hcq2: move global sync (#16504)
This commit is contained in:
parent
0faaf6df26
commit
3838c8df1b
1 changed files with 15 additions and 22 deletions
|
|
@ -285,14 +285,6 @@ def schedule_inner_sync(ctx:DepsCtx, call:UOp) -> UOp:
|
|||
return op.after(*dps).rtag("deps") if (dps:=dedup(deps)) else op
|
||||
pm_schedule_inner_sync = PatternMatcher([(UPat(Ops.CALL, tag="hcq_cmd", name="call", allow_any_len=True), schedule_inner_sync)])
|
||||
|
||||
def add_global_sync(ctx:DepsCtx, submit:UOp, q:UOp) -> UOp|None:
|
||||
if (dev:=q.arg[0]) in ctx.last_per_queue: return None
|
||||
|
||||
# some devices from a command buffer might be used for the first time this schedule, so we wait for their global timeline epoch.
|
||||
wait = make_signal(dev).wait(make_signal_value(dev).index(UOp.const(dtypes.int, 0)) - 1)
|
||||
return submit.replace(src=(q.replace(src=(wait, *q.src)),))
|
||||
pm_add_global_sync = PatternMatcher([(UPat(Ops.CUSTOM_FUNCTION, arg="submit", src=(UPat(Ops.LINEAR, name="q"),), name="submit"), add_global_sync)])
|
||||
|
||||
# *****************
|
||||
# 3.2. finalizer
|
||||
|
||||
|
|
@ -328,24 +320,26 @@ def add_loads(ctx:set[int], deps:UOp) -> UOp:
|
|||
val = make_signal_value(cur_devs, queue=queue).index(UOp.const(dtypes.int, 0))
|
||||
waits.append(sig.wait(val + opid))
|
||||
return UOp(Ops.LINEAR, dtypes.void, (*waits, cur), arg=cur.tag[0])
|
||||
pm_add_loads = PatternMatcher([(UPat(Ops.AFTER, tag="deps", name="deps"), add_loads)])
|
||||
pm_add_inner_loads = PatternMatcher([(UPat(Ops.AFTER, tag="deps", name="deps"), add_loads)])
|
||||
|
||||
def add_stores(ctx:set[int], submit:UOp, lin:UOp) -> UOp:
|
||||
def add_stores(ctx:set[int], submit:UOp, q:UOp) -> UOp:
|
||||
new_src = []
|
||||
for op in lin.src:
|
||||
for op in q.src:
|
||||
new_src.append(op.rtag(None) if op.tag else op)
|
||||
if op.tag and op.tag[1] in ctx:
|
||||
(devs, queue), opid = op.tag
|
||||
new_src.append(make_signal(devs, queue=queue).store(make_signal_value(devs, queue=queue).index(UOp.const(dtypes.int, 0)) + opid))
|
||||
return submit.replace(src=(lin.replace(src=tuple(new_src)),))
|
||||
pm_add_stores = PatternMatcher([(UPat(Ops.CUSTOM_FUNCTION, arg="submit", src=(UPat(Ops.LINEAR, name="lin"),), name="submit"), add_stores)])
|
||||
return submit.replace(src=(q.replace(src=tuple(new_src)),))
|
||||
pm_add_inner_stores = PatternMatcher([(UPat(Ops.CUSTOM_FUNCTION, arg="submit", src=(UPat(Ops.LINEAR, name="q"),), name="submit"), add_stores)])
|
||||
|
||||
# *****************
|
||||
# 3.4. barriers
|
||||
def add_global_sync(ctx:set[tuple[str, ...]], submit:UOp, q:UOp) -> UOp|None:
|
||||
if (devs:=q.arg[0]) in ctx: return None
|
||||
ctx.add(devs)
|
||||
|
||||
def add_barriers(submit:UOp, lin:UOp) -> UOp:
|
||||
return submit.replace(src=(lin.replace(src=(UOp(Ops.BARRIER, dtypes.void), *lin.src)),))
|
||||
pm_add_barriers = PatternMatcher([(UPat(Ops.CUSTOM_FUNCTION, arg="submit", src=(UPat(Ops.LINEAR, name="lin"),), name="submit"), add_barriers)])
|
||||
# some devices from a command buffer might be used for the first time this schedule, so we wait for their global timeline epoch.
|
||||
wait = make_signal(devs).wait(make_signal_value(devs).index(UOp.const(dtypes.int, 0)) - 1)
|
||||
return submit.replace(src=(q.replace(src=(UOp(Ops.BARRIER, dtypes.void), wait, *q.src)),))
|
||||
pm_add_global_sync = PatternMatcher([(UPat(Ops.CUSTOM_FUNCTION, arg="submit", src=(UPat(Ops.LINEAR, name="q"),), name="submit"), add_global_sync)])
|
||||
|
||||
# *****************
|
||||
# 4.1. encode cmdbufs
|
||||
|
|
@ -450,11 +444,10 @@ def hcq_schedule(linear:UOp) -> UOp:
|
|||
|
||||
linear = graph_rewrite(linear, pm_lower_ops, name="lower ops into hcq ir")
|
||||
linear = graph_rewrite(linear, pm_schedule_inner_sync, ctx=(deps_ctx:=DepsCtx()), walk=True, name="schedule inner sync", enter_calls=True)
|
||||
linear = graph_rewrite(linear, pm_add_global_sync, ctx=deps_ctx, walk=True, name="add global sync", enter_calls=True)
|
||||
linear = graph_rewrite(linear, pm_add_finalizer, ctx=deps_ctx, walk=True, name="add finalizer")
|
||||
linear = graph_rewrite(linear, pm_add_loads + pm_flatten_linear, ctx=(waited:=set()), walk=True, name="add loads", enter_calls=True)
|
||||
linear = graph_rewrite(linear, pm_add_stores, ctx=waited, walk=True, name="add stores", enter_calls=True)
|
||||
linear = graph_rewrite(linear, pm_add_barriers, walk=True, name="add barriers", enter_calls=True)
|
||||
linear = graph_rewrite(linear, pm_add_inner_loads + pm_flatten_linear, ctx=(waited:=set()), walk=True, name="add loads", enter_calls=True)
|
||||
linear = graph_rewrite(linear, pm_add_inner_stores, ctx=waited, walk=True, name="add stores", enter_calls=True)
|
||||
linear = graph_rewrite(linear, pm_add_global_sync, ctx=set(), walk=True, name="add global sync", enter_calls=True)
|
||||
linear = graph_rewrite(linear, pm_encode_cmdbufs, walk=True, name="encode cmdbufs", enter_calls=True)
|
||||
linear = graph_rewrite(linear, pm_lift_patches_to_cmdbuf, name="lift patches to cmdbuf", enter_calls=True)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue