track metadata with uops [pr] (#7188)

This commit is contained in:
qazal 2024-10-21 16:35:46 +03:00 committed by GitHub
commit 37b829ef0d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -155,17 +155,18 @@ if getenv("RUN_PROCESS_REPLAY"):
# *** List[LazyBuffer] lowering to ScheduleItem ***
def to_uop(buf:LazyBuffer, outputs:List[LazyBuffer], inputs:List[LazyBuffer], buf_uops:Dict[Buffer, UOp], cache:Dict[LazyBuffer, UOp]) -> UOp:
def to_uop(buf:LazyBuffer, outputs:List[LazyBuffer], inputs:List[LazyBuffer], buf_uops:Dict[Buffer, UOp], metadata:Dict[UOp, Metadata],
cache:Dict[LazyBuffer, UOp]) -> UOp:
if (r:=cache.get(buf)) is not None: return r
if buf is not buf.base:
cache[buf] = ret = to_uop(buf.base, outputs, inputs, buf_uops, cache).view(buf.st)
cache[buf] = ret = to_uop(buf.base, outputs, inputs, buf_uops, metadata, cache).view(buf.st)
return ret
if buf.op is MetaOps.CONST: return buf_uops[buf.buffer]
dtype = buf.dtype.base if isinstance(buf.dtype, ImageDType) else buf.dtype
if (ubuf:=buf_uops.get(buf.buffer)) is not None and buf not in outputs:
if not any(x.buffer is buf.buffer for x in outputs) and buf not in inputs: inputs.append(buf)
return UOp.load(ubuf, buf.st.to_uop(), dtype=dtype)
src = tuple(to_uop(x, outputs, inputs, buf_uops, cache) for x in buf.srcs)
src = tuple(to_uop(x, outputs, inputs, buf_uops, metadata, cache) for x in buf.srcs)
if buf.op in ReduceOps: ret = src[0].r(buf.op, buf.arg)
elif buf.op is MetaOps.CONTIGUOUS: ret = UOp(UOps.CONTIGUOUS, dtype, src)
elif buf.op is MetaOps.ASSIGN: ret = UOp(UOps.ASSIGN, dtype, (buf_uops[buf.buffer], src[1]), buf.arg)
@ -174,14 +175,16 @@ def to_uop(buf:LazyBuffer, outputs:List[LazyBuffer], inputs:List[LazyBuffer], bu
elif buf.op is UnaryOps.BITCAST: ret = UOp(UOps.BITCAST, dtype, src)
else: ret = UOp(UOps.ALU, dtype, src, buf.op)
cache[buf] = ret
if buf.metadata is not None: metadata[ret] = buf.metadata
return ret
def _lower_lazybuffer(outs:List[LazyBuffer], buf_uops:Dict[Buffer, UOp], var_vals:Dict[Variable, int]) -> LBScheduleItem:
"""describe the computation for a LazyBuffer with UOp + inputs + var_vals"""
cache: Dict[LazyBuffer, UOp] = {}
inputs: List[LazyBuffer] = []
metadata: Dict[UOp, Metadata] = {}
sink = UOp(UOps.SINK, src=tuple(UOp.store(buf_uops[out.buffer], ShapeTracker.from_shape(out.shape).to_uop(),
to_uop(out, outs, inputs, buf_uops, cache)) for out in outs))
to_uop(out, outs, inputs, buf_uops, metadata, cache)) for out in outs))
sink = full_ast_rewrite(sink, tuple(buf_uops[x.buffer].arg[0] for x in outs+inputs), var_vals)
# we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine
if len(assign_targets:=[x.src[0] for x in sink.sparents if x.op is UOps.ASSIGN]) != 0:
@ -189,8 +192,7 @@ def _lower_lazybuffer(outs:List[LazyBuffer], buf_uops:Dict[Buffer, UOp], var_val
and ShapeTracker.from_shape(s.shape).shrink(m) == s.shrink(m)) for x in sink.sparents if x.op is UOps.LOAD and x.src[0] in assign_targets):
raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
return LBScheduleItem(sink, tuple(outs+inputs),
tuple(dedup([x.metadata for x in cache if x.metadata is not None and (x.base in outs or x.base.buffer not in buf_uops)])))
return LBScheduleItem(sink, tuple(outs+inputs), tuple(dedup(metadata.values())))
# *** DAG creation: decide which LazyBuffers should realize ***