shrink/pad use (new_shape, offset) (#16405)

* shrink uses offset and shape

* pad does too

* fix
This commit is contained in:
George Hotz 2026-05-27 15:13:08 -07:00 committed by GitHub
commit 8ee3a37524
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 35 additions and 29 deletions

View file

@ -276,9 +276,9 @@ def apply_grad(grad_buf:Tensor, new_grad:UOp):
grad_buf.uop = grad_buf.uop.after(grad_buf.uop.store(grad_buf.uop + new_grad))
return
cur = grad_buf.uop
for pad in sorted(pads, key=lambda p: p.marg[0][0] if p.op == Ops.PAD else 0, reverse=True):
for pad in sorted(pads, key=lambda p: p.marg[0][1] if p.op == Ops.PAD else 0, reverse=True):
if pad.op == Ops.PAD:
grad_shrink = tuple([(p[0], s+p[0]) for s,p in zip(pad.src[0].shape, pad.marg)])
grad_shrink = tuple([(p[1], s+p[1]) for s,p in zip(pad.src[0].shape, pad.marg)])
buf_slice = cur.shrink(grad_shrink)
cur = cur.after(buf_slice.store(buf_slice + pad.src[0].cast(cur.dtype)))
else:

View file

@ -23,7 +23,7 @@ def calculate_storage_offset(x: Tensor) -> int:
for u in x.uop.toposort():
if u.op == Ops.SHRINK:
u_strides = strides_for_shape(u.src[0].shape)
for i, (start, _) in enumerate(u.marg): offset += start * u_strides[i]
for i, (_, start) in enumerate(u.marg): offset += start * u_strides[i]
return offset
def wrap(x: Tensor) -> torch.Tensor:
x._strides = strides_for_shape(x.shape) # always recalculate

Binary file not shown.

View file

@ -74,8 +74,8 @@ A \op{Buffer}'s \textbf{addrspace} is \texttt{GLOBAL}, \texttt{LOCAL}, or \textt
\op{Flip} & $(T,)$ & bools $\mathbf{f}$ & Reverse along flagged axes. \\
\op{Reshape} & $(T, \mathbf{s'})$ & --- & Reinterpret in row-major order. $\prod s_k = \prod s'_k$. \\
\op{Expand} & $(T, \mathbf{s'})$ & --- & Broadcast size-1 axes. $s_k \in \{1, s'_k\}$. \\
\op{Pad} & $(T, \mathbf{b}, \mathbf{e})$ & --- & Pad with $0$s: $b_k$ before, $e_k$ after each axis. \\
\op{Shrink} & $(T, \mathbf{b}, \mathbf{e})$ & --- & Keep $[b_k, e_k)$ per axis. Inverse of \op{Pad}. \\
\op{Pad} & $(T, \mathbf{s'}, \mathbf{o})$ & --- & Place $T$ at offset $o_k$ in a zero-filled output of shape $s'_k$. \\
\op{Shrink} & $(T, \mathbf{s'}, \mathbf{o})$ & --- & Keep $s'_k$ elements starting at offset $o_k$ per axis. Inverse of \op{Pad}. \\
\op{Index} & $(T, i_0, i_1, \ldots)$ & --- & Index from left. $()$-shaped $i$ removes dim; $(k,)$-shaped makes it $k$. \\
\op{Stack} & $(T_0, T_1, \ldots)$ & --- & Join along a newly created leading axis. All shapes must match. \\
\op{Replicated} & $(T,)$ & axes & Mark $T$ as replicated along axes. Collapse axes to $1$. \\

View file

@ -72,8 +72,8 @@ pm_gradient = PatternMatcher([
(UPat(Ops.EXPAND, name="ret"), lambda ctx, ret:
(ctx.cast(sum_acc_dtype(ctx.dtype))._rop(Ops.ADD, tuple(i for i,(s,n) in enumerate(zip(ret.src[0].shape, ret.shape)) if s!=n))
.cast(ctx.dtype), None)),
(UPat(Ops.PAD, name="ret"), lambda ctx, ret: (ctx.shrink(tuple([(p[0], s+p[0]) for s,p in zip(ret.src[0].shape, ret.marg)])), None, None)),
(UPat(Ops.SHRINK, name="ret"), lambda ctx, ret: (ctx.pad(tuple([(p[0], s-p[1]) for s,p in zip(ret.src[0].shape, ret.marg)])), None, None)),
(UPat(Ops.PAD, name="ret"), lambda ctx, ret: (ctx.shrink(tuple([(p[1], s+p[1]) for s,p in zip(ret.src[0].shape, ret.marg)])), None, None)),
(UPat(Ops.SHRINK, name="ret"), lambda ctx, ret: (ctx.pad(tuple([(p[1], s-p[0]-p[1]) for s,p in zip(ret.src[0].shape, ret.marg)])), None, None)),
(UPat(Ops.PERMUTE, name="ret"), lambda ctx, ret: (ctx.permute(argsort(ret.marg)),)),
(UPat(Ops.FLIP, name="ret"), lambda ctx, ret: (ctx.flip([i for i,x in enumerate(ret.marg) if x]),)),
(UPat(Ops.COPY, name="ret"), lambda ctx, ret: (ctx.copy_to_device(ret.src[0].device), None)),

View file

@ -178,7 +178,7 @@ class MovementMixin:
def pad(self, arg:tuple[tuple[sint, sint] | None, ...]) -> Self:
if self.ndim != len(arg):
raise ValueError(f"{self.ndim=} != {len(arg)=}")
ret = self._mop(Ops.PAD, tuple(x if x is not None else (0, 0) for x in arg))
ret = self._mop(Ops.PAD, tuple((s+x[0]+x[1], x[0]) if x is not None else (s, 0) for x, s in zip(arg, self.shape)))
return self if ret.shape == self.shape else ret
def shrink(self, arg: tuple[tuple[sint, sint] | None, ...]) -> Self:
@ -200,7 +200,7 @@ class MovementMixin:
"""
if self.ndim != len(arg):
raise ValueError(f"{self.ndim=} != {len(arg)=}")
ret = self._mop(Ops.SHRINK, arg=[x if x is not None else (0, s) for x, s in zip(arg, self.shape)])
ret = self._mop(Ops.SHRINK, arg=[(x[1]-x[0], x[0]) if x is not None else (s, 0) for x, s in zip(arg, self.shape)])
return self if ret.shape == self.shape else ret
def permute(self, order, *args) -> Self:
@ -251,7 +251,7 @@ class MovementMixin:
return self.shrink(tuple([None if ns is None else (0, ns) for ns in argfix(shape, *args)]))
def pad_to(self, shape, *args) -> Self:
return self._mop(Ops.PAD, tuple([(0, 0 if ns is None else ns-s) for s,ns in zip(self.shape, argfix(shape, *args), strict=True)]))
return self._mop(Ops.PAD, tuple((s if ns is None else ns, 0) for s,ns in zip(self.shape, argfix(shape, *args), strict=True)))
def view(self, shape, *args) -> Self:
"""`.view` is an alias for `.reshape`."""

View file

@ -128,15 +128,15 @@ def _apply_reshape(in_shape:tuple[sint,...], out_shape:tuple[sint, ...], urngs:U
@functools.cache
def apply_movement_op(op:Ops, in_shape:tuple[sint,...], arg:tuple, rngs:tuple[UOp, ...]) -> tuple[UOp, ...]:
match op:
case Ops.SHRINK: rngs = tuple(a if ss == 0 else a+ss for a,(ss,_) in zip(rngs, arg))
case Ops.SHRINK: rngs = tuple(a if off == 0 else a+off for a,(_,off) in zip(rngs, arg))
case Ops.PERMUTE: rngs = tuple(rngs[p] for p in argsort(arg))
case Ops.FLIP: rngs = tuple(((s-1)-a) if f else a for a,s,f in zip(rngs, in_shape, arg))
case Ops.EXPAND: rngs = tuple(a if in_sh == out_sh else a.const_like(0) for a,in_sh,out_sh in zip(rngs, in_shape, arg))
case Ops.PAD:
# NOTE: the .where(r-s, i) is not inside the graph_rewrite so that `convert_pad_to_where_to_keep_behavior_local`
# wraps the pad with only the newly added valid
rngs = tuple(r if (s == 0 and e == 0) else graph_rewrite((r >= s) & (r < (sh+s)),
symbolic+pm_simplify_valid, name="pad").where(r-s, UOp.invalid()) for r,sh,(s,e) in zip(rngs, in_shape, arg))
rngs = tuple(r if (sz == sh and off == 0) else graph_rewrite((r >= off) & (r < (sh+off)),
symbolic+pm_simplify_valid, name="pad").where(r-off, UOp.invalid()) for r,sh,(sz,off) in zip(rngs, in_shape, arg))
case Ops.RESHAPE:
sink = UOp.sink(*rngs).simplify() # NOTE: this applies any commutative flips to the rngs early
sub_array = {r:UOp.range(r.src[0], i, AxisType.PLACEHOLDER) for i,r in enumerate(sink.ranges)}

View file

@ -10,7 +10,7 @@ def mstack_early_shrink(ms:UOp, shrink:UOp):
def apply_shrink(s:UOp, i:int) -> UOp:
new_arg = [tuple([x.substitute({dvar[0]:dvar[0].const_like(i)}) if isinstance(x, UOp) and
(dvar:=[v for v in x.variables() if v.expr=='_device_num']) else x for x in ss]) for ss in shrink.marg]
return s.shrink(tuple(new_arg))
return s._mop(Ops.SHRINK, tuple(new_arg))
for i, x in enumerate(ms.src):
if x.op is Ops.COPY:
ret.append(apply_shrink(x.src[0], i).copy_to_device(x.device))
@ -88,22 +88,25 @@ def expand_multi(root:UOp, multi:UOp):
return multi.src[0].expand(new_shape).multi(multi.axis)
def pad_multi(root:UOp, multi:UOp):
assert multi.axis is None or root.marg[multi.axis] == (0,0), f"padding not supported for {root.marg=}"
return multi.src[0].pad(root.marg).multi(multi.axis)
assert multi.axis is None or root.marg[multi.axis] == (multi.shape[multi.axis], 0), f"padding not supported for {root.marg=}"
local_pad = tuple((multi.src[0].shape[multi.axis], 0) if a == multi.axis else s for a,s in enumerate(root.marg))
return multi.src[0]._mop(Ops.PAD, local_pad).multi(multi.axis)
def permute_multi(root:UOp, multi:UOp):
# all permutes supported!
return multi.src[0].permute(root.marg).multi(root.axis)
def shrink_multi(root:UOp, multi:UOp):
assert multi.axis is None or root.marg[multi.axis] == (0, multi.shape[multi.axis]) or root.marg[multi.axis] in multi.bounds, \
shard_bounds = tuple((e-s,s) for s,e in multi.bounds) if multi.axis is not None else ()
assert multi.axis is None or root.marg[multi.axis] == (multi.shape[multi.axis], 0) or root.marg[multi.axis] in shard_bounds, \
f"shrinking not supported for {root.marg=}"
if multi.axis is not None and root.marg[multi.axis] in multi.bounds and root.marg[multi.axis] != (0, multi.shape[multi.axis]):
if multi.axis is not None and root.marg[multi.axis] in shard_bounds and root.marg[multi.axis] != (multi.shape[multi.axis], 0):
# NOTE: shrink on the shard axis is only allowed when result is a single partition, denoted by the new real
# we just copy it to all the devices, no real. this will be optimized out later
non_shard_shrink = tuple((0, multi.src[0].shape[i]) if i == multi.axis else s for i, s in enumerate(root.marg))
return multi.src[0].copy_to_device(multi.device, arg=multi.bounds.index(root.marg[multi.axis])).shrink(non_shard_shrink)
return multi.src[0].shrink(tuple((0, multi.src[0].shape[multi.axis]) if a == multi.axis else s for a,s in enumerate(root.marg))).multi(multi.axis)
non_shard_shrink = tuple((multi.src[0].shape[i], 0) if i == multi.axis else s for i, s in enumerate(root.marg))
return multi.src[0].copy_to_device(multi.device, arg=shard_bounds.index(root.marg[multi.axis]))._mop(Ops.SHRINK, non_shard_shrink)
local_shrink = tuple((multi.src[0].shape[multi.axis], 0) if a == multi.axis else s for a,s in enumerate(root.marg))
return multi.src[0]._mop(Ops.SHRINK, local_shrink).multi(multi.axis)
def flip_multi(root:UOp, multi:UOp):
assert multi.axis is None or not root.marg[multi.axis], "flipping not supported on sharded axis"

View file

@ -52,7 +52,7 @@ def found_after(ctx:dict[UOp, UOp], after:UOp, src:UOp):
if x.op is Ops.PERMUTE: x, after = x.src[0], after.permute(argsort(x.marg))
elif x.op is Ops.RESHAPE: x, after = x.src[0], after.reshape(x.src[0].shape)
elif x.op is Ops.WHERE and x.src[2].base.arg == Invalid and x.src[1].op is Ops.PAD:
x, after = x.src[1].src[0], after.shrink(tuple((l, s-r) for (l,r),s in zip(x.src[1].marg, x.shape)))
x, after = x.src[1].src[0], after.shrink(tuple((o, s+o) for (_,o),s in zip(x.src[1].marg, x.src[1].src[0].shape)))
else: break
ctx[x] = after

View file

@ -327,13 +327,14 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
return tuple(ps[i] for i in self.marg)
case Ops.PAD:
# TODO: why do i need resolve here?
if len(ps) != len(self.marg) or not all(resolve(b>=0) and resolve(e>=0) for b,e in self.marg): raise ValueError(f"invalid pad {self.marg}")
return tuple(ssimplify(s+b+e) for s,(b,e) in zip(ps, self.marg))
if len(ps) != len(self.marg) or not all(resolve(sz>=0) and resolve(0<=o) and resolve(o+s<=sz) for s,(sz,o) in zip(ps, self.marg)):
raise ValueError(f"invalid pad {self.marg} for {ps}")
return tuple(ssimplify(sz) for sz,_ in self.marg)
case Ops.SHRINK:
# TODO: why do i need resolve here?
if len(ps) != len(self.marg) or not all(resolve(0<=b) and resolve(b<=e) and resolve(e<=s) for s,(b,e) in zip(ps, self.marg)):
if len(ps) != len(self.marg) or not all(resolve(0<=b) and resolve(sz>=0) and resolve(b+sz<=s) for s,(sz,b) in zip(ps, self.marg)):
raise ValueError(f"invalid shrink {self.marg} for {ps}")
return tuple(ssimplify(e-s) for s,e in self.marg)
return tuple(ssimplify(sz) for sz,_ in self.marg)
case Ops.FLIP:
if len(ps) != len(self.marg) or not all(isinstance(x, bool) for x in self.marg): raise ValueError(f"bad flip on {ps}, {self.marg}")
return ps
@ -618,7 +619,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
if self.op in GroupOp.ALU: return axes[-1] if (axes := dedup([x.axis for x in self.src if x.axis is not None])) else None
if len(self.src) == 0: return None
src_axis = self.src[0].axis
if self.op is Ops.SHRINK and src_axis is not None and self.marg[src_axis] != (0, self.src[0].shape[src_axis]):
if self.op is Ops.SHRINK and src_axis is not None and self.marg[src_axis] != (self.src[0].shape[src_axis], 0):
return None # SHRINK will remove the sharding if it's on axis
if self.op is Ops.REDUCE: return None if src_axis is not None and src_axis in self.arg[1] else src_axis
if self.op is Ops.RESHAPE:
@ -679,7 +680,8 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
def marg(self):
match self.op:
case Ops.RESHAPE | Ops.EXPAND: return tuple(ssimplify(self.src[1].sgep(i)) for i in range(self.src[1].dtype.count))
case Ops.PAD | Ops.SHRINK: return tuple((self.src[1].sgep(i), self.src[2].sgep(i)) for i in range(self.src[1].dtype.count))
case Ops.PAD: return tuple((self.src[1].sgep(i), self.src[2].sgep(i)) for i in range(self.src[1].dtype.count))
case Ops.SHRINK: return tuple((self.src[1].sgep(i), self.src[2].sgep(i)) for i in range(self.src[1].dtype.count))
case Ops.PERMUTE | Ops.FLIP: return self.arg
case _: raise RuntimeError(f"{self.op} is not a MovementOp")

View file

@ -158,7 +158,8 @@ spec_tensor = PatternMatcher([
# movement ops
(UPat((Ops.RESHAPE, Ops.EXPAND), src=(UPat(), UPat(dtype=dtypes.weakint))), lambda: True),
(UPat((Ops.PAD, Ops.SHRINK), src=(UPat(), UPat(dtype=dtypes.weakint), UPat(dtype=dtypes.weakint))), lambda: True),
(UPat((Ops.PAD, Ops.SHRINK), src=(UPat(), UPat(dtype=dtypes.weakint), UPat(dtype=dtypes.weakint)), name="x"),
lambda x: x.src[1].dtype.count == x.src[2].dtype.count),
(UPat((Ops.PERMUTE, Ops.FLIP), name="mv", src=(UPat(),)), lambda mv: isinstance(mv.arg, tuple)),
# REDUCE has arg=(op, axis_tuple), src[1:] are ranges after lowering