mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
SLICE -> PAD,SHRINK
This commit is contained in:
parent
9574dd8559
commit
3c4565fa21
11 changed files with 53 additions and 47 deletions
12
README.md
12
README.md
|
|
@ -130,12 +130,12 @@ You no longer need to write mlops for a new accelerator
|
|||
The autodiff stuff is all in mlops now so you can focus on the raw operations
|
||||
|
||||
```
|
||||
Buffer # class of memory on this device
|
||||
unary_op (RELU, EXP, LOG, NEG, SIGN) # A -> A
|
||||
reduce_op (SUM, MAX) # A -> B (smaller size, B has 1 in shape)
|
||||
binary_op (ADD, SUB, MUL, DIV, POW, CMPEQ) # A + B -> C (all the same size)
|
||||
movement_op (RESHAPE, PERMUTE, SLICE, EXPAND, FLIP) # A -> B (different size)
|
||||
processing_op (CONV) # A + B -> C
|
||||
Buffer # class of memory on this device
|
||||
unary_op (RELU, EXP, LOG, NEG, SIGN) # A -> A
|
||||
reduce_op (SUM, MAX) # A -> B (smaller size, B has 1 in shape)
|
||||
binary_op (ADD, SUB, MUL, DIV, POW, CMPEQ) # A + B -> C (all the same size)
|
||||
movement_op (RESHAPE, PERMUTE, PAD, SHRINK, EXPAND, FLIP) # A -> B (different size)
|
||||
processing_op (CONV) # A + B -> C
|
||||
```
|
||||
|
||||
When tinygrad moves to lazy evaluation, optimizations will happen here.
|
||||
|
|
|
|||
|
|
@ -117,6 +117,7 @@ class OpenCLBuffer(GPUBuffer):
|
|||
return self._image
|
||||
|
||||
seen = set()
|
||||
SUPPORTS_PADDING = True
|
||||
def _processing_op(ret, bufs: List[Tuple[str, OpenCLBuffer]]=[], code:str="acc", C=None, start="0.0"):
|
||||
if C is None:
|
||||
# TODO: handle an opencl conv without the conv part
|
||||
|
|
|
|||
|
|
@ -10,31 +10,31 @@ def preprocessing_op(x,w,C):
|
|||
if C.bs > 1 and C.py > 0:
|
||||
# explictly add y-padding for batched inputs
|
||||
# N C H W
|
||||
xs = [(0, s) for s in x.shape]
|
||||
xs[2] = (-C.py, x.shape[2]+C.py)
|
||||
x = x.movement_op(MovementOps.SLICE, xs)
|
||||
xs = [(0, 0) for _ in x.shape]
|
||||
xs[2] = (C.py, C.py)
|
||||
x = x.movement_op(MovementOps.PAD, xs)
|
||||
C = C._replace(iy=C.iy + C.py*2, py=0)
|
||||
|
||||
# hack for non multiples of 4 on C.cin
|
||||
if C.cin % 4 != 0 and not (C.cin == 1 and C.groups%4 == 0):
|
||||
to_add = 4 - (C.cin % 4)
|
||||
ws = [(0, s) for s in w.shape]
|
||||
ws[2] = (0, w.shape[2]+to_add)
|
||||
w = w.movement_op(MovementOps.SLICE, ws)
|
||||
ws = [(0, 0) for _ in w.shape]
|
||||
ws[2] = (0, to_add)
|
||||
w = w.movement_op(MovementOps.PAD, ws)
|
||||
|
||||
x = x.movement_op(MovementOps.RESHAPE, (C.bs, C.groups, C.cin, C.iy, C.ix))
|
||||
xs = [(0, s) for s in x.shape]
|
||||
xs[2] = (0, x.shape[2]+to_add)
|
||||
x = x.movement_op(MovementOps.SLICE, xs)
|
||||
xs = [(0, 0) for _ in x.shape]
|
||||
xs[2] = (0, to_add)
|
||||
x = x.movement_op(MovementOps.PAD, xs)
|
||||
C = C._replace(cin = C.cin + to_add)
|
||||
x = x.movement_op(MovementOps.RESHAPE, (C.bs, C.groups*C.cin, C.iy, C.ix))
|
||||
|
||||
# hack for non multiples of 4 on C.rcout
|
||||
if C.rcout % 4 != 0 and not (C.rcout == 1 and C.groups%4 == 0):
|
||||
added_output_channels = 4 - (C.rcout % 4)
|
||||
ws = [(0, s) for s in w.shape]
|
||||
ws[1] = (0, w.shape[1]+added_output_channels)
|
||||
w = w.movement_op(MovementOps.SLICE, ws)
|
||||
ws = [(0, 0) for _ in w.shape]
|
||||
ws[1] = (0, added_output_channels)
|
||||
w = w.movement_op(MovementOps.PAD, ws)
|
||||
C = C._replace(rcout = C.rcout + added_output_channels, cout = C.groups * (C.rcout + added_output_channels))
|
||||
|
||||
# packed
|
||||
|
|
@ -66,7 +66,7 @@ def postprocessing_op(ret, C, C_initial):
|
|||
ret = ret.movement_op(MovementOps.RESHAPE, (C.bs, C.oy, C.ox, C.groups, C.rcout))
|
||||
xs = [(0, s) for s in ret.shape]
|
||||
xs[4] = (0, ret.shape[4]-added_output_channels)
|
||||
ret = ret.movement_op(MovementOps.SLICE, xs)
|
||||
ret = ret.movement_op(MovementOps.SHRINK, xs)
|
||||
C = C._replace(rcout = C.rcout - added_output_channels, cout = C.groups * (C.rcout - added_output_channels))
|
||||
|
||||
ret = ret.movement_op(MovementOps.RESHAPE, (C.bs, C.oy, C.ox, C.cout))
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ class DumbShapeTracker:
|
|||
def flip(self, *axis):
|
||||
self.t = np.flip(self.t, axis)
|
||||
|
||||
def slice(self, *arg):
|
||||
def shrink(self, *arg):
|
||||
self.t = self.t[tuple([slice(x[0], x[1]) for x in arg])]
|
||||
|
||||
def stride(self, *arg):
|
||||
|
|
@ -104,8 +104,8 @@ class TestSingleShapeTracker(unittest.TestCase):
|
|||
self.st.permute(1,0)
|
||||
assert not self.st.contiguous
|
||||
|
||||
def test_slice(self):
|
||||
self.st.slice((1,2), (0,4))
|
||||
def test_shrink(self):
|
||||
self.st.shrink((1,2), (0,4))
|
||||
assert not self.st.contiguous
|
||||
|
||||
def test_double_permute(self):
|
||||
|
|
@ -179,27 +179,27 @@ class TestShapeTracker(unittest.TestCase):
|
|||
self.apply(lambda x: x.flip(0,1))
|
||||
|
||||
def test_slice_0(self):
|
||||
self.apply(lambda x: x.slice((1, x.shape[0]), (0, x.shape[1])))
|
||||
self.apply(lambda x: x.shrink((1, x.shape[0]), (0, x.shape[1])))
|
||||
|
||||
def test_slice_1(self):
|
||||
self.apply(lambda x: x.slice((0, x.shape[0]), (1, x.shape[1])))
|
||||
self.apply(lambda x: x.shrink((0, x.shape[0]), (1, x.shape[1])))
|
||||
|
||||
def test_slice_1c1(self):
|
||||
self.apply(lambda x: x.slice((0, 1), (0, 1)))
|
||||
self.apply(lambda x: x.shrink((0, 1), (0, 1)))
|
||||
|
||||
def test_slice_1c2(self):
|
||||
self.apply(lambda x: x.slice((1, 2), (1, 2)))
|
||||
self.apply(lambda x: x.shrink((1, 2), (1, 2)))
|
||||
|
||||
def test_double_permute(self):
|
||||
self.apply(lambda x: x.permute(1, 0))
|
||||
self.apply(lambda x: x.permute(1, 0))
|
||||
|
||||
def test_slice_permute(self):
|
||||
self.apply(lambda x: x.slice((0, 2), (2, 4)))
|
||||
self.apply(lambda x: x.shrink((0, 2), (2, 4)))
|
||||
self.apply(lambda x: x.permute(1, 0))
|
||||
|
||||
def test_slice_expand(self):
|
||||
self.apply(lambda x: x.slice((0, 2), (3, 4)))
|
||||
self.apply(lambda x: x.shrink((0, 2), (3, 4)))
|
||||
self.apply(lambda x: x.expand(2, 10))
|
||||
|
||||
def test_double_stride(self):
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ def get_conv_args(x_shape, w_shape, stride=1, groups=1, padding=0, dilation=1, o
|
|||
bs,cin_,iy,ix = x_shape
|
||||
|
||||
# this can change px_ and py_ to make the out_shape right
|
||||
# TOOD: copy padding names from http://nvdla.org/hw/v1/ias/unit_description.html
|
||||
if out_shape is not None:
|
||||
py_ = (out_shape[2] - 1) * sy + 1 + dy * (H-1) - iy - py
|
||||
px_ = (out_shape[3] - 1) * sx + 1 + dx * (W-1) - ix - px
|
||||
|
|
@ -29,8 +30,7 @@ def get_conv_args(x_shape, w_shape, stride=1, groups=1, padding=0, dilation=1, o
|
|||
oy = (iy + py + py_ - dy * (H-1) - 1)//sy + 1
|
||||
ox = (ix + px + px_ - dx * (W-1) - 1)//sx + 1
|
||||
if cin*groups != cin_: raise Exception(f"Input Tensor shape {x_shape} does not match the shape of the weights {w_shape}. ({cin*groups} vs. {cin_})")
|
||||
assert cout % groups == 0
|
||||
assert out_shape is None or out_shape == (bs, cout, oy, ox)
|
||||
assert cout % groups == 0 and (out_shape is None or out_shape == (bs, cout, oy, ox))
|
||||
return ConvArgs(H, W, groups, cout//groups, cin, oy, ox, iy, ix, sy, sx, bs, cout, py, py_, px, px_, dy, dx, (bs, cout, oy, ox))
|
||||
|
||||
def get_available_llops():
|
||||
|
|
|
|||
|
|
@ -42,15 +42,13 @@ class CPUBuffer(np.ndarray):
|
|||
elif op == MovementOps.PERMUTE: return x.permute(arg)
|
||||
elif op == MovementOps.FLIP: return x.flip(arg)
|
||||
elif op == MovementOps.PAD: return x.custompad(arg)
|
||||
elif op == MovementOps.SLICE:
|
||||
padding = [(max(0, -p[0]), max(0, p[1]-x.shape[i])) for i,p in enumerate(arg)]
|
||||
return x.custompad(padding)[tuple(slice(p[0] + padding[i][0], p[1] + padding[i][0], None) for i,p in enumerate(arg))]
|
||||
elif op == MovementOps.SHRINK: return x[tuple(slice(p[0], p[1], None) for i,p in enumerate(arg))]
|
||||
elif op == MovementOps.EXPAND: return x.expand(arg)
|
||||
elif op == MovementOps.STRIDED: return x.contiguous().as_strided([x[0] for x in arg], [x[1] for x in arg])
|
||||
|
||||
PREPAD = True
|
||||
def processing_op(x,op,w,C):
|
||||
assert op == ProcessingOps.CONV, f"{op} isn't supported"
|
||||
x = x.movement_op(MovementOps.SLICE, ((0, x.shape[0]), (0, x.shape[1]), (-C.py, x.shape[2]+C.py_), (-C.px, x.shape[3]+C.px_)))
|
||||
tx = x.movement_op(MovementOps.STRIDED, (
|
||||
(C.bs, C.groups*C.cin*x.shape[2]*x.shape[3]), (C.groups, C.cin*x.shape[2]*x.shape[3]),
|
||||
(C.oy, C.sy*x.shape[3]), (C.ox, C.sx), (C.cin, x.shape[2]*x.shape[3]), (C.H, C.dy*x.shape[3]), (C.W, C.dx)))
|
||||
|
|
|
|||
|
|
@ -111,6 +111,7 @@ class GPUBuffer:
|
|||
def contiguous_op(x): return x if x.st.contiguous else x.unary_op(UnaryOps.NOOP)
|
||||
def movement_op(x, op:MovementOps, arg) -> GPUBuffer: return type(x)(ShapeTracker(x.st).movement_op(op, arg), x)
|
||||
|
||||
SUPPORTS_PADDING = True
|
||||
def processing_op(x, op:ProcessingOps, w:GPUBuffer, C:ConvArgs):
|
||||
assert op == ProcessingOps.CONV, f"{op} isn't supported"
|
||||
return type(x)(C.out_shape)._processing_op([("input", x.contiguous_op()), ("weight", w.contiguous_op())], "acc", C)
|
||||
|
|
|
|||
|
|
@ -14,5 +14,4 @@ class TorchBuffer(torch.Tensor):
|
|||
|
||||
def processing_op(x,op,w,C):
|
||||
assert op == ProcessingOps.CONV, f"{op} isn't supported"
|
||||
x = x.movement_op(MovementOps.SLICE, ((0, x.shape[0]), (0, x.shape[1]), (-C.py, x.shape[2]+C.py_), (-C.px, x.shape[3]+C.px_)))
|
||||
return torch.conv2d(x, w, stride=(C.sy, C.sx), groups=C.groups, dilation=(C.dy, C.dx))
|
||||
|
|
|
|||
|
|
@ -140,10 +140,10 @@ class Permute(Function):
|
|||
class Slice(Function):
|
||||
def forward(ctx, x, arg=None):
|
||||
ctx.narg = tuple((0-p[0], x.shape[i]-p[0]) for i,p in enumerate(arg))
|
||||
return x.movement_op(MovementOps.SLICE, tuple(arg))
|
||||
return x.slice(tuple(arg))
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
return grad_output.movement_op(MovementOps.SLICE, ctx.narg)
|
||||
return grad_output.slice(ctx.narg)
|
||||
|
||||
class Flip(Function):
|
||||
def forward(ctx, x, axis):
|
||||
|
|
@ -172,8 +172,8 @@ class Conv2D(Function):
|
|||
xt = xt.movement_op(MovementOps.RESHAPE, (grad_output.shape[0], grad_output.shape[1], grad_output.shape[2], 1, grad_output.shape[3], 1))
|
||||
xt = xt.movement_op(MovementOps.PAD, ((0,0), (0,0), (0,0), (0,C.sy-1), (0,0), (0,C.sx-1)))
|
||||
xt = xt.movement_op(MovementOps.RESHAPE, (xt.shape[0], xt.shape[1], xt.shape[2]*C.sy, xt.shape[4]*C.sx))
|
||||
wt = w.movement_op(MovementOps.RESHAPE, (C.groups, C.rcout, C.cin, C.H, C.W)).movement_op(MovementOps.PERMUTE, (0, 2, 1, 3, 4))
|
||||
wt = wt.movement_op(MovementOps.RESHAPE, (C.groups*C.cin, C.rcout, C.H, C.W)).movement_op(MovementOps.FLIP, (2, 3))
|
||||
wt = w.movement_op(MovementOps.RESHAPE, (C.groups, C.rcout, C.cin, C.H, C.W)).movement_op(MovementOps.PERMUTE, (0, 2, 1, 3, 4)).movement_op(MovementOps.FLIP, (3, 4))
|
||||
wt = wt.movement_op(MovementOps.RESHAPE, (C.groups*C.cin, C.rcout, C.H, C.W))
|
||||
py, px = (C.H-1)*C.dy - C.py, (C.W-1)*C.dx - C.px
|
||||
Cdx = get_conv_args(xt.shape, wt.shape, out_shape=x.shape, dilation=(C.dy, C.dx), padding=(py, px), groups=C.groups)
|
||||
dx = xt.processing_op(ProcessingOps.CONV, wt, Cdx)
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ sys.setrecursionlimit(10000)
|
|||
UnaryOps = Enum("UnaryOps", ["NOOP", "NEG", "RELU", "EXP", "LOG", "SIGN"])
|
||||
BinaryOps = Enum("BinaryOps", ["ADD", "SUB", "MUL", "DIV", "POW", "CMPEQ"])
|
||||
ReduceOps = Enum("ReduceOps", ["SUM", "MAX"])
|
||||
MovementOps = Enum("MovementOps", ["RESHAPE", "PERMUTE", "SLICE", "EXPAND", "FLIP", "STRIDED", "PAD"])
|
||||
MovementOps = Enum("MovementOps", ["RESHAPE", "PERMUTE", "EXPAND", "FLIP", "STRIDED", "PAD", "SHRINK"])
|
||||
ProcessingOps = Enum("ProcessingOps", ["CONV"])
|
||||
LoadOps = Enum("LoadOps", ["FROMCPU"])
|
||||
|
||||
|
|
@ -224,6 +224,11 @@ class LazyBuffer:
|
|||
def reduce_op(x:LazyBuffer, op:ReduceOps, new_shape:Tuple[int, ...]) -> LazyBuffer:
|
||||
return LazyBuffer(x.device, tuple(new_shape), ReduceOps, LazyOp(op, (x,), tuple(new_shape))) if x.shape != tuple(new_shape) else x
|
||||
|
||||
# syntactic sugar around PAD and SHRINK
|
||||
def slice(x:LazyBuffer, arg):
|
||||
padding = [(max(0, -p[0]), max(0, p[1]-x.shape[i])) for i,p in enumerate(arg)]
|
||||
return x.movement_op(MovementOps.PAD, padding).movement_op(MovementOps.SHRINK, tuple((p[0] + padding[i][0], p[1] + padding[i][0]) for i,p in enumerate(arg)))
|
||||
|
||||
def movement_op(x:LazyBuffer, op:MovementOps, arg) -> LazyBuffer:
|
||||
# TODO: look into why that copy is needed
|
||||
arg = tuple(copy(arg))
|
||||
|
|
@ -231,7 +236,8 @@ class LazyBuffer:
|
|||
# instant nops
|
||||
if op in [MovementOps.RESHAPE, MovementOps.EXPAND] and arg == x.shape: return x
|
||||
if op == MovementOps.PERMUTE and arg == tuple(range(len(x.shape))): return x
|
||||
if op == MovementOps.SLICE and arg == tuple((0,i) for i in x.shape): return x
|
||||
if op == MovementOps.SHRINK and arg == tuple((0,i) for i in x.shape): return x
|
||||
if op == MovementOps.PAD and arg == tuple((0,0) for _ in x.shape): return x
|
||||
if op == MovementOps.FLIP and all(s == 1 or i not in arg for i,s in enumerate(x.shape)): return x
|
||||
|
||||
# two reshapes in a row is one reshape
|
||||
|
|
@ -243,12 +249,10 @@ class LazyBuffer:
|
|||
# some permutes are actually just reshapes
|
||||
if op == MovementOps.PERMUTE and ShapeTracker(x.shape).movement_op(op, arg).contiguous: return x.movement_op(MovementOps.RESHAPE, tuple(x.shape[i] for i in arg))
|
||||
|
||||
# TODO: SHUFFLE_SLICE_OPS is okay if it's a shrink
|
||||
if (SHUFFLE_MOVEMENT_OPS or (SHUFFLE_RESHAPE_OPS and op == MovementOps.RESHAPE)) and x.optype == BinaryOps and x.realized is None and (SHUFFLE_PAD_OPS or op not in [MovementOps.SLICE, MovementOps.PAD]):
|
||||
if (SHUFFLE_MOVEMENT_OPS or (SHUFFLE_RESHAPE_OPS and op == MovementOps.RESHAPE)) and x.optype == BinaryOps and x.realized is None and (SHUFFLE_PAD_OPS or op != MovementOps.PAD):
|
||||
# if this MovementOp is being applied to a BinaryOp, apply the MovementOp to all the BinaryOp inputs instead
|
||||
def replace_with_movement_op(y:Union[LazyOp, LazyBuffer]) -> LazyBuffer:
|
||||
if isinstance(y, LazyBuffer): return y.movement_op(op, arg)
|
||||
assert isinstance(y.op, BinaryOps) or isinstance(y.op, UnaryOps)
|
||||
return elementwise_op(y.op, *[replace_with_movement_op(z) for z in y.src])
|
||||
return replace_with_movement_op(x.op)
|
||||
|
||||
|
|
@ -265,9 +269,11 @@ class LazyBuffer:
|
|||
return ret
|
||||
|
||||
def processing_op(x:LazyBuffer, op:ProcessingOps, w:LazyBuffer, C:ConvArgs) -> LazyBuffer:
|
||||
# TODO: fixup C?
|
||||
if NOCONV or not getattr(x.dbuffer, "SUPPORTS_PADDING", False): x = x.slice(((0, x.shape[0]), (0, x.shape[1]), (-C.py, x.shape[2]+C.py_), (-C.px, x.shape[3]+C.px_)))
|
||||
|
||||
if NOCONV:
|
||||
# universal conv, just mul and reduce
|
||||
x = x.movement_op(MovementOps.SLICE, ((0, x.shape[0]), (0, x.shape[1]), (-C.py, x.shape[2]+C.py_), (-C.px, x.shape[3]+C.px_)))
|
||||
# TODO: is there any way to replace strided with other movement ops?
|
||||
x = x.movement_op(MovementOps.STRIDED, (
|
||||
(C.bs, C.groups*C.cin*x.shape[2]*x.shape[3]), (C.groups, C.cin*x.shape[2]*x.shape[3]),
|
||||
|
|
|
|||
|
|
@ -126,9 +126,10 @@ class ShapeTracker:
|
|||
# TODO: take this functionality out of slice
|
||||
def pad(self, *arg):
|
||||
assert all((b>=0 and e>=0) for b,e in arg) and len(arg) == len(self.shape)
|
||||
return self.slice(*[(-b,s+e) for s,(b,e) in zip(self.shape, arg)])
|
||||
return self.shrink(*[(-b,s+e) for s,(b,e) in zip(self.shape, arg)])
|
||||
|
||||
def slice(self, *arg):
|
||||
# TODO: take the pad functionality out of shrink
|
||||
def shrink(self, *arg):
|
||||
assert len(arg) == len(self.shape)
|
||||
offset = sum([self.strides[i]*x for i,(x,_) in enumerate(arg)])
|
||||
zeroview = ZeroView(self.shape, arg)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue