mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
refactor: canonicalize axis
This commit is contained in:
parent
70544e7e9f
commit
8a02bd56a1
3 changed files with 22 additions and 20 deletions
|
|
@ -55,28 +55,25 @@ class Exp(Function):
|
|||
# ************* reduce ops *************
|
||||
|
||||
class Sum(Function):
|
||||
def forward(ctx, input, axis=None):
|
||||
def forward(ctx, input, axis):
|
||||
ctx.save_for_backward(input, axis)
|
||||
return input.sum(axis) if axis != None else input.sum().reshape((1,))
|
||||
return input.sum(axis)
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
input, axis = ctx.saved_tensors
|
||||
if isinstance(axis, int): axis = [axis]
|
||||
shape = [1 if axis is None or i in axis else input.shape[i] for i in range(len(input.shape))]
|
||||
shape = [1 if i in axis else input.shape[i] for i in range(len(input.shape))]
|
||||
return grad_output.reshape(shape).expand(input.shape)
|
||||
|
||||
class Max(Function):
|
||||
def forward(ctx, inp, axis=None):
|
||||
if isinstance(axis, int): axis = [axis]
|
||||
ret = inp.amax(axis=None if axis is None else tuple(axis), keepdims=True)
|
||||
def forward(ctx, inp, axis):
|
||||
ret = inp.amax(axis=axis, keepdims=True)
|
||||
ctx.save_for_backward(inp, axis, ret)
|
||||
if axis is not None:
|
||||
ret = ret.reshape([inp.shape[i] for i in range(len(inp.shape)) if i not in axis])
|
||||
ret = ret.reshape([inp.shape[i] for i in range(len(inp.shape)) if i not in axis])
|
||||
return ret
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
input, axis, ret = ctx.saved_tensors
|
||||
shape = [1 if axis is None or i in axis else input.shape[i] for i in range(len(input.shape))]
|
||||
shape = [1 if i in axis else input.shape[i] for i in range(len(input.shape))]
|
||||
ret2 = (input==ret.reshape(shape))
|
||||
div = ret2.sum(axis=tuple(axis), keepdims=True) if axis is not None else ret2.sum()
|
||||
return ret2*grad_output.reshape(shape)/div.type(input.dtype)
|
||||
|
|
|
|||
|
|
@ -138,31 +138,27 @@ def reduce_op(ctx, code, code2, inp, axis=None, start="0.0"):
|
|||
|
||||
class Sum(Function):
|
||||
def forward(ctx, input, axis=None):
|
||||
if isinstance(axis, int): axis = [axis]
|
||||
ctx.save_for_backward(input, axis)
|
||||
ret = reduce_op(ctx, "out += a", "out", input, axis=axis)
|
||||
if axis is not None:
|
||||
ret.shape = tuple([input.shape[i] for i in range(len(input.shape)) if i not in axis])
|
||||
ret.shape = tuple([input.shape[i] for i in range(len(input.shape)) if i not in axis])
|
||||
return ret
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
input, axis = ctx.saved_tensors
|
||||
shape = [1 if axis is None or i in axis else input.shape[i] for i in range(len(input.shape))]
|
||||
shape = [1 if i in axis else input.shape[i] for i in range(len(input.shape))]
|
||||
output = GPUBuffer(shape, hostbuf=grad_output)
|
||||
return binary_op(ctx, 'a+b', output, buffer_new(ctx, input.shape, zero=True))
|
||||
|
||||
class Max(Function):
|
||||
def forward(ctx, input, axis=None):
|
||||
if isinstance(axis, int): axis = [axis]
|
||||
ret = reduce_op(ctx, "out = max(a,out)", "out", input, axis=axis, start="-INFINITY")
|
||||
ctx.save_for_backward(input, axis, ret)
|
||||
if axis is not None:
|
||||
ret.shape = tuple([input.shape[i] for i in range(len(input.shape)) if i not in axis])
|
||||
ret.shape = tuple([input.shape[i] for i in range(len(input.shape)) if i not in axis])
|
||||
return ret
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
input, axis, ret = ctx.saved_tensors
|
||||
shape = [1 if axis is None or i in axis else input.shape[i] for i in range(len(input.shape))]
|
||||
shape = [1 if i in axis else input.shape[i] for i in range(len(input.shape))]
|
||||
ret2 = binary_op(ctx, "1.0*(a==b)", input, GPUBuffer(shape, ret))
|
||||
div = reduce_op(ctx, "out += a", "out+1e-10", ret2, axis=axis)
|
||||
ret3 = binary_op(ctx, "a/b", ret2, GPUBuffer(shape, div))
|
||||
|
|
|
|||
|
|
@ -206,9 +206,18 @@ class Tensor:
|
|||
def dot(self, w):
|
||||
return self.matmul(w)
|
||||
|
||||
# override for sum to support keepdim
|
||||
def _canonicalize_axis(self, axis):
|
||||
if axis is None: axis = range(len(self.shape))
|
||||
if isinstance(axis, int): axis = [axis]
|
||||
return tuple([x if x >= 0 else x+len(self.shape) for x in axis])
|
||||
|
||||
def sum(self, axis=None):
|
||||
return self._sum(axis=axis)
|
||||
ret = self._sum(axis=self._canonicalize_axis(axis))
|
||||
return ret.reshape(shape=(1,)) if ret.shape == () else ret
|
||||
|
||||
def max(self, axis=None):
|
||||
ret = self._max(axis=self._canonicalize_axis(axis))
|
||||
return ret.reshape(shape=(1,)) if ret.shape == () else ret
|
||||
|
||||
def mean(self, axis=None):
|
||||
out = self.sum(axis=axis)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue