mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
flip offset and shape in pad and shrink (#16414)
* flip offset and shape in pad and shrink * dumb test
This commit is contained in:
parent
d72d8ee065
commit
edca5df25a
12 changed files with 39 additions and 26 deletions
|
|
@ -276,9 +276,9 @@ def apply_grad(grad_buf:Tensor, new_grad:UOp):
|
|||
grad_buf.uop = grad_buf.uop.after(grad_buf.uop.store(grad_buf.uop + new_grad))
|
||||
return
|
||||
cur = grad_buf.uop
|
||||
for pad in sorted(pads, key=lambda p: p.marg[0][1] if p.op == Ops.PAD else 0, reverse=True):
|
||||
for pad in sorted(pads, key=lambda p: p.marg[0][0] if p.op == Ops.PAD else 0, reverse=True):
|
||||
if pad.op == Ops.PAD:
|
||||
grad_shrink = tuple([(p[1], s+p[1]) for s,p in zip(pad.src[0].shape, pad.marg)])
|
||||
grad_shrink = tuple([(p[0], s+p[0]) for s,p in zip(pad.src[0].shape, pad.marg)])
|
||||
buf_slice = cur.shrink(grad_shrink)
|
||||
cur = cur.after(buf_slice.store(buf_slice + pad.src[0].cast(cur.dtype)))
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ def calculate_storage_offset(x: Tensor) -> int:
|
|||
for u in x.uop.toposort():
|
||||
if u.op == Ops.SHRINK:
|
||||
u_strides = strides_for_shape(u.src[0].shape)
|
||||
for i, (_, start) in enumerate(u.marg): offset += start * u_strides[i]
|
||||
for i, (start, _) in enumerate(u.marg): offset += start * u_strides[i]
|
||||
return offset
|
||||
def wrap(x: Tensor) -> torch.Tensor:
|
||||
x._strides = strides_for_shape(x.shape) # always recalculate
|
||||
|
|
|
|||
Binary file not shown.
|
|
@ -74,8 +74,8 @@ A \op{Buffer}'s \textbf{addrspace} is \texttt{GLOBAL}, \texttt{LOCAL}, or \textt
|
|||
\op{Flip} & $(T,)$ & bools $\mathbf{f}$ & Reverse along flagged axes. \\
|
||||
\op{Reshape} & $(T, \mathbf{s'})$ & --- & Reinterpret in row-major order. $\prod s_k = \prod s'_k$. \\
|
||||
\op{Expand} & $(T, \mathbf{s'})$ & --- & Broadcast size-1 axes. $s_k \in \{1, s'_k\}$. \\
|
||||
\op{Pad} & $(T, \mathbf{s'}, \mathbf{o})$ & --- & Place $T$ at offset $o_k$ in a zero-filled output of shape $s'_k$. \\
|
||||
\op{Shrink} & $(T, \mathbf{s'}, \mathbf{o})$ & --- & Keep $s'_k$ elements starting at offset $o_k$ per axis. Inverse of \op{Pad}. \\
|
||||
\op{Pad} & $(T, \mathbf{o}, \mathbf{s'})$ & --- & Place $T$ at offset $o_k$ in a zero-filled output of shape $s'_k$. \\
|
||||
\op{Shrink} & $(T, \mathbf{o}, \mathbf{s'})$ & --- & Keep $s'_k$ elements starting at offset $o_k$ per axis. Inverse of \op{Pad}. \\
|
||||
\op{Index} & $(T, i_0, i_1, \ldots)$ & --- & Index from left. $()$-shaped $i$ removes dim; $(k,)$-shaped makes it $k$. \\
|
||||
\op{Stack} & $(T_0, T_1, \ldots)$ & --- & Join along a newly created leading axis. All shapes must match. \\
|
||||
\op{Replicated} & $(T,)$ & axes & Mark $T$ as replicated along axes. Collapse axes to $1$. \\
|
||||
|
|
|
|||
|
|
@ -248,6 +248,19 @@ class TestViz(unittest.TestCase):
|
|||
self.assertIn("EXPAND", excluded_nodes)
|
||||
self.assertIn("CONST1 1 Ops.DEVICE", graph[id(alu)]["label"])
|
||||
|
||||
def test_stack_movement_not_folded_unless_all_const(self):
|
||||
a = UOp.variable("a", 0, 10, dtype=dtypes.int)
|
||||
c = UOp.const(dtypes.int, 1)
|
||||
stack = a.vectorize(c)
|
||||
reshaped = stack.reshape((1, 2))
|
||||
graph = uop_to_json(VizData(), reshaped)
|
||||
self.assertFalse(graph[id(stack)]["exclude"])
|
||||
|
||||
const_stack = c.vectorize(UOp.const(dtypes.int, 2))
|
||||
const_reshaped = const_stack.reshape((1, 2))
|
||||
const_graph = uop_to_json(VizData(), const_reshaped)
|
||||
self.assertTrue(const_graph[id(const_stack)]["exclude"])
|
||||
|
||||
# VIZ displays nested graph_rewrites in a tree view
|
||||
|
||||
def leaf_rewrite(x:UOp): return x.rtag(1) if x.tag is None else None
|
||||
|
|
|
|||
|
|
@ -72,8 +72,8 @@ pm_gradient = PatternMatcher([
|
|||
(UPat(Ops.EXPAND, name="ret"), lambda ctx, ret:
|
||||
(ctx.cast(sum_acc_dtype(ctx.dtype))._rop(Ops.ADD, tuple(i for i,(s,n) in enumerate(zip(ret.src[0].shape, ret.shape)) if s!=n))
|
||||
.cast(ctx.dtype), None)),
|
||||
(UPat(Ops.PAD, name="ret"), lambda ctx, ret: (ctx.shrink(tuple([(p[1], s+p[1]) for s,p in zip(ret.src[0].shape, ret.marg)])), None, None)),
|
||||
(UPat(Ops.SHRINK, name="ret"), lambda ctx, ret: (ctx.pad(tuple([(p[1], s-p[0]-p[1]) for s,p in zip(ret.src[0].shape, ret.marg)])), None, None)),
|
||||
(UPat(Ops.PAD, name="ret"), lambda ctx, ret: (ctx.shrink(tuple([(p[0], s+p[0]) for s,p in zip(ret.src[0].shape, ret.marg)])), None, None)),
|
||||
(UPat(Ops.SHRINK, name="ret"), lambda ctx, ret: (ctx.pad(tuple([(p[0], s-p[0]-p[1]) for s,p in zip(ret.src[0].shape, ret.marg)])), None, None)),
|
||||
(UPat(Ops.PERMUTE, name="ret"), lambda ctx, ret: (ctx.permute(argsort(ret.marg)),)),
|
||||
(UPat(Ops.FLIP, name="ret"), lambda ctx, ret: (ctx.flip([i for i,x in enumerate(ret.marg) if x]),)),
|
||||
(UPat(Ops.COPY, name="ret"), lambda ctx, ret: (ctx.copy_to_device(ret.src[0].device), None)),
|
||||
|
|
|
|||
|
|
@ -178,7 +178,7 @@ class MovementMixin:
|
|||
def pad(self, arg:tuple[tuple[sint, sint] | None, ...]) -> Self:
|
||||
if self.ndim != len(arg):
|
||||
raise ValueError(f"{self.ndim=} != {len(arg)=}")
|
||||
ret = self._mop(Ops.PAD, tuple((s+x[0]+x[1], x[0]) if x is not None else (s, 0) for x, s in zip(arg, self.shape)))
|
||||
ret = self._mop(Ops.PAD, tuple((x[0], s+x[0]+x[1]) if x is not None else (0, s) for x, s in zip(arg, self.shape)))
|
||||
return self if ret.shape == self.shape else ret
|
||||
|
||||
def shrink(self, arg: tuple[tuple[sint, sint] | None, ...]) -> Self:
|
||||
|
|
@ -200,7 +200,7 @@ class MovementMixin:
|
|||
"""
|
||||
if self.ndim != len(arg):
|
||||
raise ValueError(f"{self.ndim=} != {len(arg)=}")
|
||||
ret = self._mop(Ops.SHRINK, arg=[(x[1]-x[0], x[0]) if x is not None else (s, 0) for x, s in zip(arg, self.shape)])
|
||||
ret = self._mop(Ops.SHRINK, arg=[(x[0], x[1]-x[0]) if x is not None else (0, s) for x, s in zip(arg, self.shape)])
|
||||
return self if ret.shape == self.shape else ret
|
||||
|
||||
def permute(self, order, *args) -> Self:
|
||||
|
|
@ -251,7 +251,7 @@ class MovementMixin:
|
|||
return self.shrink(tuple([None if ns is None else (0, ns) for ns in argfix(shape, *args)]))
|
||||
|
||||
def pad_to(self, shape, *args) -> Self:
|
||||
return self._mop(Ops.PAD, tuple((s if ns is None else ns, 0) for s,ns in zip(self.shape, argfix(shape, *args), strict=True)))
|
||||
return self._mop(Ops.PAD, tuple((0, s if ns is None else ns) for s,ns in zip(self.shape, argfix(shape, *args), strict=True)))
|
||||
|
||||
def view(self, shape, *args) -> Self:
|
||||
"""`.view` is an alias for `.reshape`."""
|
||||
|
|
|
|||
|
|
@ -128,7 +128,7 @@ def _apply_reshape(in_shape:tuple[sint,...], out_shape:tuple[sint, ...], urngs:U
|
|||
@functools.cache
|
||||
def apply_movement_op(op:Ops, in_shape:tuple[sint,...], arg:tuple, rngs:tuple[UOp, ...]) -> tuple[UOp, ...]:
|
||||
match op:
|
||||
case Ops.SHRINK: rngs = tuple(a if off == 0 else a+off for a,(_,off) in zip(rngs, arg))
|
||||
case Ops.SHRINK: rngs = tuple(a if off == 0 else a+off for a,(off,_) in zip(rngs, arg))
|
||||
case Ops.PERMUTE: rngs = tuple(rngs[p] for p in argsort(arg))
|
||||
case Ops.FLIP: rngs = tuple(((s-1)-a) if f else a for a,s,f in zip(rngs, in_shape, arg))
|
||||
case Ops.EXPAND: rngs = tuple(a if in_sh == out_sh else a.const_like(0) for a,in_sh,out_sh in zip(rngs, in_shape, arg))
|
||||
|
|
@ -136,7 +136,7 @@ def apply_movement_op(op:Ops, in_shape:tuple[sint,...], arg:tuple, rngs:tuple[UO
|
|||
# NOTE: the .where(r-s, i) is not inside the graph_rewrite so that `convert_pad_to_where_to_keep_behavior_local`
|
||||
# wraps the pad with only the newly added valid
|
||||
rngs = tuple(r if (sz == sh and off == 0) else graph_rewrite((r >= off) & (r < (sh+off)),
|
||||
symbolic+pm_simplify_valid, name="pad").where(r-off, UOp.invalid()) for r,sh,(sz,off) in zip(rngs, in_shape, arg))
|
||||
symbolic+pm_simplify_valid, name="pad").where(r-off, UOp.invalid()) for r,sh,(off,sz) in zip(rngs, in_shape, arg))
|
||||
case Ops.RESHAPE:
|
||||
sink = UOp.sink(*rngs).simplify() # NOTE: this applies any commutative flips to the rngs early
|
||||
sub_array = {r:UOp.range(r.src[0], i, AxisType.PLACEHOLDER) for i,r in enumerate(sink.ranges)}
|
||||
|
|
|
|||
|
|
@ -88,8 +88,8 @@ def expand_multi(root:UOp, multi:UOp):
|
|||
return multi.src[0].expand(new_shape).multi(multi.axis)
|
||||
|
||||
def pad_multi(root:UOp, multi:UOp):
|
||||
assert multi.axis is None or root.marg[multi.axis] == (multi.shape[multi.axis], 0), f"padding not supported for {root.marg=}"
|
||||
local_pad = tuple((multi.src[0].shape[multi.axis], 0) if a == multi.axis else s for a,s in enumerate(root.marg))
|
||||
assert multi.axis is None or root.marg[multi.axis] == (0, multi.shape[multi.axis]), f"padding not supported for {root.marg=}"
|
||||
local_pad = tuple((0, multi.src[0].shape[multi.axis]) if a == multi.axis else s for a,s in enumerate(root.marg))
|
||||
return multi.src[0]._mop(Ops.PAD, local_pad).multi(multi.axis)
|
||||
|
||||
def permute_multi(root:UOp, multi:UOp):
|
||||
|
|
@ -97,15 +97,15 @@ def permute_multi(root:UOp, multi:UOp):
|
|||
return multi.src[0].permute(root.marg).multi(root.axis)
|
||||
|
||||
def shrink_multi(root:UOp, multi:UOp):
|
||||
shard_bounds = tuple((e-s,s) for s,e in multi.bounds) if multi.axis is not None else ()
|
||||
assert multi.axis is None or root.marg[multi.axis] == (multi.shape[multi.axis], 0) or root.marg[multi.axis] in shard_bounds, \
|
||||
shard_bounds = tuple((s,e-s) for s,e in multi.bounds) if multi.axis is not None else ()
|
||||
assert multi.axis is None or root.marg[multi.axis] == (0, multi.shape[multi.axis]) or root.marg[multi.axis] in shard_bounds, \
|
||||
f"shrinking not supported for {root.marg=}"
|
||||
if multi.axis is not None and root.marg[multi.axis] in shard_bounds and root.marg[multi.axis] != (multi.shape[multi.axis], 0):
|
||||
if multi.axis is not None and root.marg[multi.axis] in shard_bounds and root.marg[multi.axis] != (0, multi.shape[multi.axis]):
|
||||
# NOTE: shrink on the shard axis is only allowed when result is a single partition, denoted by the new real
|
||||
# we just copy it to all the devices, no real. this will be optimized out later
|
||||
non_shard_shrink = tuple((multi.src[0].shape[i], 0) if i == multi.axis else s for i, s in enumerate(root.marg))
|
||||
non_shard_shrink = tuple((0, multi.src[0].shape[i]) if i == multi.axis else s for i, s in enumerate(root.marg))
|
||||
return multi.src[0].copy_to_device(multi.device, arg=shard_bounds.index(root.marg[multi.axis]))._mop(Ops.SHRINK, non_shard_shrink)
|
||||
local_shrink = tuple((multi.src[0].shape[multi.axis], 0) if a == multi.axis else s for a,s in enumerate(root.marg))
|
||||
local_shrink = tuple((0, multi.src[0].shape[multi.axis]) if a == multi.axis else s for a,s in enumerate(root.marg))
|
||||
return multi.src[0]._mop(Ops.SHRINK, local_shrink).multi(multi.axis)
|
||||
|
||||
def flip_multi(root:UOp, multi:UOp):
|
||||
|
|
|
|||
|
|
@ -52,7 +52,7 @@ def found_after(ctx:dict[UOp, UOp], after:UOp, src:UOp):
|
|||
if x.op is Ops.PERMUTE: x, after = x.src[0], after.permute(argsort(x.marg))
|
||||
elif x.op is Ops.RESHAPE: x, after = x.src[0], after.reshape(x.src[0].shape)
|
||||
elif x.op is Ops.WHERE and x.src[2].base.arg == Invalid and x.src[1].op is Ops.PAD:
|
||||
x, after = x.src[1].src[0], after.shrink(tuple((o, s+o) for (_,o),s in zip(x.src[1].marg, x.src[1].src[0].shape)))
|
||||
x, after = x.src[1].src[0], after.shrink(tuple((o, s+o) for (o,_),s in zip(x.src[1].marg, x.src[1].src[0].shape)))
|
||||
else: break
|
||||
ctx[x] = after
|
||||
|
||||
|
|
|
|||
|
|
@ -327,14 +327,14 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
return tuple(ps[i] for i in self.marg)
|
||||
case Ops.PAD:
|
||||
# TODO: why do i need resolve here?
|
||||
if len(ps) != len(self.marg) or not all(resolve(sz>=0) and resolve(0<=o) and resolve(o+s<=sz) for s,(sz,o) in zip(ps, self.marg)):
|
||||
if len(ps) != len(self.marg) or not all(resolve(sz>=0) and resolve(0<=o) and resolve(o+s<=sz) for s,(o,sz) in zip(ps, self.marg)):
|
||||
raise ValueError(f"invalid pad {self.marg} for {ps}")
|
||||
return tuple(ssimplify(sz) for sz,_ in self.marg)
|
||||
return tuple(ssimplify(sz) for _,sz in self.marg)
|
||||
case Ops.SHRINK:
|
||||
# TODO: why do i need resolve here?
|
||||
if len(ps) != len(self.marg) or not all(resolve(0<=b) and resolve(sz>=0) and resolve(b+sz<=s) for s,(sz,b) in zip(ps, self.marg)):
|
||||
if len(ps) != len(self.marg) or not all(resolve(0<=o) and resolve(sz>=0) and resolve(o+sz<=s) for s,(o,sz) in zip(ps, self.marg)):
|
||||
raise ValueError(f"invalid shrink {self.marg} for {ps}")
|
||||
return tuple(ssimplify(sz) for sz,_ in self.marg)
|
||||
return tuple(ssimplify(sz) for _,sz in self.marg)
|
||||
case Ops.FLIP:
|
||||
if len(ps) != len(self.marg) or not all(isinstance(x, bool) for x in self.marg): raise ValueError(f"bad flip on {ps}, {self.marg}")
|
||||
return ps
|
||||
|
|
@ -619,7 +619,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
if self.op in GroupOp.ALU: return axes[-1] if (axes := dedup([x.axis for x in self.src if x.axis is not None])) else None
|
||||
if len(self.src) == 0: return None
|
||||
src_axis = self.src[0].axis
|
||||
if self.op is Ops.SHRINK and src_axis is not None and self.marg[src_axis] != (self.src[0].shape[src_axis], 0):
|
||||
if self.op is Ops.SHRINK and src_axis is not None and self.marg[src_axis] != (0, self.src[0].shape[src_axis]):
|
||||
return None # SHRINK will remove the sharding if it's on axis
|
||||
if self.op is Ops.REDUCE: return None if src_axis is not None and src_axis in self.arg[1] else src_axis
|
||||
if self.op is Ops.RESHAPE:
|
||||
|
|
|
|||
|
|
@ -121,7 +121,7 @@ def uop_to_json(data:VizData, x:UOp) -> dict[int, dict]:
|
|||
if u.op is Ops.STACK and len(u.src) == 0: excluded.add(u)
|
||||
# exclude RESHAPE/EXPAND that only serve to broadcast a CONST
|
||||
if u.op in {Ops.RESHAPE, Ops.EXPAND} and len(u.src) >= 1 and u.src[0] in excluded and u is not x: excluded.add(u)
|
||||
if u.op in GroupOp.Movement: excluded.update(s for s in u.src if s.op is Ops.STACK)
|
||||
if u.op in GroupOp.Movement: excluded.update(s for s in u.src if s.op is Ops.STACK and all(x.op is Ops.CONST for x in s.src))
|
||||
for u in toposort:
|
||||
argst = codecs.decode(str(u.arg), "unicode_escape")
|
||||
if u.op in GroupOp.Movement: argst = (mask_to_str if u.op in {Ops.SHRINK, Ops.PAD} else shape_to_str)(u.marg)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue