mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
fix rangeify indexing for pad/reduce (#16599)
This commit is contained in:
parent
96b86aad7b
commit
aa32d309db
2 changed files with 12 additions and 6 deletions
|
|
@ -61,6 +61,14 @@ class TestSchedule(unittest.TestCase):
|
|||
# NOTE: the gradient flows twice
|
||||
np.testing.assert_allclose(out.numpy(), 2*np.ones((64,64)))
|
||||
|
||||
def test_pad_reduce_scope_collision(self):
|
||||
b = Tensor.rand(4, 3).realize()
|
||||
s1 = b.pad(((1, 1), (0, 0))).sum(axis=1)
|
||||
s2 = b.pad(((1, 2), (0, 0))).shrink(((0, 6), (0, 3))).sum(axis=1)
|
||||
out = s1 + s2
|
||||
run_linear(*check_schedule(out, 1))
|
||||
np.testing.assert_allclose(out.numpy(), 2*np.pad(b.numpy(), ((1, 1), (0, 0))).sum(axis=1), rtol=1e-6)
|
||||
|
||||
def test_cumsum_parallel_reduce_fused(self):
|
||||
# two-stage cumsum + ops triggers parallel REDUCEs in one kernel that must share an END (same nesting context = should merge)
|
||||
step, num_steps = 513, 10
|
||||
|
|
|
|||
|
|
@ -82,18 +82,16 @@ def create_bufferize_and_index_based_on_ranges(ctx:IndexingContext, x:UOp):
|
|||
|
||||
def convert_pad_to_where_to_keep_behavior_local(ctx:IndexingContext, x:UOp):
|
||||
if x not in ctx.range_map: return None
|
||||
if (bx:=create_bufferize_and_index_based_on_ranges(ctx, x)) is None: bx = x
|
||||
valid: UOp = UOp.const(dtypes.bool, True).uprod([r.get_valid() for r in ctx.range_map[x][0]])
|
||||
ret = valid.where(x.src[0], UOp.const(x.dtype, 0))
|
||||
ctx.range_map[ret] = ctx.range_map[x]
|
||||
return ret
|
||||
return valid.where(bx.src[0], UOp.const(x.dtype, 0))
|
||||
|
||||
def convert_reduce_to_reduce_with_ranges(ctx:IndexingContext, x:UOp):
|
||||
if len(x.arg[1]) == 0: return None
|
||||
if (bx:=create_bufferize_and_index_based_on_ranges(ctx, x)) is None: bx = x
|
||||
# input ranges
|
||||
new_ranges = [r for i,r in enumerate(ctx.range_map[x][0]) if i in x.arg[1]]
|
||||
ret = UOp(Ops.REDUCE, x.dtype, src=(x.src[0],)+tuple(new_ranges), arg=(x.arg[0], ()))
|
||||
ctx.range_map[ret] = ctx.range_map[x]
|
||||
return ret
|
||||
return UOp(Ops.REDUCE, x.dtype, src=(bx.src[0],)+tuple(new_ranges), arg=(x.arg[0], ()))
|
||||
|
||||
def remove_movement_op_after_rangeify(ctx:IndexingContext, x:UOp):
|
||||
if x in ctx.range_map or x.src[0].op is Ops.INDEX: return x.src[0]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue