Revert "[pr] match torch rmsnorm (#16122)" (#16144)

This reverts commit 692257dd70.
This commit is contained in:
chenyu 2026-05-11 17:53:42 -04:00 committed by GitHub
commit 0b02fb6797
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 27 additions and 31 deletions

View file

@ -52,7 +52,6 @@
::: tinygrad.Tensor.linear
::: tinygrad.Tensor.sequential
::: tinygrad.Tensor.layernorm
::: tinygrad.Tensor.rmsnorm
::: tinygrad.Tensor.batchnorm
::: tinygrad.Tensor.dropout
::: tinygrad.Tensor.one_hot

View file

@ -338,11 +338,25 @@ class TestNN(unittest.TestCase):
np.testing.assert_allclose(layer.bias.grad.numpy(), torch_layer.bias.grad.detach().numpy(), atol=1e-3, rtol=1e-3)
def test_rmsnorm(self):
B, T, embed_size = 4, 10, 20
class TorchRMSNorm(torch.nn.Module):
# https://github.com/meta-llama/llama/blob/be327c427cc5e89cc1d3ab3d3fec4484df771245/llama/model.py#L34C1-L77C36
def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine: bool = True):
super().__init__()
self.eps = eps
self.elementwise_affine = elementwise_affine
self.weight = torch.nn.Parameter(torch.ones(dim)) if elementwise_affine else None
torch_layer = torch.nn.RMSNorm(embed_size, eps=1e-5)
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output if self.weight is None else output * self.weight
B, T, embed_size = 4, 10, 20
torch_layer = TorchRMSNorm(embed_size)
layer = RMSNorm(embed_size)
layer.weight = Tensor(torch_layer.weight.detach().numpy(), requires_grad=True)
layer.weight.requires_grad = True
for _ in range(10):
# forward
@ -355,10 +369,10 @@ class TestNN(unittest.TestCase):
torch_z.sum().backward()
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)
np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
np.testing.assert_allclose(layer.weight.grad.numpy(), torch_layer.weight.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=1e-3, rtol=1e-3)
np.testing.assert_allclose(layer.weight.grad.numpy(), torch_layer.weight.grad.detach().numpy(), atol=2e-3, rtol=1e-3)
torch_layer = torch.nn.RMSNorm(embed_size, eps=1e-5, elementwise_affine=False)
torch_layer = TorchRMSNorm(embed_size, elementwise_affine=False)
layer = RMSNorm(embed_size, elementwise_affine=False)
for _ in range(10):
@ -372,7 +386,7 @@ class TestNN(unittest.TestCase):
torch_z.sum().backward()
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)
np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=1e-3, rtol=1e-3)
def test_embedding(self):
B, T, embed_size, vocab_size = 4, 10, 20, 28

View file

@ -1305,23 +1305,6 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
y = (self - self.mean(axis, keepdim=True))
return y.mul((y*y).mean(axis, keepdim=True).add(eps).rsqrt())
def rmsnorm(self, axis:int|tuple[int, ...]=-1, eps=1e-5) -> Self:
"""
Applies Root Mean Square Normalization to input.
- Paper: https://arxiv.org/abs/1910.07467
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.randn(8, 10, 16) * 2 + 8
print(t.mean().item(), t.std().item())
```
```python exec="true" source="above" session="tensor" result="python"
t = t.rmsnorm()
print(t.mean().item(), t.std().item())
```
"""
return self.mul((self.square().mean(axis, keepdim=True).add(eps)).rsqrt())
def batchnorm(self, weight:Self|None, bias:Self|None, mean:Self, invstd:Self, axis:int|tuple[int, ...]=1) -> Self:
"""
Applies Batch Normalization over a mini-batch of inputs.

View file

@ -293,14 +293,14 @@ class RMSNorm:
print(norm(t).numpy())
```
"""
def __init__(self, normalized_shape:int|tuple[int, ...], eps:float=1e-5, elementwise_affine:bool=True):
self.normalized_shape: tuple[int, ...] = make_tuple(normalized_shape, 1)
self.axis, self.eps = tuple(-1-i for i in range(len(self.normalized_shape))), eps
self.weight: Tensor|None = Tensor.ones(*self.normalized_shape) if elementwise_affine else None
def __init__(self, dim:int, eps=1e-6, elementwise_affine=True):
self.eps = eps
self.weight = Tensor.ones(dim) if elementwise_affine else None
def _norm(self, x:Tensor) -> Tensor: return x * (x.square().mean(-1, keepdim=True) + self.eps).rsqrt()
def __call__(self, x:Tensor) -> Tensor:
assert self.normalized_shape == x.shape[-len(self.normalized_shape):]
x = x.rmsnorm(axis=self.axis, eps=self.eps)
x = self._norm(x.float()).cast(x.dtype)
return x if self.weight is None else x * self.weight
from tinygrad.uop.ops import UOp, KernelInfo, Ops, AxisType