mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
parent
fb74f75485
commit
2bfdf85f87
1 changed files with 10 additions and 10 deletions
|
|
@ -410,6 +410,14 @@ pm_lift_patches_to_cmdbuf = PatternMatcher([
|
|||
(UPat(Ops.AFTER, src=(UPat(Ops.BUFFER, tag={"compute", "copy"}),), allow_any_len=True, name="cmdbuf"), lift_patches_to_cmdbuf),
|
||||
])
|
||||
|
||||
# *****************
|
||||
# 5.3. capture buffers reachable from each hcq call as BIND, so we don't drop their refs
|
||||
|
||||
def hold_call_buffers(call:UOp) -> UOp|None:
|
||||
if not (bufs:=tuple(dedup(u for u in call.src[0].toposort() if u.op is Ops.BUFFER and u not in call.src))): return None
|
||||
return call.replace(src=call.src + (UOp(Ops.BIND, dtypes.void, src=bufs),))
|
||||
pm_hold_call_buffers = PatternMatcher([(UPat(Ops.CALL, tag="hcq", name="call"), hold_call_buffers)])
|
||||
|
||||
# *****************
|
||||
# 6. bufferize placeholders: replace placeholders with real buffers.
|
||||
|
||||
|
|
@ -420,15 +428,7 @@ def bufferize_buf(buf:UOp) -> UOp|None:
|
|||
pm_bufferize = PatternMatcher([(UPat(Ops.BUFFER, name="buf"), bufferize_buf)])
|
||||
|
||||
# *****************
|
||||
# 7.1. capture buffers reachable from each hcq call as BIND, so we don't drop their refs
|
||||
|
||||
def hold_call_buffers(call:UOp) -> UOp|None:
|
||||
if not (bufs:=tuple(dedup(u for u in call.src[0].toposort() if u.op is Ops.BUFFER and u not in call.src))): return None
|
||||
return call.replace(src=call.src + (UOp(Ops.BIND, dtypes.void, src=bufs),))
|
||||
pm_hold_call_buffers = PatternMatcher([(UPat(Ops.CALL, tag="hcq", name="call"), hold_call_buffers)])
|
||||
|
||||
# *****************
|
||||
# 7.2. resolve patches
|
||||
# 7. resolve patches
|
||||
|
||||
def push_stack(op, s): return UOp(Ops.STACK, op.dtype.scalar().vec(len(s.src)),
|
||||
tuple(op.replace(dtype=op.dtype.scalar(), src=tuple(x if y is s else y for y in op.src)) for x in s.src))
|
||||
|
|
@ -494,10 +494,10 @@ def hcq_schedule(linear:UOp) -> UOp:
|
|||
linear = graph_rewrite(linear, pm_add_global_sync, ctx=set(), walk=True, name="add global sync", enter_calls=True)
|
||||
linear = graph_rewrite(linear, pm_encode_cmdbufs, walk=True, name="encode cmdbufs", enter_calls=True)
|
||||
linear = graph_rewrite(linear, pm_lift_patches_to_cmdbuf, name="lift patches to cmdbuf", enter_calls=True)
|
||||
linear = graph_rewrite(linear, pm_hold_call_buffers, walk=True, name="hold call buffers")
|
||||
|
||||
# realize starts from here
|
||||
linear = graph_rewrite(linear, pm_bufferize, bottom_up=True, name="bufferize placeholders", enter_calls=True)
|
||||
linear = graph_rewrite(linear, pm_hold_call_buffers, walk=True, name="hold call buffers")
|
||||
linear = graph_rewrite(linear, pm_resolve_patches, bottom_up=False, name="simplify patches", enter_calls=True)
|
||||
linear = graph_rewrite(linear, pm_parametrize_host_buffers, name="parametrize host buffers")
|
||||
linear = graph_rewrite(linear, pm_callify_hcq, name="callify hcq")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue