mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
max op
This commit is contained in:
parent
bcb3ceeca3
commit
36579f66bf
3 changed files with 27 additions and 2 deletions
|
|
@ -80,6 +80,12 @@ class TestOps(unittest.TestCase):
|
|||
helper_test_op([(10,45,65), (10,65,45)], lambda x,y: x @ y, Tensor.dot, device=self.device)
|
||||
def test_sum(self):
|
||||
helper_test_op([(45,3)], lambda x: x.sum(), Tensor.sum, device=self.device)
|
||||
@cpu_only
|
||||
def test_max(self):
|
||||
helper_test_op([(45,3)], lambda x: x.max(), Tensor.max, device=self.device)
|
||||
@cpu_only
|
||||
def test_max_axis(self):
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.max(axis=1)[0], lambda x: Tensor.max(x, axis=1), device=self.device)
|
||||
def test_sum_axis(self):
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(1,2)), lambda x: Tensor.sum(x, axis=(1,2)), device=self.device)
|
||||
def test_mean_axis(self):
|
||||
|
|
|
|||
|
|
@ -71,6 +71,24 @@ class Sum(Function):
|
|||
return grad_output.reshape(shape) + np.zeros_like(input)
|
||||
register('sum', Sum)
|
||||
|
||||
class Max(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input, axis=None):
|
||||
am = input.argmax(axis=axis)
|
||||
if axis is not None:
|
||||
am = np.expand_dims(am, axis=axis)
|
||||
else:
|
||||
am = np.array([am])
|
||||
ctx.save_for_backward(input.shape, am, axis)
|
||||
return np.take_along_axis(input, am, axis=axis).squeeze(axis=axis)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
shape, am, axis = ctx.saved_tensors
|
||||
ret = np.zeros(shape, dtype=np.float32)
|
||||
np.put_along_axis(ret, am, 1/np.prod(am.shape), axis=axis)
|
||||
return ret
|
||||
register('max', Max)
|
||||
|
||||
# ************* GEMM *************
|
||||
|
||||
|
|
|
|||
|
|
@ -226,9 +226,10 @@ class Tensor:
|
|||
return self.relu() - (-neg_slope*self).relu()
|
||||
|
||||
def softmax(self):
|
||||
# Replace with (self - self.max())
|
||||
ns = list(self.shape)[:-1]+[1]
|
||||
#e = (self - self.max(axis=len(self.shape)-1).reshape(shape=ns)).exp()
|
||||
e = self.exp()
|
||||
ss = e.sum(axis=len(self.shape)-1).reshape(shape=list(self.shape)[:-1]+[1])
|
||||
ss = e.sum(axis=len(self.shape)-1).reshape(shape=ns)
|
||||
return e.div(ss)
|
||||
|
||||
def logsoftmax(self):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue