Revert "sparse cat cross entropy (#1591)" (#1596)

This reverts commit f0ee850e98.
This commit is contained in:
George Hotz 2023-08-21 10:04:26 -07:00 committed by GitHub
commit 2e60920317
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 28 additions and 35 deletions

View file

@ -115,13 +115,13 @@ class TinyNet:
x = self.l1(x)
x = x.leakyrelu()
x = self.l2(x)
return x
return x.log_softmax()
net = TinyNet()
```
We can see that the forward pass of our neural network is just the sequence of operations performed on the input tensor `x`.
We can also see that functional operations like `leakyrelu` are not defined as classes and instead are just methods we can just call.
We can also see that functional operations like `leakyrelu` and `log_softmax` are not defined as classes and instead are just methods we can just call.
Finally, we just initialize an instance of our neural network, and we are ready to start training it.
## Training
@ -137,18 +137,18 @@ First we need to set the training flag in `Tensor`:
Tensor.training = True
```
For our loss function we will be using sparse categorical cross entropy loss.
For our loss function we will be using cross entropy loss.
```python
# from tinygrad.tensor import sparse_categorical_crossentropy
def sparse_categorical_crossentropy(out, Y, ignore_index=-1):
loss_mask = Y != ignore_index
# from extra.training import sparse_categorical_crossentropy
def cross_entropy(out, Y):
num_classes = out.shape[-1]
y_counter = Tensor.arange(num_classes, requires_grad=False).unsqueeze(0).expand(Y.numel(), num_classes)
y = (y_counter == Y.flatten().reshape(-1, 1)).where(-1.0, 0)
y = y * loss_mask.reshape(-1, 1)
y = y.reshape(*Y.shape, num_classes)
return out.log_softmax().mul(y).sum() / loss_mask.sum()
YY = Y.flatten().astype(np.int32)
y = np.zeros((YY.shape[0], num_classes), np.float32)
y[range(y.shape[0]),YY] = -1.0*num_classes
y = y.reshape(list(Y.shape)+[num_classes])
y = Tensor(y)
return out.mul(y).mean()
```
As we can see in this implementation of cross entropy loss, there are certain operations that tinygrad does not support.
@ -187,13 +187,13 @@ for step in range(1000):
samp = np.random.randint(0, X_train.shape[0], size=(64))
batch = Tensor(X_train[samp], requires_grad=False)
# get the corresponding labels
labels = Tensor(Y_train[samp])
labels = Y_train[samp]
# forward pass
out = net(batch)
# compute loss
loss = sparse_categorical_crossentropy(out, labels)
loss = cross_entropy(out, labels)
# zero gradients
opt.zero_grad()

View file

@ -8,7 +8,7 @@ from tinygrad.nn import BatchNorm2d, optim
from tinygrad.helpers import getenv
from extra.datasets import fetch_mnist
from extra.augment import augment_img
from extra.training import train, evaluate
from extra.training import train, evaluate, sparse_categorical_crossentropy
GPU = getenv("GPU")
QUICK = getenv("QUICK")
DEBUG = getenv("DEBUG")
@ -93,7 +93,7 @@ class BigConvNet:
x1 = x.avg_pool2d(kernel_size=(14,14)).reshape(shape=(-1,128)) #global
x2 = x.max_pool2d(kernel_size=(14,14)).reshape(shape=(-1,128)) #global
xo = x1.dot(self.weight1) + x2.dot(self.weight2)
return xo
return xo.log_softmax()
if __name__ == "__main__":
@ -102,7 +102,7 @@ if __name__ == "__main__":
BS = 32
lmbd = 0.00025
lossfn = lambda out,y: out.sparse_categorical_crossentropy(y) + lmbd*(model.weight1.abs() + model.weight2.abs()).sum()
lossfn = lambda out,y: sparse_categorical_crossentropy(out, y) + lmbd*(model.weight1.abs() + model.weight2.abs()).sum()
X_train, Y_train, X_test, Y_test = fetch_mnist()
X_train = X_train.reshape(-1, 28, 28).astype(np.uint8)
X_test = X_test.reshape(-1, 28, 28).astype(np.uint8)

View file

@ -3,14 +3,24 @@ from tqdm import trange
from tinygrad.tensor import Tensor, Device
from tinygrad.helpers import getenv
def train(model, X_train, Y_train, optim, steps, BS=128, lossfn=lambda out,y: out.sparse_categorical_crossentropy(y),
def sparse_categorical_crossentropy(out, Y):
num_classes = out.shape[-1]
YY = Y.flatten().astype(np.int32)
y = np.zeros((YY.shape[0], num_classes), np.float32)
# correct loss for NLL, torch NLL loss returns one per row
y[range(y.shape[0]),YY] = -1.0*num_classes
y = y.reshape(list(Y.shape)+[num_classes])
y = Tensor(y)
return out.mul(y).mean()
def train(model, X_train, Y_train, optim, steps, BS=128, lossfn=sparse_categorical_crossentropy,
transform=lambda x: x, target_transform=lambda x: x, noloss=False):
Tensor.training = True
losses, accuracies = [], []
for i in (t := trange(steps, disable=getenv('CI', False))):
samp = np.random.randint(0, X_train.shape[0], size=(BS))
x = Tensor(transform(X_train[samp]), requires_grad=False)
y = Tensor(target_transform(Y_train[samp]))
y = target_transform(Y_train[samp])
# network
out = model.forward(x) if hasattr(model, 'forward') else model(x)

View file

@ -12,17 +12,6 @@ import pytest
pytestmark = [pytest.mark.exclude_cuda]
class TestNN(unittest.TestCase):
def test_sparse_cat_cross_entropy(self):
input = torch.randn(3, 5)
target = torch.empty(3, dtype=torch.long).random_(5)
loss_fun = torch.nn.CrossEntropyLoss(reduction='mean')
loss = loss_fun(input, target)
input_tiny = Tensor(input.detach().numpy())
taret_tiny = Tensor(target.detach().numpy())
loss_tiny = input_tiny.sparse_categorical_crossentropy(taret_tiny)
np.testing.assert_allclose(loss_tiny.numpy(), loss.detach().numpy(), atol=1e-5, rtol=1e-6)
def test_batchnorm2d(self, training=False):
szs = [4, 8, 16, 32]

View file

@ -708,12 +708,6 @@ class Tensor:
if attn_mask is not None and attn_mask.dtype == dtypes.bool: attn_mask = (attn_mask == 0).where(-float("inf"), attn_mask)
return (self @ key.transpose(-2,-1) / sqrt(self.shape[-1]) + attn_mask).softmax(-1).dropout(dropout_p) @ value
def sparse_categorical_crossentropy(self, Y, ignore_index=-1) -> Tensor:
loss_mask = Y != ignore_index
y_counter = Tensor.arange(self.shape[-1], requires_grad=False).unsqueeze(0).expand(Y.numel(), self.shape[-1])
y = ((y_counter == Y.flatten().reshape(-1, 1)).where(-1.0, 0) * loss_mask.reshape(-1, 1)).reshape(*Y.shape, self.shape[-1])
return self.log_softmax().mul(y).sum() / loss_mask.sum()
# ***** cast ops *****
def cast(self, dtype:DType) -> Tensor: return mlops.Cast.apply(self, dtype=dtype) if self.dtype != dtype else self