mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
try 2 on VIEW(BUFFER, <op>) scheduling + spec [pr] (#8377)
* second iteration on VIEW(BUFFER, <op>) scheduling + spec [pr] * image * notes
This commit is contained in:
parent
b7397c1322
commit
e6f4c24619
1 changed files with 22 additions and 17 deletions
|
|
@ -546,30 +546,35 @@ do_realize = PatternMatcher([
|
|||
(UPat(Ops.ASSIGN, src=(UPat(Ops.VIEW, name="dest"), UPat.var("src")), name="x"), lambda dest,src,x: x.replace(src=(dest.base.buf_uop, src))),
|
||||
])
|
||||
|
||||
# ** this breaks down realized ops into STOREs and rewrites the ops to LOADs
|
||||
# **** rewrite VIEW into LOAD/STORE/VALID or fuse the underlying UOp
|
||||
|
||||
def generate_const(x:UOp, st:UOp):
|
||||
# NOTE: masked VIEW stacks on top of the CONST, this is required for const folding correctness
|
||||
assert all(v.mask is None for v in unwrap(st.st).views), f"ShapeTracker of CONST must be unmasked, got {st}"
|
||||
return UOp(Ops.VALID, dtypes.bool, (unwrap(st.st).to_uop(),)).where(x.replace(dtype=x.dtype.base), 0)
|
||||
|
||||
def unbind_variable(ctx:ScheduleContext, bind:UOp, st:UOp):
|
||||
ctx.var_vals.update([bind.unbind()])
|
||||
return UOp.const_with_shape(bind.dtype, bind, st.shape)
|
||||
return generate_const(UOp.const(bind.dtype, bind), st)
|
||||
|
||||
def append_realize(ctx:ScheduleContext, b:UOp, to_store:UOp, base:UOp) -> UOp:
|
||||
ctx.realizes[b] = UOp.store(b, ShapeTracker.from_shape(base.shape).to_uop(), append_op(ctx, b, to_store))
|
||||
return UOp(Ops.LOAD, base.dtype, (b, unwrap(base.st).to_uop()))
|
||||
def load_realized(ctx:ScheduleContext, b:UOp, st:UOp):
|
||||
assert st.size == b.size and unwrap(st.st).contiguous, f"ShapeTracker of realized {b} BUFFER must match the BUFFER size {st}"
|
||||
# NOTE: if we're assigning to the BUFFER too, PRELOAD tells toposort to place this load before the ASSIGN
|
||||
return UOp(Ops.PRELOAD if b in ctx.assigns else Ops.LOAD, b.dtype.base, (b, unwrap(st.st).to_uop()))
|
||||
|
||||
def append_op(ctx:ScheduleContext, b:UOp, to_store:UOp) -> UOp:
|
||||
# TODO: metadata post merge
|
||||
if (m:=ctx.tensor_uops[b][0].metadata) is not None: ctx.ops_metadata[to_store] = m
|
||||
return to_store
|
||||
def store_or_fuse(ctx:ScheduleContext, b:UOp, x:UOp, st:UOp):
|
||||
if (m:=ctx.tensor_uops[b][0].metadata) is not None: ctx.ops_metadata[x] = m
|
||||
if b not in ctx.realizes: return x # collapse BUFFER
|
||||
ctx.realizes[b] = UOp.store(b, ShapeTracker.from_shape(st.shape).to_uop(), x)
|
||||
return UOp(Ops.LOAD, x.dtype, (b, unwrap(st.st).to_uop()))
|
||||
|
||||
break_sched = PatternMatcher([
|
||||
# consts are always fused and generated
|
||||
(UPat(Ops.VIEW, name="root", src=(UPat(), UPat.cvar())), lambda root: UOp.const_with_shape(root.dtype.base, root.const_arg, root.shape)),
|
||||
# values from BIND append to this schedule's var_vals
|
||||
(UPat(Ops.VIEW, name="st", src=(UPat(), UPat(Ops.BIND, name="bind"))), unbind_variable),
|
||||
# view of realized buffer just loads
|
||||
(UPat(Ops.BUFFER, name="b").view(name="v"), lambda ctx,b,v: UOp(Ops.PRELOAD if b in ctx.assigns else Ops.LOAD, b.dtype.base, (b, v.st.to_uop()))),
|
||||
# all other views either fold or realize with a store
|
||||
(UPatScheduled(), lambda ctx,b,to_store,base: append_realize(ctx, b, to_store, base) if b in ctx.realizes else append_op(ctx, b, to_store)),
|
||||
# CONST is always fused and generated
|
||||
(UPat(Ops.VIEW, name="st", src=(UPat(Ops.DEVICE), UPat(Ops.CONST, name="x"))), generate_const),
|
||||
(UPat(Ops.VIEW, name="st", src=(UPat(Ops.DEVICE), UPat(Ops.BIND, name="bind"))), unbind_variable),
|
||||
# VIEW of BUFFER either becomes a LOAD/STORE or we fuse it
|
||||
(UPat(Ops.VIEW, name="st", src=(UPat(Ops.BUFFER, name="b"),)), load_realized),
|
||||
(UPat(Ops.VIEW, name="st", src=(UPat(Ops.BUFFER, name="b"), UPat.var("x"))), store_or_fuse),
|
||||
])
|
||||
|
||||
# **** Schedule context builder
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue