group stores by buffer uops [pr] (#7190)

* group stores by buffer uops [pr]

* dedup
This commit is contained in:
qazal 2024-10-21 18:04:44 +03:00 committed by GitHub
commit bc9eb324dc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -347,13 +347,13 @@ def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem]
if DEBUG_ARANGE: print(colored(f"folding {r}", "green"))
for tr in group: del realizes[tr]
output_groups: DefaultDict[LazyBuffer, List[LazyBuffer]] = defaultdict(list)
output_groups: DefaultDict[LazyBuffer, List[UOp]] = defaultdict(list)
buf_uops: Dict[Buffer, UOp] = {}
uop_bufs: Dict[UOp, Buffer] = {}
var_vals: Dict[Variable, int] = {}
lazybufs_to_realize: Dict[Buffer, LazyBuffer] = {}
for buf in realizes:
if buf.realized is None and buf.op is not MetaOps.CONST:
output_groups[reduce_for_op.get(buf, buf)].append(buf)
if (dup:=lazybufs_to_realize.get(buf.buffer)) is not None:
raise RuntimeError(f"can't double realize in one schedule, Buffer is realizing both {dup} and {buf}")
lazybufs_to_realize[buf.buffer] = buf
@ -373,10 +373,13 @@ def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem]
uop = UOp(UOps.VALID, dtypes.bool, (buf.st.to_uop(),)).where(v:=UOp.const(buf.dtype.scalar(), buf.arg), v.const_like(0))
# NOTE: UOps.BUFFER creation must come after the ImageDType fixup
else: uop = UOp(UOps.BUFFER, buf.buffer.dtype.ptr(), (), (len(buf_uops), (buf.buffer.device, buf.buffer.size, buf.buffer.dtype)))
buf_uops.setdefault(buf.buffer, uop)
if buf.buffer not in buf_uops:
buf_uops[buf.buffer] = uop
uop_bufs[uop] = buf.buffer
if buf.realized is None and buf.op is not MetaOps.CONST: output_groups[reduce_for_op.get(buf, buf)].append(buf_uops[buf.buffer])
# preschedule all buffers in realizes
prescheduled = [_lower_lazybuffer(outs, buf_uops, var_vals) for outs in output_groups.values()]
prescheduled = [_lower_lazybuffer([lazybufs_to_realize[uop_bufs[b]] for b in outs], buf_uops, var_vals) for outs in output_groups.values()]
schedule_targets = {out:lsi for lsi in prescheduled for out in lsi.outputs}
graph: DefaultDict[LBScheduleItem, List[LBScheduleItem]] = defaultdict(list)