mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
linalg cosmetic change (#12356)
This commit is contained in:
parent
6a56d3c859
commit
86c5c969ea
2 changed files with 35 additions and 38 deletions
|
|
@ -1,29 +1,26 @@
|
|||
import numpy as np
|
||||
import unittest
|
||||
import unittest, functools
|
||||
from tinygrad import Tensor
|
||||
from typing import List
|
||||
import functools
|
||||
import numpy as np
|
||||
|
||||
def orthogonality_helper(A:Tensor,tolerance=1.0e-5):
|
||||
def orthogonality_helper(A:Tensor, tolerance=1e-5):
|
||||
b_shape,m = A.shape[0:-2],A.shape[-2] #outer dimension should be the dim along orthogonality
|
||||
A_identity = (Tensor.eye(m).reshape((1,) * len(b_shape)+(m,m)).expand(b_shape+(m,m)))
|
||||
A_identity = (Tensor.eye(m).reshape((1,)*len(b_shape)+(m,m)).expand(b_shape+(m,m)))
|
||||
np.testing.assert_allclose((A @ A.transpose(-2,-1)).numpy(),A_identity.numpy(),atol=tolerance,rtol=tolerance)
|
||||
|
||||
def reconstruction_helper(A:List[Tensor],B:Tensor, tolerance=1.0e-5):
|
||||
def reconstruction_helper(A:list[Tensor],B:Tensor, tolerance=1e-5):
|
||||
reconstructed_tensor = functools.reduce(Tensor.matmul, A)
|
||||
np.testing.assert_allclose(reconstructed_tensor.numpy(),B.numpy(),atol=tolerance,rtol=tolerance)
|
||||
|
||||
class TestLinAlg(unittest.TestCase):
|
||||
|
||||
def test_svd_general(self):
|
||||
sizes = [(2,2),(5,3),(3,5),(3,4,4),(2,2,2,2,3)]
|
||||
for size in sizes:
|
||||
a = Tensor.randn(size).realize()
|
||||
U,S,V = Tensor.svd(a)
|
||||
U,S,V = a.svd()
|
||||
b_shape,m,n = size[0:-2],size[-2],size[-1]
|
||||
k = min(m,n)
|
||||
s_diag = (S.unsqueeze(-2) * Tensor.eye(k).reshape((1,) * len(b_shape) + (k,k)))
|
||||
s_diag = s_diag.expand(b_shape + (k,k)).pad(tuple([(0,0) for _ in range(len(size)-2)] + [(0,m-k), (0,n-k)]))
|
||||
s_diag = s_diag.expand(b_shape + (k,k)).pad(tuple([None]*len(b_shape) + [(0,m-k), (0,n-k)]))
|
||||
orthogonality_helper(U)
|
||||
orthogonality_helper(V)
|
||||
reconstruction_helper([U,s_diag,V],a)
|
||||
|
|
@ -32,7 +29,7 @@ class TestLinAlg(unittest.TestCase):
|
|||
sizes = [(2,2),(5,3),(3,5),(2,2,2,2,3)]
|
||||
for size in sizes:
|
||||
a = Tensor.randn(size).realize()
|
||||
U,S,V = Tensor.svd(a,full_matrices=False)
|
||||
U,S,V = a.svd(full_matrices=False)
|
||||
b_shape,m,n = size[0:-2],size[-2],size[-1]
|
||||
k = min(m,n)
|
||||
s_diag = (S.unsqueeze(-2) * Tensor.eye(k).reshape((1,) * len(b_shape) + (k,k)).expand(b_shape + (k,k)))
|
||||
|
|
@ -45,20 +42,20 @@ class TestLinAlg(unittest.TestCase):
|
|||
def test_svd_large(self):
|
||||
size = (1024,1024)
|
||||
a = Tensor.randn(size).realize()
|
||||
U,S,V = Tensor.svd(a)
|
||||
U,S,V = a.svd()
|
||||
b_shape,m,n = size[0:-2],size[-2],size[-1]
|
||||
k = min(m,n)
|
||||
s_diag = (S.unsqueeze(-2) * Tensor.eye(k).reshape((1,) * len(b_shape) + (k,k)))
|
||||
s_diag = s_diag.expand(b_shape + (k,k)).pad(tuple([(0,0) for _ in range(len(size)-2)] + [(0,m-k), (0,n-k)]))
|
||||
orthogonality_helper(U,tolerance=1.0e-3)
|
||||
orthogonality_helper(V,tolerance=1.0e-3)
|
||||
reconstruction_helper([U,s_diag,V],a,tolerance=1.0e-3)
|
||||
s_diag = s_diag.expand(b_shape + (k,k)).pad(tuple([None]*len(b_shape) + [(0,m-k), (0,n-k)]))
|
||||
orthogonality_helper(U,tolerance=1e-3)
|
||||
orthogonality_helper(V,tolerance=1e-3)
|
||||
reconstruction_helper([U,s_diag,V],a,tolerance=1e-3)
|
||||
|
||||
def test_qr_general(self):
|
||||
sizes = [(3,3),(3,6),(6,3),(2,2,2,2,2)]
|
||||
for size in sizes:
|
||||
a = Tensor.randn(size).realize()
|
||||
Q,R = Tensor.qr(a)
|
||||
Q,R = a.qr()
|
||||
orthogonality_helper(Q)
|
||||
reconstruction_helper([Q,R],a)
|
||||
|
||||
|
|
@ -68,9 +65,9 @@ class TestLinAlg(unittest.TestCase):
|
|||
for coefs in coefficients:
|
||||
for size in sizes:
|
||||
a = Tensor.randn(size)
|
||||
b = Tensor.newton_schulz(a, steps=20, params=coefs, eps=0.0)
|
||||
b = a.newton_schulz(steps=20, params=coefs, eps=0.0)
|
||||
# ns(A) = U @ Vt -> (U @ Vt) @ (U @ Vt)t = I
|
||||
orthogonality_helper(b if size[-1] > size[-2] else b.transpose(-2, -1), tolerance=1e-1)
|
||||
orthogonality_helper(b if size[-1] > size[-2] else b.transpose(-2, -1), tolerance=1e-3)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
@ -4090,24 +4090,24 @@ class Tensor(MathTrait):
|
|||
"""
|
||||
assert self.ndim > 1, "NS only works for two or more dims"
|
||||
G = self / (self.square().sum(axis=(-2, -1), keepdim=True).sqrt() + eps)
|
||||
G = G.transpose(-2, -1) if self.shape[-2] > self.shape[-1] else G
|
||||
if (swap := self.shape[-2] > self.shape[-1]): G = G.transpose(-2, -1)
|
||||
for _ in range(steps): G = sum(p * functools.reduce(lambda x, y: (y @ y.transpose(-2, -1)) @ x, [G]*i, G) for i,p in enumerate(params))
|
||||
return G.transpose(-2, -1) if self.shape[-2] > self.shape[-1] else G
|
||||
return G.transpose(-2, -1) if swap else G
|
||||
|
||||
def qr(self) -> tuple[Tensor, Tensor]:
|
||||
assert self.ndim > 1, f"expected two or more dimensions, got {self.ndim}"
|
||||
b_shape, m, n = self.shape[:-2], int(self.shape[-2]), int(self.shape[-1])
|
||||
R = self.clone()
|
||||
b_shape, m, n = self.shape[0:self.ndim - 2], int(R.shape[-2]), int(R.shape[-1])
|
||||
Q = Tensor.eye(m, dtype = self.dtype).reshape((1,) * (len(self.shape) - 2) + 2 * (m,)).expand(b_shape + 2 * (m,)).contiguous()
|
||||
for i in range(int(min(m, n))):
|
||||
Q = Tensor.eye(m, dtype=self.dtype).reshape((1,) * len(b_shape) + (m, m)).expand(b_shape + (m, m)).contiguous()
|
||||
for i in range(min(m, n)):
|
||||
x = R[..., i:m, i]
|
||||
s = -x[..., 0].sign()
|
||||
u1 = x[..., 0] - s * x.square().sum(-1).sqrt()
|
||||
w = x.unsqueeze(-1) / u1.reshape(b_shape + 2 * (1,))
|
||||
w = x.unsqueeze(-1) / u1.reshape(b_shape + (1, 1))
|
||||
w[..., 0, 0] = 1
|
||||
tau = (-s * u1 / x.square().sum(-1).sqrt()).reshape(b_shape + 2 * (1,)).expand(w.shape)
|
||||
tau = (-s * u1 / x.square().sum(-1).sqrt()).reshape(b_shape + (1, 1))
|
||||
R[..., i:m, :] = R[..., i:m, :] - (w * tau) @ (w.transpose(-2, -1) @ R[..., i:m, :])
|
||||
Q[..., :, i:m] = Q[..., :, i:m] - (Q[..., :, i:m] @ w) @ (tau.transpose(-2, -1) * w.transpose(-2, -1))
|
||||
Q[..., :, i:m] = Q[..., :, i:m] - (Q[..., :, i:m] @ w) @ (tau * w).transpose(-2, -1)
|
||||
return Q,R
|
||||
|
||||
def svd(self, full_matrices = True) -> tuple[Tensor, Tensor, Tensor]:
|
||||
|
|
@ -4115,14 +4115,14 @@ class Tensor(MathTrait):
|
|||
assert self.ndim > 1, f"expected two or more dimensions, got {self.ndim}"
|
||||
b_shape, m, n = self.shape[:-2], int(self.shape[-2]), int(self.shape[-1])
|
||||
#preprocess the matrix
|
||||
Q, R = (Tensor.qr(self) if m >= n else Tensor.qr(self.transpose(-2, -1)))
|
||||
num, q_num = int(min(m, n)), int(max(m, n))
|
||||
U = R.shrink(tuple([(0, self.shape[i]) for i in range(self.ndim - 2)] + [(0, num), (0, num)])).contiguous()
|
||||
V = Tensor.eye(num, dtype = self.dtype).reshape((1,) * (self.ndim - 2) + (num, num)).expand(b_shape + 2 * (num,)).contiguous()
|
||||
Q, R = (self.qr() if m >= n else self.transpose(-2, -1).qr())
|
||||
num, q_num = min(m, n), max(m, n)
|
||||
U = R.shrink(tuple([None] * len(b_shape) + [(0, num), (0, num)])).contiguous()
|
||||
V = Tensor.eye(num, dtype=self.dtype).reshape((1,) * len(b_shape) + (num, num)).expand(b_shape + (num, num)).contiguous()
|
||||
#prepare round robin pairing
|
||||
permute, inverse_permute = Tensor.arange(0, num, dtype = dtypes.int), Tensor.zeros(num, dtype = dtypes.int).contiguous()
|
||||
permute, inverse_permute = Tensor.arange(0, num, dtype=dtypes.int), Tensor.zeros(num, dtype=dtypes.int).contiguous()
|
||||
permute[num//2:num] = permute[num//2:num].flip(0)
|
||||
inverse_permute[permute] = Tensor.arange(num, dtype = dtypes.int)
|
||||
inverse_permute[permute] = Tensor.arange(num, dtype=dtypes.int)
|
||||
def one_round_jacobi(U, V,permute,inverse_permute):
|
||||
#pair all the columns
|
||||
V_permuted, runoff_V = (V[..., permute].split(num - 1, -1)) if num % 2 == 1 else (V[..., permute], None)
|
||||
|
|
@ -4146,15 +4146,15 @@ class Tensor(MathTrait):
|
|||
else: permute = permute[0].reshape(1).cat(((permute[1:num] - 2) % (num - 1)) + 1)
|
||||
inverse_permute = inverse_permute.scatter(0,permute,Tensor.arange(num,dtype=dtypes.int32))
|
||||
return U, V, permute, inverse_permute
|
||||
max_iterations, iterations_per_round = 1, int((num) * math.log2(num) * 2 + 2)#sorta heuristic, most use num*log2(num)
|
||||
max_iterations, iterations_per_round = 1, int(num * math.log2(num) * 2 + 2)#sorta heuristic, most use num*log2(num)
|
||||
for _ in range(max_iterations * iterations_per_round): U, V, permute, inverse_permute = one_round_jacobi(U, V, permute, inverse_permute)
|
||||
#extract singular values and sort. construct U from Q
|
||||
S, indices = U.square().sum(-2).sqrt().sort(dim = -1, descending=True)
|
||||
new_indices = Tensor.arange(num).reshape((1,) * (self.ndim - 1) + (num,)).expand(b_shape + 2 * (num,)).contiguous()
|
||||
new_indices[..., :num] = indices.reshape(b_shape + (1,) + (num,)).expand(b_shape + 2 * (num,))
|
||||
U,V = U.gather(-1, new_indices[...,0:num,0:num]) / S.unsqueeze(-2), V.gather(-1, new_indices[..., 0:num, 0:num]).realize()
|
||||
new_indices = Tensor.arange(num).reshape((1,) * (self.ndim - 1) + (num,)).expand(b_shape + (num, num)).contiguous()
|
||||
new_indices[..., :num] = indices.reshape(b_shape + (1, num)).expand(b_shape + (num, num))
|
||||
U, V = U.gather(-1, new_indices[...,0:num,0:num]) / S.unsqueeze(-2), V.gather(-1, new_indices[..., 0:num, 0:num]).realize()
|
||||
|
||||
padded_u = Tensor.eye(q_num, dtype = U.dtype).reshape((1,) * (self.ndim - 2) + 2 * (q_num,)).expand(b_shape + 2 * (q_num,)).contiguous()
|
||||
padded_u = Tensor.eye(q_num, dtype=U.dtype).reshape((1,) * len(b_shape) + (q_num, q_num)).expand(b_shape + (q_num, q_num)).contiguous()
|
||||
padded_u[..., 0:num, 0:num] = U
|
||||
U = Q @ padded_u
|
||||
if not full_matrices: U, V = U[..., 0:num], V[..., 0:num]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue