flip offset and shape in pad and shrink (#16414)

* flip offset and shape in pad and shrink

* dumb test
This commit is contained in:
George Hotz 2026-05-28 11:58:19 -07:00 committed by GitHub
commit edca5df25a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 39 additions and 26 deletions

View file

@ -276,9 +276,9 @@ def apply_grad(grad_buf:Tensor, new_grad:UOp):
grad_buf.uop = grad_buf.uop.after(grad_buf.uop.store(grad_buf.uop + new_grad))
return
cur = grad_buf.uop
for pad in sorted(pads, key=lambda p: p.marg[0][1] if p.op == Ops.PAD else 0, reverse=True):
for pad in sorted(pads, key=lambda p: p.marg[0][0] if p.op == Ops.PAD else 0, reverse=True):
if pad.op == Ops.PAD:
grad_shrink = tuple([(p[1], s+p[1]) for s,p in zip(pad.src[0].shape, pad.marg)])
grad_shrink = tuple([(p[0], s+p[0]) for s,p in zip(pad.src[0].shape, pad.marg)])
buf_slice = cur.shrink(grad_shrink)
cur = cur.after(buf_slice.store(buf_slice + pad.src[0].cast(cur.dtype)))
else:

View file

@ -23,7 +23,7 @@ def calculate_storage_offset(x: Tensor) -> int:
for u in x.uop.toposort():
if u.op == Ops.SHRINK:
u_strides = strides_for_shape(u.src[0].shape)
for i, (_, start) in enumerate(u.marg): offset += start * u_strides[i]
for i, (start, _) in enumerate(u.marg): offset += start * u_strides[i]
return offset
def wrap(x: Tensor) -> torch.Tensor:
x._strides = strides_for_shape(x.shape) # always recalculate

Binary file not shown.

View file

@ -74,8 +74,8 @@ A \op{Buffer}'s \textbf{addrspace} is \texttt{GLOBAL}, \texttt{LOCAL}, or \textt
\op{Flip} & $(T,)$ & bools $\mathbf{f}$ & Reverse along flagged axes. \\
\op{Reshape} & $(T, \mathbf{s'})$ & --- & Reinterpret in row-major order. $\prod s_k = \prod s'_k$. \\
\op{Expand} & $(T, \mathbf{s'})$ & --- & Broadcast size-1 axes. $s_k \in \{1, s'_k\}$. \\
\op{Pad} & $(T, \mathbf{s'}, \mathbf{o})$ & --- & Place $T$ at offset $o_k$ in a zero-filled output of shape $s'_k$. \\
\op{Shrink} & $(T, \mathbf{s'}, \mathbf{o})$ & --- & Keep $s'_k$ elements starting at offset $o_k$ per axis. Inverse of \op{Pad}. \\
\op{Pad} & $(T, \mathbf{o}, \mathbf{s'})$ & --- & Place $T$ at offset $o_k$ in a zero-filled output of shape $s'_k$. \\
\op{Shrink} & $(T, \mathbf{o}, \mathbf{s'})$ & --- & Keep $s'_k$ elements starting at offset $o_k$ per axis. Inverse of \op{Pad}. \\
\op{Index} & $(T, i_0, i_1, \ldots)$ & --- & Index from left. $()$-shaped $i$ removes dim; $(k,)$-shaped makes it $k$. \\
\op{Stack} & $(T_0, T_1, \ldots)$ & --- & Join along a newly created leading axis. All shapes must match. \\
\op{Replicated} & $(T,)$ & axes & Mark $T$ as replicated along axes. Collapse axes to $1$. \\

View file

@ -248,6 +248,19 @@ class TestViz(unittest.TestCase):
self.assertIn("EXPAND", excluded_nodes)
self.assertIn("CONST1 1 Ops.DEVICE", graph[id(alu)]["label"])
def test_stack_movement_not_folded_unless_all_const(self):
a = UOp.variable("a", 0, 10, dtype=dtypes.int)
c = UOp.const(dtypes.int, 1)
stack = a.vectorize(c)
reshaped = stack.reshape((1, 2))
graph = uop_to_json(VizData(), reshaped)
self.assertFalse(graph[id(stack)]["exclude"])
const_stack = c.vectorize(UOp.const(dtypes.int, 2))
const_reshaped = const_stack.reshape((1, 2))
const_graph = uop_to_json(VizData(), const_reshaped)
self.assertTrue(const_graph[id(const_stack)]["exclude"])
# VIZ displays nested graph_rewrites in a tree view
def leaf_rewrite(x:UOp): return x.rtag(1) if x.tag is None else None

View file

@ -72,8 +72,8 @@ pm_gradient = PatternMatcher([
(UPat(Ops.EXPAND, name="ret"), lambda ctx, ret:
(ctx.cast(sum_acc_dtype(ctx.dtype))._rop(Ops.ADD, tuple(i for i,(s,n) in enumerate(zip(ret.src[0].shape, ret.shape)) if s!=n))
.cast(ctx.dtype), None)),
(UPat(Ops.PAD, name="ret"), lambda ctx, ret: (ctx.shrink(tuple([(p[1], s+p[1]) for s,p in zip(ret.src[0].shape, ret.marg)])), None, None)),
(UPat(Ops.SHRINK, name="ret"), lambda ctx, ret: (ctx.pad(tuple([(p[1], s-p[0]-p[1]) for s,p in zip(ret.src[0].shape, ret.marg)])), None, None)),
(UPat(Ops.PAD, name="ret"), lambda ctx, ret: (ctx.shrink(tuple([(p[0], s+p[0]) for s,p in zip(ret.src[0].shape, ret.marg)])), None, None)),
(UPat(Ops.SHRINK, name="ret"), lambda ctx, ret: (ctx.pad(tuple([(p[0], s-p[0]-p[1]) for s,p in zip(ret.src[0].shape, ret.marg)])), None, None)),
(UPat(Ops.PERMUTE, name="ret"), lambda ctx, ret: (ctx.permute(argsort(ret.marg)),)),
(UPat(Ops.FLIP, name="ret"), lambda ctx, ret: (ctx.flip([i for i,x in enumerate(ret.marg) if x]),)),
(UPat(Ops.COPY, name="ret"), lambda ctx, ret: (ctx.copy_to_device(ret.src[0].device), None)),

View file

@ -178,7 +178,7 @@ class MovementMixin:
def pad(self, arg:tuple[tuple[sint, sint] | None, ...]) -> Self:
if self.ndim != len(arg):
raise ValueError(f"{self.ndim=} != {len(arg)=}")
ret = self._mop(Ops.PAD, tuple((s+x[0]+x[1], x[0]) if x is not None else (s, 0) for x, s in zip(arg, self.shape)))
ret = self._mop(Ops.PAD, tuple((x[0], s+x[0]+x[1]) if x is not None else (0, s) for x, s in zip(arg, self.shape)))
return self if ret.shape == self.shape else ret
def shrink(self, arg: tuple[tuple[sint, sint] | None, ...]) -> Self:
@ -200,7 +200,7 @@ class MovementMixin:
"""
if self.ndim != len(arg):
raise ValueError(f"{self.ndim=} != {len(arg)=}")
ret = self._mop(Ops.SHRINK, arg=[(x[1]-x[0], x[0]) if x is not None else (s, 0) for x, s in zip(arg, self.shape)])
ret = self._mop(Ops.SHRINK, arg=[(x[0], x[1]-x[0]) if x is not None else (0, s) for x, s in zip(arg, self.shape)])
return self if ret.shape == self.shape else ret
def permute(self, order, *args) -> Self:
@ -251,7 +251,7 @@ class MovementMixin:
return self.shrink(tuple([None if ns is None else (0, ns) for ns in argfix(shape, *args)]))
def pad_to(self, shape, *args) -> Self:
return self._mop(Ops.PAD, tuple((s if ns is None else ns, 0) for s,ns in zip(self.shape, argfix(shape, *args), strict=True)))
return self._mop(Ops.PAD, tuple((0, s if ns is None else ns) for s,ns in zip(self.shape, argfix(shape, *args), strict=True)))
def view(self, shape, *args) -> Self:
"""`.view` is an alias for `.reshape`."""

View file

@ -128,7 +128,7 @@ def _apply_reshape(in_shape:tuple[sint,...], out_shape:tuple[sint, ...], urngs:U
@functools.cache
def apply_movement_op(op:Ops, in_shape:tuple[sint,...], arg:tuple, rngs:tuple[UOp, ...]) -> tuple[UOp, ...]:
match op:
case Ops.SHRINK: rngs = tuple(a if off == 0 else a+off for a,(_,off) in zip(rngs, arg))
case Ops.SHRINK: rngs = tuple(a if off == 0 else a+off for a,(off,_) in zip(rngs, arg))
case Ops.PERMUTE: rngs = tuple(rngs[p] for p in argsort(arg))
case Ops.FLIP: rngs = tuple(((s-1)-a) if f else a for a,s,f in zip(rngs, in_shape, arg))
case Ops.EXPAND: rngs = tuple(a if in_sh == out_sh else a.const_like(0) for a,in_sh,out_sh in zip(rngs, in_shape, arg))
@ -136,7 +136,7 @@ def apply_movement_op(op:Ops, in_shape:tuple[sint,...], arg:tuple, rngs:tuple[UO
# NOTE: the .where(r-s, i) is not inside the graph_rewrite so that `convert_pad_to_where_to_keep_behavior_local`
# wraps the pad with only the newly added valid
rngs = tuple(r if (sz == sh and off == 0) else graph_rewrite((r >= off) & (r < (sh+off)),
symbolic+pm_simplify_valid, name="pad").where(r-off, UOp.invalid()) for r,sh,(sz,off) in zip(rngs, in_shape, arg))
symbolic+pm_simplify_valid, name="pad").where(r-off, UOp.invalid()) for r,sh,(off,sz) in zip(rngs, in_shape, arg))
case Ops.RESHAPE:
sink = UOp.sink(*rngs).simplify() # NOTE: this applies any commutative flips to the rngs early
sub_array = {r:UOp.range(r.src[0], i, AxisType.PLACEHOLDER) for i,r in enumerate(sink.ranges)}

View file

@ -88,8 +88,8 @@ def expand_multi(root:UOp, multi:UOp):
return multi.src[0].expand(new_shape).multi(multi.axis)
def pad_multi(root:UOp, multi:UOp):
assert multi.axis is None or root.marg[multi.axis] == (multi.shape[multi.axis], 0), f"padding not supported for {root.marg=}"
local_pad = tuple((multi.src[0].shape[multi.axis], 0) if a == multi.axis else s for a,s in enumerate(root.marg))
assert multi.axis is None or root.marg[multi.axis] == (0, multi.shape[multi.axis]), f"padding not supported for {root.marg=}"
local_pad = tuple((0, multi.src[0].shape[multi.axis]) if a == multi.axis else s for a,s in enumerate(root.marg))
return multi.src[0]._mop(Ops.PAD, local_pad).multi(multi.axis)
def permute_multi(root:UOp, multi:UOp):
@ -97,15 +97,15 @@ def permute_multi(root:UOp, multi:UOp):
return multi.src[0].permute(root.marg).multi(root.axis)
def shrink_multi(root:UOp, multi:UOp):
shard_bounds = tuple((e-s,s) for s,e in multi.bounds) if multi.axis is not None else ()
assert multi.axis is None or root.marg[multi.axis] == (multi.shape[multi.axis], 0) or root.marg[multi.axis] in shard_bounds, \
shard_bounds = tuple((s,e-s) for s,e in multi.bounds) if multi.axis is not None else ()
assert multi.axis is None or root.marg[multi.axis] == (0, multi.shape[multi.axis]) or root.marg[multi.axis] in shard_bounds, \
f"shrinking not supported for {root.marg=}"
if multi.axis is not None and root.marg[multi.axis] in shard_bounds and root.marg[multi.axis] != (multi.shape[multi.axis], 0):
if multi.axis is not None and root.marg[multi.axis] in shard_bounds and root.marg[multi.axis] != (0, multi.shape[multi.axis]):
# NOTE: shrink on the shard axis is only allowed when result is a single partition, denoted by the new real
# we just copy it to all the devices, no real. this will be optimized out later
non_shard_shrink = tuple((multi.src[0].shape[i], 0) if i == multi.axis else s for i, s in enumerate(root.marg))
non_shard_shrink = tuple((0, multi.src[0].shape[i]) if i == multi.axis else s for i, s in enumerate(root.marg))
return multi.src[0].copy_to_device(multi.device, arg=shard_bounds.index(root.marg[multi.axis]))._mop(Ops.SHRINK, non_shard_shrink)
local_shrink = tuple((multi.src[0].shape[multi.axis], 0) if a == multi.axis else s for a,s in enumerate(root.marg))
local_shrink = tuple((0, multi.src[0].shape[multi.axis]) if a == multi.axis else s for a,s in enumerate(root.marg))
return multi.src[0]._mop(Ops.SHRINK, local_shrink).multi(multi.axis)
def flip_multi(root:UOp, multi:UOp):

View file

@ -52,7 +52,7 @@ def found_after(ctx:dict[UOp, UOp], after:UOp, src:UOp):
if x.op is Ops.PERMUTE: x, after = x.src[0], after.permute(argsort(x.marg))
elif x.op is Ops.RESHAPE: x, after = x.src[0], after.reshape(x.src[0].shape)
elif x.op is Ops.WHERE and x.src[2].base.arg == Invalid and x.src[1].op is Ops.PAD:
x, after = x.src[1].src[0], after.shrink(tuple((o, s+o) for (_,o),s in zip(x.src[1].marg, x.src[1].src[0].shape)))
x, after = x.src[1].src[0], after.shrink(tuple((o, s+o) for (o,_),s in zip(x.src[1].marg, x.src[1].src[0].shape)))
else: break
ctx[x] = after

View file

@ -327,14 +327,14 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
return tuple(ps[i] for i in self.marg)
case Ops.PAD:
# TODO: why do i need resolve here?
if len(ps) != len(self.marg) or not all(resolve(sz>=0) and resolve(0<=o) and resolve(o+s<=sz) for s,(sz,o) in zip(ps, self.marg)):
if len(ps) != len(self.marg) or not all(resolve(sz>=0) and resolve(0<=o) and resolve(o+s<=sz) for s,(o,sz) in zip(ps, self.marg)):
raise ValueError(f"invalid pad {self.marg} for {ps}")
return tuple(ssimplify(sz) for sz,_ in self.marg)
return tuple(ssimplify(sz) for _,sz in self.marg)
case Ops.SHRINK:
# TODO: why do i need resolve here?
if len(ps) != len(self.marg) or not all(resolve(0<=b) and resolve(sz>=0) and resolve(b+sz<=s) for s,(sz,b) in zip(ps, self.marg)):
if len(ps) != len(self.marg) or not all(resolve(0<=o) and resolve(sz>=0) and resolve(o+sz<=s) for s,(o,sz) in zip(ps, self.marg)):
raise ValueError(f"invalid shrink {self.marg} for {ps}")
return tuple(ssimplify(sz) for sz,_ in self.marg)
return tuple(ssimplify(sz) for _,sz in self.marg)
case Ops.FLIP:
if len(ps) != len(self.marg) or not all(isinstance(x, bool) for x in self.marg): raise ValueError(f"bad flip on {ps}, {self.marg}")
return ps
@ -619,7 +619,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
if self.op in GroupOp.ALU: return axes[-1] if (axes := dedup([x.axis for x in self.src if x.axis is not None])) else None
if len(self.src) == 0: return None
src_axis = self.src[0].axis
if self.op is Ops.SHRINK and src_axis is not None and self.marg[src_axis] != (self.src[0].shape[src_axis], 0):
if self.op is Ops.SHRINK and src_axis is not None and self.marg[src_axis] != (0, self.src[0].shape[src_axis]):
return None # SHRINK will remove the sharding if it's on axis
if self.op is Ops.REDUCE: return None if src_axis is not None and src_axis in self.arg[1] else src_axis
if self.op is Ops.RESHAPE:

View file

@ -121,7 +121,7 @@ def uop_to_json(data:VizData, x:UOp) -> dict[int, dict]:
if u.op is Ops.STACK and len(u.src) == 0: excluded.add(u)
# exclude RESHAPE/EXPAND that only serve to broadcast a CONST
if u.op in {Ops.RESHAPE, Ops.EXPAND} and len(u.src) >= 1 and u.src[0] in excluded and u is not x: excluded.add(u)
if u.op in GroupOp.Movement: excluded.update(s for s in u.src if s.op is Ops.STACK)
if u.op in GroupOp.Movement: excluded.update(s for s in u.src if s.op is Ops.STACK and all(x.op is Ops.CONST for x in s.src))
for u in toposort:
argst = codecs.decode(str(u.arg), "unicode_escape")
if u.op in GroupOp.Movement: argst = (mask_to_str if u.op in {Ops.SHRINK, Ops.PAD} else shape_to_str)(u.marg)