try 2 on VIEW(BUFFER, <op>) scheduling + spec [pr] (#8377)

* second iteration on VIEW(BUFFER, <op>) scheduling + spec [pr]

* image

* notes
This commit is contained in:
qazal 2024-12-22 16:30:35 +02:00 committed by GitHub
commit e6f4c24619
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -546,30 +546,35 @@ do_realize = PatternMatcher([
(UPat(Ops.ASSIGN, src=(UPat(Ops.VIEW, name="dest"), UPat.var("src")), name="x"), lambda dest,src,x: x.replace(src=(dest.base.buf_uop, src))),
])
# ** this breaks down realized ops into STOREs and rewrites the ops to LOADs
# **** rewrite VIEW into LOAD/STORE/VALID or fuse the underlying UOp
def generate_const(x:UOp, st:UOp):
# NOTE: masked VIEW stacks on top of the CONST, this is required for const folding correctness
assert all(v.mask is None for v in unwrap(st.st).views), f"ShapeTracker of CONST must be unmasked, got {st}"
return UOp(Ops.VALID, dtypes.bool, (unwrap(st.st).to_uop(),)).where(x.replace(dtype=x.dtype.base), 0)
def unbind_variable(ctx:ScheduleContext, bind:UOp, st:UOp):
ctx.var_vals.update([bind.unbind()])
return UOp.const_with_shape(bind.dtype, bind, st.shape)
return generate_const(UOp.const(bind.dtype, bind), st)
def append_realize(ctx:ScheduleContext, b:UOp, to_store:UOp, base:UOp) -> UOp:
ctx.realizes[b] = UOp.store(b, ShapeTracker.from_shape(base.shape).to_uop(), append_op(ctx, b, to_store))
return UOp(Ops.LOAD, base.dtype, (b, unwrap(base.st).to_uop()))
def load_realized(ctx:ScheduleContext, b:UOp, st:UOp):
assert st.size == b.size and unwrap(st.st).contiguous, f"ShapeTracker of realized {b} BUFFER must match the BUFFER size {st}"
# NOTE: if we're assigning to the BUFFER too, PRELOAD tells toposort to place this load before the ASSIGN
return UOp(Ops.PRELOAD if b in ctx.assigns else Ops.LOAD, b.dtype.base, (b, unwrap(st.st).to_uop()))
def append_op(ctx:ScheduleContext, b:UOp, to_store:UOp) -> UOp:
# TODO: metadata post merge
if (m:=ctx.tensor_uops[b][0].metadata) is not None: ctx.ops_metadata[to_store] = m
return to_store
def store_or_fuse(ctx:ScheduleContext, b:UOp, x:UOp, st:UOp):
if (m:=ctx.tensor_uops[b][0].metadata) is not None: ctx.ops_metadata[x] = m
if b not in ctx.realizes: return x # collapse BUFFER
ctx.realizes[b] = UOp.store(b, ShapeTracker.from_shape(st.shape).to_uop(), x)
return UOp(Ops.LOAD, x.dtype, (b, unwrap(st.st).to_uop()))
break_sched = PatternMatcher([
# consts are always fused and generated
(UPat(Ops.VIEW, name="root", src=(UPat(), UPat.cvar())), lambda root: UOp.const_with_shape(root.dtype.base, root.const_arg, root.shape)),
# values from BIND append to this schedule's var_vals
(UPat(Ops.VIEW, name="st", src=(UPat(), UPat(Ops.BIND, name="bind"))), unbind_variable),
# view of realized buffer just loads
(UPat(Ops.BUFFER, name="b").view(name="v"), lambda ctx,b,v: UOp(Ops.PRELOAD if b in ctx.assigns else Ops.LOAD, b.dtype.base, (b, v.st.to_uop()))),
# all other views either fold or realize with a store
(UPatScheduled(), lambda ctx,b,to_store,base: append_realize(ctx, b, to_store, base) if b in ctx.realizes else append_op(ctx, b, to_store)),
# CONST is always fused and generated
(UPat(Ops.VIEW, name="st", src=(UPat(Ops.DEVICE), UPat(Ops.CONST, name="x"))), generate_const),
(UPat(Ops.VIEW, name="st", src=(UPat(Ops.DEVICE), UPat(Ops.BIND, name="bind"))), unbind_variable),
# VIEW of BUFFER either becomes a LOAD/STORE or we fuse it
(UPat(Ops.VIEW, name="st", src=(UPat(Ops.BUFFER, name="b"),)), load_realized),
(UPat(Ops.VIEW, name="st", src=(UPat(Ops.BUFFER, name="b"), UPat.var("x"))), store_or_fuse),
])
# **** Schedule context builder