fix rangeify indexing for pad/reduce (#16599)

This commit is contained in:
chenyu 2026-06-12 20:26:15 -04:00 committed by GitHub
commit aa32d309db
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 12 additions and 6 deletions

View file

@ -61,6 +61,14 @@ class TestSchedule(unittest.TestCase):
# NOTE: the gradient flows twice
np.testing.assert_allclose(out.numpy(), 2*np.ones((64,64)))
def test_pad_reduce_scope_collision(self):
b = Tensor.rand(4, 3).realize()
s1 = b.pad(((1, 1), (0, 0))).sum(axis=1)
s2 = b.pad(((1, 2), (0, 0))).shrink(((0, 6), (0, 3))).sum(axis=1)
out = s1 + s2
run_linear(*check_schedule(out, 1))
np.testing.assert_allclose(out.numpy(), 2*np.pad(b.numpy(), ((1, 1), (0, 0))).sum(axis=1), rtol=1e-6)
def test_cumsum_parallel_reduce_fused(self):
# two-stage cumsum + ops triggers parallel REDUCEs in one kernel that must share an END (same nesting context = should merge)
step, num_steps = 513, 10

View file

@ -82,18 +82,16 @@ def create_bufferize_and_index_based_on_ranges(ctx:IndexingContext, x:UOp):
def convert_pad_to_where_to_keep_behavior_local(ctx:IndexingContext, x:UOp):
if x not in ctx.range_map: return None
if (bx:=create_bufferize_and_index_based_on_ranges(ctx, x)) is None: bx = x
valid: UOp = UOp.const(dtypes.bool, True).uprod([r.get_valid() for r in ctx.range_map[x][0]])
ret = valid.where(x.src[0], UOp.const(x.dtype, 0))
ctx.range_map[ret] = ctx.range_map[x]
return ret
return valid.where(bx.src[0], UOp.const(x.dtype, 0))
def convert_reduce_to_reduce_with_ranges(ctx:IndexingContext, x:UOp):
if len(x.arg[1]) == 0: return None
if (bx:=create_bufferize_and_index_based_on_ranges(ctx, x)) is None: bx = x
# input ranges
new_ranges = [r for i,r in enumerate(ctx.range_map[x][0]) if i in x.arg[1]]
ret = UOp(Ops.REDUCE, x.dtype, src=(x.src[0],)+tuple(new_ranges), arg=(x.arg[0], ()))
ctx.range_map[ret] = ctx.range_map[x]
return ret
return UOp(Ops.REDUCE, x.dtype, src=(bx.src[0],)+tuple(new_ranges), arg=(x.arg[0], ()))
def remove_movement_op_after_rangeify(ctx:IndexingContext, x:UOp):
if x in ctx.range_map or x.src[0].op is Ops.INDEX: return x.src[0]