cleanup view on reduce [pr] (#7081)

This commit is contained in:
qazal 2024-10-16 05:22:52 +03:00 committed by GitHub
commit 207fbc4bc7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -62,21 +62,20 @@ def permute_reduce(input_st:ShapeTracker, axis:Tuple[int, ...]) -> Tuple[ShapeTr
# ** reduceop fusor
def push_swizzle_up_through_reduce(swizzle:UOp, reduceop:UOp) -> Optional[UOp]:
if (swizzle_st:=unwrap(swizzle.st)).contiguous: return None
rsrc = reduceop.src[0]
tmp, rshape = permute_reduce(ShapeTracker.from_shape(unwrap(rsrc.st).shape), reduceop.axis_arg)
def view_r(view:UOp, r:UOp, rsrc:UOp) -> Optional[UOp]:
if (st:=unwrap(view.st)).contiguous: return None
tmp, rshape = permute_reduce(ShapeTracker.from_shape(unwrap(rsrc.st).shape), r.axis_arg)
prshape = prod(rshape)
strides = strides_for_shape(rshape)
nv: List[View] = []
for v in swizzle_st.views:
for v in st.views:
nv.append(View.create(v.shape+rshape, tuple(x*prshape for x in v.strides)+strides,
v.offset*prshape, v.mask+tuple((0,s) for s in rshape) if v.mask is not None else None))
# update input_st and axis
new_input_st = tmp + ShapeTracker(tuple(nv))
_, new_rshape = permute_reduce(new_input_st, reduceop.axis_arg)
_, new_rshape = permute_reduce(new_input_st, r.axis_arg)
new_axis = tuple(range(len(new_input_st.shape)-len(new_rshape), len(new_input_st.shape)))
return st_fixup(rsrc, lambda st:st+new_input_st, {}).r(reduceop.arg[0], new_axis).view(ShapeTracker.from_shape(swizzle_st.shape))
return st_fixup(rsrc, lambda st:st+new_input_st, {}).r(r.arg[0], new_axis).view(ShapeTracker.from_shape(st.shape))
def push_swizzle_down_through_reduce(root:UOp, swizzle:UOp) -> UOp:
swizzle_st, src_st = unwrap(swizzle.st), unwrap(swizzle.src[0].st)
@ -107,7 +106,7 @@ merge_views = PatternMatcher([(UPat(UOps.VIEW, src=(UPat(UOps.VIEW, name="s0"),)
# push VIEW to loads
view_left = merge_views+PatternMatcher([
# view on reduce
(UPat(UOps.VIEW, src=(UPat(UOps.REDUCE_AXIS, name="reduceop"),), name="swizzle"), push_swizzle_up_through_reduce),
(UPat(UOps.VIEW, src=(UPat(UOps.REDUCE_AXIS, src=UPat.var("rsrc"), name="r"),), name="view"), view_r),
# view on elementwise
(UPat(UOps.VIEW, src=(UPat((UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.ASSIGN, UOps.CONTIGUOUS, *BUFFER_UOPS), name="e"),), name="v"),
lambda e,v: e.replace(src=tuple(s.view(v.st) if s.has_st else s for s in e.src))),