mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
This reverts commit 692257dd70.
This commit is contained in:
parent
fbe8be0b8b
commit
0b02fb6797
4 changed files with 27 additions and 31 deletions
|
|
@ -52,7 +52,6 @@
|
|||
::: tinygrad.Tensor.linear
|
||||
::: tinygrad.Tensor.sequential
|
||||
::: tinygrad.Tensor.layernorm
|
||||
::: tinygrad.Tensor.rmsnorm
|
||||
::: tinygrad.Tensor.batchnorm
|
||||
::: tinygrad.Tensor.dropout
|
||||
::: tinygrad.Tensor.one_hot
|
||||
|
|
|
|||
|
|
@ -338,11 +338,25 @@ class TestNN(unittest.TestCase):
|
|||
np.testing.assert_allclose(layer.bias.grad.numpy(), torch_layer.bias.grad.detach().numpy(), atol=1e-3, rtol=1e-3)
|
||||
|
||||
def test_rmsnorm(self):
|
||||
B, T, embed_size = 4, 10, 20
|
||||
class TorchRMSNorm(torch.nn.Module):
|
||||
# https://github.com/meta-llama/llama/blob/be327c427cc5e89cc1d3ab3d3fec4484df771245/llama/model.py#L34C1-L77C36
|
||||
def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine: bool = True):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.elementwise_affine = elementwise_affine
|
||||
self.weight = torch.nn.Parameter(torch.ones(dim)) if elementwise_affine else None
|
||||
|
||||
torch_layer = torch.nn.RMSNorm(embed_size, eps=1e-5)
|
||||
def _norm(self, x):
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
output = self._norm(x.float()).type_as(x)
|
||||
return output if self.weight is None else output * self.weight
|
||||
|
||||
B, T, embed_size = 4, 10, 20
|
||||
torch_layer = TorchRMSNorm(embed_size)
|
||||
layer = RMSNorm(embed_size)
|
||||
layer.weight = Tensor(torch_layer.weight.detach().numpy(), requires_grad=True)
|
||||
layer.weight.requires_grad = True
|
||||
|
||||
for _ in range(10):
|
||||
# forward
|
||||
|
|
@ -355,10 +369,10 @@ class TestNN(unittest.TestCase):
|
|||
torch_z.sum().backward()
|
||||
|
||||
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)
|
||||
np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
|
||||
np.testing.assert_allclose(layer.weight.grad.numpy(), torch_layer.weight.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
|
||||
np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=1e-3, rtol=1e-3)
|
||||
np.testing.assert_allclose(layer.weight.grad.numpy(), torch_layer.weight.grad.detach().numpy(), atol=2e-3, rtol=1e-3)
|
||||
|
||||
torch_layer = torch.nn.RMSNorm(embed_size, eps=1e-5, elementwise_affine=False)
|
||||
torch_layer = TorchRMSNorm(embed_size, elementwise_affine=False)
|
||||
layer = RMSNorm(embed_size, elementwise_affine=False)
|
||||
|
||||
for _ in range(10):
|
||||
|
|
@ -372,7 +386,7 @@ class TestNN(unittest.TestCase):
|
|||
torch_z.sum().backward()
|
||||
|
||||
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)
|
||||
np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
|
||||
np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=1e-3, rtol=1e-3)
|
||||
|
||||
def test_embedding(self):
|
||||
B, T, embed_size, vocab_size = 4, 10, 20, 28
|
||||
|
|
|
|||
|
|
@ -1305,23 +1305,6 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
|
|||
y = (self - self.mean(axis, keepdim=True))
|
||||
return y.mul((y*y).mean(axis, keepdim=True).add(eps).rsqrt())
|
||||
|
||||
def rmsnorm(self, axis:int|tuple[int, ...]=-1, eps=1e-5) -> Self:
|
||||
"""
|
||||
Applies Root Mean Square Normalization to input.
|
||||
|
||||
- Paper: https://arxiv.org/abs/1910.07467
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor.randn(8, 10, 16) * 2 + 8
|
||||
print(t.mean().item(), t.std().item())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = t.rmsnorm()
|
||||
print(t.mean().item(), t.std().item())
|
||||
```
|
||||
"""
|
||||
return self.mul((self.square().mean(axis, keepdim=True).add(eps)).rsqrt())
|
||||
|
||||
def batchnorm(self, weight:Self|None, bias:Self|None, mean:Self, invstd:Self, axis:int|tuple[int, ...]=1) -> Self:
|
||||
"""
|
||||
Applies Batch Normalization over a mini-batch of inputs.
|
||||
|
|
|
|||
|
|
@ -293,14 +293,14 @@ class RMSNorm:
|
|||
print(norm(t).numpy())
|
||||
```
|
||||
"""
|
||||
def __init__(self, normalized_shape:int|tuple[int, ...], eps:float=1e-5, elementwise_affine:bool=True):
|
||||
self.normalized_shape: tuple[int, ...] = make_tuple(normalized_shape, 1)
|
||||
self.axis, self.eps = tuple(-1-i for i in range(len(self.normalized_shape))), eps
|
||||
self.weight: Tensor|None = Tensor.ones(*self.normalized_shape) if elementwise_affine else None
|
||||
def __init__(self, dim:int, eps=1e-6, elementwise_affine=True):
|
||||
self.eps = eps
|
||||
self.weight = Tensor.ones(dim) if elementwise_affine else None
|
||||
|
||||
def _norm(self, x:Tensor) -> Tensor: return x * (x.square().mean(-1, keepdim=True) + self.eps).rsqrt()
|
||||
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
assert self.normalized_shape == x.shape[-len(self.normalized_shape):]
|
||||
x = x.rmsnorm(axis=self.axis, eps=self.eps)
|
||||
x = self._norm(x.float()).cast(x.dtype)
|
||||
return x if self.weight is None else x * self.weight
|
||||
|
||||
from tinygrad.uop.ops import UOp, KernelInfo, Ops, AxisType
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue