mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
cleanup view on reduce [pr] (#7081)
This commit is contained in:
parent
067b35e915
commit
207fbc4bc7
1 changed files with 7 additions and 8 deletions
|
|
@ -62,21 +62,20 @@ def permute_reduce(input_st:ShapeTracker, axis:Tuple[int, ...]) -> Tuple[ShapeTr
|
|||
|
||||
# ** reduceop fusor
|
||||
|
||||
def push_swizzle_up_through_reduce(swizzle:UOp, reduceop:UOp) -> Optional[UOp]:
|
||||
if (swizzle_st:=unwrap(swizzle.st)).contiguous: return None
|
||||
rsrc = reduceop.src[0]
|
||||
tmp, rshape = permute_reduce(ShapeTracker.from_shape(unwrap(rsrc.st).shape), reduceop.axis_arg)
|
||||
def view_r(view:UOp, r:UOp, rsrc:UOp) -> Optional[UOp]:
|
||||
if (st:=unwrap(view.st)).contiguous: return None
|
||||
tmp, rshape = permute_reduce(ShapeTracker.from_shape(unwrap(rsrc.st).shape), r.axis_arg)
|
||||
prshape = prod(rshape)
|
||||
strides = strides_for_shape(rshape)
|
||||
nv: List[View] = []
|
||||
for v in swizzle_st.views:
|
||||
for v in st.views:
|
||||
nv.append(View.create(v.shape+rshape, tuple(x*prshape for x in v.strides)+strides,
|
||||
v.offset*prshape, v.mask+tuple((0,s) for s in rshape) if v.mask is not None else None))
|
||||
# update input_st and axis
|
||||
new_input_st = tmp + ShapeTracker(tuple(nv))
|
||||
_, new_rshape = permute_reduce(new_input_st, reduceop.axis_arg)
|
||||
_, new_rshape = permute_reduce(new_input_st, r.axis_arg)
|
||||
new_axis = tuple(range(len(new_input_st.shape)-len(new_rshape), len(new_input_st.shape)))
|
||||
return st_fixup(rsrc, lambda st:st+new_input_st, {}).r(reduceop.arg[0], new_axis).view(ShapeTracker.from_shape(swizzle_st.shape))
|
||||
return st_fixup(rsrc, lambda st:st+new_input_st, {}).r(r.arg[0], new_axis).view(ShapeTracker.from_shape(st.shape))
|
||||
|
||||
def push_swizzle_down_through_reduce(root:UOp, swizzle:UOp) -> UOp:
|
||||
swizzle_st, src_st = unwrap(swizzle.st), unwrap(swizzle.src[0].st)
|
||||
|
|
@ -107,7 +106,7 @@ merge_views = PatternMatcher([(UPat(UOps.VIEW, src=(UPat(UOps.VIEW, name="s0"),)
|
|||
# push VIEW to loads
|
||||
view_left = merge_views+PatternMatcher([
|
||||
# view on reduce
|
||||
(UPat(UOps.VIEW, src=(UPat(UOps.REDUCE_AXIS, name="reduceop"),), name="swizzle"), push_swizzle_up_through_reduce),
|
||||
(UPat(UOps.VIEW, src=(UPat(UOps.REDUCE_AXIS, src=UPat.var("rsrc"), name="r"),), name="view"), view_r),
|
||||
# view on elementwise
|
||||
(UPat(UOps.VIEW, src=(UPat((UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.ASSIGN, UOps.CONTIGUOUS, *BUFFER_UOPS), name="e"),), name="v"),
|
||||
lambda e,v: e.replace(src=tuple(s.view(v.st) if s.has_st else s for s in e.src))),
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue