mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
track metadata with uops [pr] (#7188)
This commit is contained in:
parent
5551cf6689
commit
37b829ef0d
1 changed files with 8 additions and 6 deletions
|
|
@ -155,17 +155,18 @@ if getenv("RUN_PROCESS_REPLAY"):
|
|||
|
||||
# *** List[LazyBuffer] lowering to ScheduleItem ***
|
||||
|
||||
def to_uop(buf:LazyBuffer, outputs:List[LazyBuffer], inputs:List[LazyBuffer], buf_uops:Dict[Buffer, UOp], cache:Dict[LazyBuffer, UOp]) -> UOp:
|
||||
def to_uop(buf:LazyBuffer, outputs:List[LazyBuffer], inputs:List[LazyBuffer], buf_uops:Dict[Buffer, UOp], metadata:Dict[UOp, Metadata],
|
||||
cache:Dict[LazyBuffer, UOp]) -> UOp:
|
||||
if (r:=cache.get(buf)) is not None: return r
|
||||
if buf is not buf.base:
|
||||
cache[buf] = ret = to_uop(buf.base, outputs, inputs, buf_uops, cache).view(buf.st)
|
||||
cache[buf] = ret = to_uop(buf.base, outputs, inputs, buf_uops, metadata, cache).view(buf.st)
|
||||
return ret
|
||||
if buf.op is MetaOps.CONST: return buf_uops[buf.buffer]
|
||||
dtype = buf.dtype.base if isinstance(buf.dtype, ImageDType) else buf.dtype
|
||||
if (ubuf:=buf_uops.get(buf.buffer)) is not None and buf not in outputs:
|
||||
if not any(x.buffer is buf.buffer for x in outputs) and buf not in inputs: inputs.append(buf)
|
||||
return UOp.load(ubuf, buf.st.to_uop(), dtype=dtype)
|
||||
src = tuple(to_uop(x, outputs, inputs, buf_uops, cache) for x in buf.srcs)
|
||||
src = tuple(to_uop(x, outputs, inputs, buf_uops, metadata, cache) for x in buf.srcs)
|
||||
if buf.op in ReduceOps: ret = src[0].r(buf.op, buf.arg)
|
||||
elif buf.op is MetaOps.CONTIGUOUS: ret = UOp(UOps.CONTIGUOUS, dtype, src)
|
||||
elif buf.op is MetaOps.ASSIGN: ret = UOp(UOps.ASSIGN, dtype, (buf_uops[buf.buffer], src[1]), buf.arg)
|
||||
|
|
@ -174,14 +175,16 @@ def to_uop(buf:LazyBuffer, outputs:List[LazyBuffer], inputs:List[LazyBuffer], bu
|
|||
elif buf.op is UnaryOps.BITCAST: ret = UOp(UOps.BITCAST, dtype, src)
|
||||
else: ret = UOp(UOps.ALU, dtype, src, buf.op)
|
||||
cache[buf] = ret
|
||||
if buf.metadata is not None: metadata[ret] = buf.metadata
|
||||
return ret
|
||||
|
||||
def _lower_lazybuffer(outs:List[LazyBuffer], buf_uops:Dict[Buffer, UOp], var_vals:Dict[Variable, int]) -> LBScheduleItem:
|
||||
"""describe the computation for a LazyBuffer with UOp + inputs + var_vals"""
|
||||
cache: Dict[LazyBuffer, UOp] = {}
|
||||
inputs: List[LazyBuffer] = []
|
||||
metadata: Dict[UOp, Metadata] = {}
|
||||
sink = UOp(UOps.SINK, src=tuple(UOp.store(buf_uops[out.buffer], ShapeTracker.from_shape(out.shape).to_uop(),
|
||||
to_uop(out, outs, inputs, buf_uops, cache)) for out in outs))
|
||||
to_uop(out, outs, inputs, buf_uops, metadata, cache)) for out in outs))
|
||||
sink = full_ast_rewrite(sink, tuple(buf_uops[x.buffer].arg[0] for x in outs+inputs), var_vals)
|
||||
# we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine
|
||||
if len(assign_targets:=[x.src[0] for x in sink.sparents if x.op is UOps.ASSIGN]) != 0:
|
||||
|
|
@ -189,8 +192,7 @@ def _lower_lazybuffer(outs:List[LazyBuffer], buf_uops:Dict[Buffer, UOp], var_val
|
|||
and ShapeTracker.from_shape(s.shape).shrink(m) == s.shrink(m)) for x in sink.sparents if x.op is UOps.LOAD and x.src[0] in assign_targets):
|
||||
raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
|
||||
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
|
||||
return LBScheduleItem(sink, tuple(outs+inputs),
|
||||
tuple(dedup([x.metadata for x in cache if x.metadata is not None and (x.base in outs or x.base.buffer not in buf_uops)])))
|
||||
return LBScheduleItem(sink, tuple(outs+inputs), tuple(dedup(metadata.values())))
|
||||
|
||||
# *** DAG creation: decide which LazyBuffers should realize ***
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue