refactor: canonicalize axis

This commit is contained in:
George Hotz 2021-11-29 17:29:18 -05:00
commit 8a02bd56a1
3 changed files with 22 additions and 20 deletions

View file

@ -55,28 +55,25 @@ class Exp(Function):
# ************* reduce ops *************
class Sum(Function):
def forward(ctx, input, axis=None):
def forward(ctx, input, axis):
ctx.save_for_backward(input, axis)
return input.sum(axis) if axis != None else input.sum().reshape((1,))
return input.sum(axis)
def backward(ctx, grad_output):
input, axis = ctx.saved_tensors
if isinstance(axis, int): axis = [axis]
shape = [1 if axis is None or i in axis else input.shape[i] for i in range(len(input.shape))]
shape = [1 if i in axis else input.shape[i] for i in range(len(input.shape))]
return grad_output.reshape(shape).expand(input.shape)
class Max(Function):
def forward(ctx, inp, axis=None):
if isinstance(axis, int): axis = [axis]
ret = inp.amax(axis=None if axis is None else tuple(axis), keepdims=True)
def forward(ctx, inp, axis):
ret = inp.amax(axis=axis, keepdims=True)
ctx.save_for_backward(inp, axis, ret)
if axis is not None:
ret = ret.reshape([inp.shape[i] for i in range(len(inp.shape)) if i not in axis])
ret = ret.reshape([inp.shape[i] for i in range(len(inp.shape)) if i not in axis])
return ret
def backward(ctx, grad_output):
input, axis, ret = ctx.saved_tensors
shape = [1 if axis is None or i in axis else input.shape[i] for i in range(len(input.shape))]
shape = [1 if i in axis else input.shape[i] for i in range(len(input.shape))]
ret2 = (input==ret.reshape(shape))
div = ret2.sum(axis=tuple(axis), keepdims=True) if axis is not None else ret2.sum()
return ret2*grad_output.reshape(shape)/div.type(input.dtype)

View file

@ -138,31 +138,27 @@ def reduce_op(ctx, code, code2, inp, axis=None, start="0.0"):
class Sum(Function):
def forward(ctx, input, axis=None):
if isinstance(axis, int): axis = [axis]
ctx.save_for_backward(input, axis)
ret = reduce_op(ctx, "out += a", "out", input, axis=axis)
if axis is not None:
ret.shape = tuple([input.shape[i] for i in range(len(input.shape)) if i not in axis])
ret.shape = tuple([input.shape[i] for i in range(len(input.shape)) if i not in axis])
return ret
def backward(ctx, grad_output):
input, axis = ctx.saved_tensors
shape = [1 if axis is None or i in axis else input.shape[i] for i in range(len(input.shape))]
shape = [1 if i in axis else input.shape[i] for i in range(len(input.shape))]
output = GPUBuffer(shape, hostbuf=grad_output)
return binary_op(ctx, 'a+b', output, buffer_new(ctx, input.shape, zero=True))
class Max(Function):
def forward(ctx, input, axis=None):
if isinstance(axis, int): axis = [axis]
ret = reduce_op(ctx, "out = max(a,out)", "out", input, axis=axis, start="-INFINITY")
ctx.save_for_backward(input, axis, ret)
if axis is not None:
ret.shape = tuple([input.shape[i] for i in range(len(input.shape)) if i not in axis])
ret.shape = tuple([input.shape[i] for i in range(len(input.shape)) if i not in axis])
return ret
def backward(ctx, grad_output):
input, axis, ret = ctx.saved_tensors
shape = [1 if axis is None or i in axis else input.shape[i] for i in range(len(input.shape))]
shape = [1 if i in axis else input.shape[i] for i in range(len(input.shape))]
ret2 = binary_op(ctx, "1.0*(a==b)", input, GPUBuffer(shape, ret))
div = reduce_op(ctx, "out += a", "out+1e-10", ret2, axis=axis)
ret3 = binary_op(ctx, "a/b", ret2, GPUBuffer(shape, div))

View file

@ -206,9 +206,18 @@ class Tensor:
def dot(self, w):
return self.matmul(w)
# override for sum to support keepdim
def _canonicalize_axis(self, axis):
if axis is None: axis = range(len(self.shape))
if isinstance(axis, int): axis = [axis]
return tuple([x if x >= 0 else x+len(self.shape) for x in axis])
def sum(self, axis=None):
return self._sum(axis=axis)
ret = self._sum(axis=self._canonicalize_axis(axis))
return ret.reshape(shape=(1,)) if ret.shape == () else ret
def max(self, axis=None):
ret = self._max(axis=self._canonicalize_axis(axis))
return ret.reshape(shape=(1,)) if ret.shape == () else ret
def mean(self, axis=None):
out = self.sum(axis=axis)