mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
refactor sparse_categorical_crossentropy (#4406)
factor out the -1 * and / loss_mask.sum() for both smoothing and non-smoothing terms
This commit is contained in:
parent
3401734e54
commit
c7368515d2
1 changed files with 3 additions and 3 deletions
|
|
@ -1324,9 +1324,9 @@ class Tensor:
|
|||
# NOTE: self is a logits input
|
||||
log_probs, loss_mask = self.log_softmax(), (Y != ignore_index)
|
||||
y_counter = Tensor.arange(self.shape[-1], requires_grad=False, device=self.device).unsqueeze(0).expand(Y.numel(), self.shape[-1])
|
||||
y = ((y_counter == Y.flatten().reshape(-1, 1)).where(-1, 0) * loss_mask.reshape(-1, 1)).reshape(*Y.shape, self.shape[-1])
|
||||
smoothing = -1 * label_smoothing * (log_probs.mean(-1) * loss_mask).sum() / loss_mask.sum()
|
||||
return (1 - label_smoothing) * (log_probs * y).sum() / loss_mask.sum() + smoothing
|
||||
y = ((y_counter == Y.flatten().reshape(-1, 1)) * loss_mask.reshape(-1, 1)).reshape(*Y.shape, self.shape[-1])
|
||||
smoothing = label_smoothing * (log_probs.mean(-1) * loss_mask).sum()
|
||||
return -((1 - label_smoothing) * (log_probs * y).sum() + smoothing) / loss_mask.sum()
|
||||
|
||||
# ***** cast ops *****
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue