tinygrad/test/test_linalg.py
kevvz e2873a3a41
[bounty] Muon optim (#11414)
* newton schulz

* add muon + move newton schulz to tensor

* compact newton schulz

* better tests

* cleanup

* add comments for muon

* cleanup

* add export with tests

* match muon optim with test optim

* cleanup

* unsed import

* correct comment

* whitespace

* move export

* muon test fix

* match reference impl + tests

* remove export by moving muon device

* add credit

* cleanup

* remove print

* spacing

* spacing

* comma

* cleanup

* removal

* fix tests + optim momentum

* consistent is not/ not

* more consistency

* fix test

* cleanup

* fix the nones

* remove comment

* cast

* comment

* comment

* muon teeny test

* muon flag beautiful mnist

* set steps

* steps as hyperparam

* match default test steps

* name

* large cleanup

* dont care about steps

* nesterov false default

* match each other impl

* steps

* switch nest

* swap defaults

* update docstring

* add no nesterov test

* ban fuse_optim

* prints

* classical momentum

* alternative condition

* recon

* pre + post wd

* false default

* detach

* signature changes

* context

* swap order

* big cleanup

* 0 step instead

* parity

* remove fuse

* remove fused

* better paper

* assert message

* correct shape check + eps

* multidim

* add eps

* cleanup

* correct assert message

* lint

* better tests

* naming

* ns_steps,ns_params

* update docstring

* docstring

* match sgd and muon together

* sandwich

* add back fused

* parity

---------

Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
2025-08-13 14:27:55 -04:00

76 lines
No EOL
3.2 KiB
Python

import numpy as np
import unittest
from tinygrad import Tensor
from typing import List
import functools
def orthogonality_helper(A:Tensor,tolerance=1.0e-5):
b_shape,m = A.shape[0:-2],A.shape[-2] #outer dimension should be the dim along orthogonality
A_identity = (Tensor.eye(m).reshape((1,) * len(b_shape)+(m,m)).expand(b_shape+(m,m)))
np.testing.assert_allclose((A @ A.transpose(-2,-1)).numpy(),A_identity.numpy(),atol=tolerance,rtol=tolerance)
def reconstruction_helper(A:List[Tensor],B:Tensor, tolerance=1.0e-5):
reconstructed_tensor = functools.reduce(Tensor.matmul, A)
np.testing.assert_allclose(reconstructed_tensor.numpy(),B.numpy(),atol=tolerance,rtol=tolerance)
class TestLinAlg(unittest.TestCase):
def test_svd_general(self):
sizes = [(2,2),(5,3),(3,5),(3,4,4),(2,2,2,2,3)]
for size in sizes:
a = Tensor.randn(size).realize()
U,S,V = Tensor.svd(a)
b_shape,m,n = size[0:-2],size[-2],size[-1]
k = min(m,n)
s_diag = (S.unsqueeze(-2) * Tensor.eye(k).reshape((1,) * len(b_shape) + (k,k)))
s_diag = s_diag.expand(b_shape + (k,k)).pad(tuple([(0,0) for _ in range(len(size)-2)] + [(0,m-k), (0,n-k)]))
orthogonality_helper(U)
orthogonality_helper(V)
reconstruction_helper([U,s_diag,V],a)
def test_svd_nonfull(self):
sizes = [(2,2),(5,3),(3,5),(2,2,2,2,3)]
for size in sizes:
a = Tensor.randn(size).realize()
U,S,V = Tensor.svd(a,full_matrices=False)
b_shape,m,n = size[0:-2],size[-2],size[-1]
k = min(m,n)
s_diag = (S.unsqueeze(-2) * Tensor.eye(k).reshape((1,) * len(b_shape) + (k,k)).expand(b_shape + (k,k)))
#reduced U,V is only orthogonal along smaller dim
if (m < n): orthogonality_helper(U),orthogonality_helper(V)
else: orthogonality_helper(U.transpose(-2,-1)),orthogonality_helper(V.transpose(-2,-1))
reconstruction_helper([U,s_diag,V],a)
@unittest.skip("very big. recommend wrapping with TinyJit around inner function")
def test_svd_large(self):
size = (1024,1024)
a = Tensor.randn(size).realize()
U,S,V = Tensor.svd(a)
b_shape,m,n = size[0:-2],size[-2],size[-1]
k = min(m,n)
s_diag = (S.unsqueeze(-2) * Tensor.eye(k).reshape((1,) * len(b_shape) + (k,k)))
s_diag = s_diag.expand(b_shape + (k,k)).pad(tuple([(0,0) for _ in range(len(size)-2)] + [(0,m-k), (0,n-k)]))
orthogonality_helper(U,tolerance=1.0e-3)
orthogonality_helper(V,tolerance=1.0e-3)
reconstruction_helper([U,s_diag,V],a,tolerance=1.0e-3)
def test_qr_general(self):
sizes = [(3,3),(3,6),(6,3),(2,2,2,2,2)]
for size in sizes:
a = Tensor.randn(size).realize()
Q,R = Tensor.qr(a)
orthogonality_helper(Q)
reconstruction_helper([Q,R],a)
def test_newton_schulz(self):
coefficients = [(2, -1.5, 0.5), (2.0, -1.4, 0.2, 0.2)]#these params map to the sign function
sizes = [(2,2), (3,2), (2,3), (2,2,2)]
for coefs in coefficients:
for size in sizes:
a = Tensor.randn(size)
b = Tensor.newton_schulz(a, steps=20, params=coefs, eps=0.0)
# ns(A) = U @ Vt -> (U @ Vt) @ (U @ Vt)t = I
orthogonality_helper(b if size[-1] > size[-2] else b.transpose(-2, -1), tolerance=1e-1)
if __name__ == "__main__":
unittest.main()