Skip to content

Commit c29c268

Browse files
authored
Incorporated SVD (#102)
Incorporated SVD functionality into CoLA.
1 parent d385075 commit c29c268

File tree

9 files changed

+238
-88
lines changed

9 files changed

+238
-88
lines changed

cola/backends/jax_fns.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from jax.lax import dynamic_slice, expand_dims
1111
from jax.lax import fori_loop as _for_loop
1212
from jax.lax import while_loop as _while_loop
13-
from jax.lax.linalg import cholesky, qr, svd
13+
from jax.lax.linalg import cholesky, qr
1414
from jax.random import PRNGKey, normal
1515
from jax.scipy.linalg import block_diag
1616
from jax.scipy.linalg import lu as lu_lax
@@ -58,7 +58,6 @@
5858
nan_to_num = jnp.nan_to_num
5959
dynamic_slice = dynamic_slice
6060
zeros_like = jnp.zeros_like
61-
svd = svd
6261
cholesky = cholesky
6362
solvetri = solvetri
6463
qr = qr
@@ -94,6 +93,11 @@
9493
finfo = jnp.finfo
9594

9695

96+
def svd(A, full_matrices):
97+
U, S, VH = jnp.linalg.svd(A, full_matrices=full_matrices)
98+
return U, S, VH.T.conj()
99+
100+
97101
def to_np(array):
98102
return jax.device_get(array)
99103

cola/backends/np_fns.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@ def __init__(self):
6565
sqrt = np.sqrt
6666
stack = np.stack
6767
sum = np.sum
68-
svd = np.linalg.svd
6968
where = np.where
7069
fft = np.fft.fft
7170
ifft = np.fft.ifft
@@ -86,6 +85,11 @@ def lstsq(A, b):
8685
return soln
8786

8887

88+
def svd(A, full_matrices):
89+
U, S, VH = np.linalg.svd(A, full_matrices=full_matrices)
90+
return U, S, VH.T.conj()
91+
92+
8993
def eig(array):
9094
eigvals, eigvecs = np.linalg.eig(array)
9195
eigvals, eigvecs = eigvals.astype(complex), eigvecs.astype(complex)

cola/backends/torch_fns.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@
4343
eigh = torch.linalg.eigh
4444
solve = torch.linalg.solve
4545
copy = torch.clone
46-
svd = torch.linalg.svd
4746
diag = torch.diag
4847
zeros_like = torch.zeros_like
4948
cholesky = torch.linalg.cholesky
@@ -68,6 +67,11 @@
6867
iscomplexobj = torch.is_complex
6968

7069

70+
def svd(A, full_matrices):
71+
U, S, VH = torch.linalg.svd(A, full_matrices=full_matrices)
72+
return U, S, VH.H
73+
74+
7175
def to_np(array):
7276
return array.detach().cpu().numpy()
7377

cola/linalg/decompositions/decompositions.py

Lines changed: 58 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,26 @@
11
""" Decompositions of linear operators, e.g. LU, Cholesky, Lanczos, Arnoldi, SVD"""
2-
from plum import dispatch
32
from dataclasses import dataclass
4-
from cola.ops import LinearOperator, Array
5-
from cola.ops import Triangular, Permutation, Diagonal
6-
from cola.ops import Identity, ScalarMul, Kronecker, BlockDiag
7-
from cola.utils import export
8-
from cola.linalg.algorithm_base import Algorithm
9-
from cola.linalg.decompositions.lanczos import lanczos
10-
from cola.linalg.decompositions.arnoldi import arnoldi
11-
from cola.linalg.algorithm_base import Auto
3+
from typing import Any, Optional
4+
5+
from plum import dispatch
6+
127
import cola.linalg
13-
from typing import Optional, Any
8+
from cola.annotations import Unitary
9+
from cola.linalg.algorithm_base import Algorithm, Auto
10+
from cola.linalg.decompositions.arnoldi import arnoldi
11+
from cola.linalg.decompositions.lanczos import lanczos, lanczos_eigs
12+
from cola.ops import (
13+
Array,
14+
BlockDiag,
15+
Diagonal,
16+
Identity,
17+
Kronecker,
18+
LinearOperator,
19+
Permutation,
20+
ScalarMul,
21+
Triangular,
22+
)
23+
from cola.utils import export
1424

1525
PRNGKey = Any
1626

@@ -78,6 +88,31 @@ def __call__(self, A: LinearOperator):
7888
return arnoldi(A, **self.__dict__)
7989

8090

91+
@dataclass
92+
class LanczosSVD(Algorithm):
93+
"""
94+
Does the SVD using the Lanczos decomposition,
95+
Args:
96+
start_vector (Array, optional): (n,) or (n, b) vector to start the algorithm.
97+
max_iters (int, optional): The maximum number of iterations to run.
98+
tol (float, optional): Relative error tolerance.
99+
pbar (bool, optional): Whether to show progress bar.
100+
key (PRNGKey, optional): Random key to use for the algorithm.
101+
PRNGKey for jax, long integer for numpy or pytorch.
102+
"""
103+
start_vector: Array = None
104+
max_iters: int = 1_000
105+
tol: float = 1e-6
106+
pbar: bool = False
107+
key: Optional[PRNGKey] = None
108+
109+
def __call__(self, A: LinearOperator):
110+
Sigma0_2, V, _ = lanczos_eigs(A.H @ A, **self.__dict__)
111+
Sigma1_2, U, _ = lanczos_eigs(A @ A.H, **self.__dict__)
112+
Sigma = A.xnp.sqrt(Sigma0_2)
113+
return Unitary(U), Diagonal(Sigma), Unitary(V)
114+
115+
81116
@export
82117
@dataclass
83118
class Lanczos(Algorithm):
@@ -174,3 +209,16 @@ def plu(A: BlockDiag):
174209
P, L, U = zip(*[plu(Ai) for Ai in A.Ms])
175210
BD = lambda *args: BlockDiag(*args, multiplicities=A.multiplicities) # noqa
176211
return BD(*P), BD(*L), BD(*U)
212+
213+
214+
def get_slice(num, which):
215+
if num == -1:
216+
raise ValueError(f"Number of eigenvalues {num} must be explicitly specified")
217+
if which == "SM":
218+
eig_slice = slice(0, num, None)
219+
elif which == "LM":
220+
id = -1 if num is None else -num
221+
eig_slice = slice(id, None, None)
222+
else:
223+
raise NotImplementedError(f"which={which} is not implemented")
224+
return eig_slice

cola/linalg/eig/eigs.py

Lines changed: 9 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,24 @@
11
import numpy as np
22
from plum import dispatch
3-
from cola.annotations import Unitary, Stiefel
3+
4+
from cola.annotations import SelfAdjoint, Stiefel, Unitary
45
from cola.fns import lazify
5-
from cola.annotations import SelfAdjoint
6-
from cola.linalg.trace.diag_trace import diag
7-
from cola.ops.operator_base import LinearOperator
8-
from cola.ops.operators import Diagonal
9-
from cola.ops.operators import I_like
10-
from cola.ops.operators import Identity
11-
from cola.ops.operators import Triangular
12-
from cola.linalg.decompositions.lanczos import lanczos_eigs
13-
from cola.linalg.decompositions.arnoldi import arnoldi_eigs
146
from cola.linalg.algorithm_base import Algorithm, Auto
15-
from cola.linalg.decompositions.decompositions import Arnoldi, Lanczos
7+
from cola.linalg.decompositions.arnoldi import arnoldi_eigs
8+
from cola.linalg.decompositions.decompositions import Arnoldi, Lanczos, get_slice
9+
from cola.linalg.decompositions.lanczos import lanczos_eigs
1610
from cola.linalg.eig.lobpcg import LOBPCG, lobpcg
1711
from cola.linalg.eig.power_iteration import PowerIteration
12+
from cola.linalg.trace.diag_trace import diag
1813
from cola.linalg.unary.unary import Eig, Eigh
14+
from cola.ops.operator_base import LinearOperator
15+
from cola.ops.operators import Diagonal, I_like, Identity, Triangular
1916
from cola.utils import export
2017

2118

2219
@export
2320
@dispatch.abstract
24-
def eig(A: LinearOperator, k: int = -1, which: str = 'LM', alg: Algorithm = Auto()):
21+
def eig(A: LinearOperator, k: int, which: str = "LM", alg: Algorithm = Auto()):
2522
"""
2623
Computes eigenvalues and eigenvectors of a linear operator.
2724
@@ -183,16 +180,3 @@ def eig(A: Diagonal, k: int, which: str, alg: Algorithm):
183180
eig_vals = A.diag[sorted_ind]
184181
eig_vecs = I_like(A).to_dense()[:, sorted_ind]
185182
return eig_vals[eig_slice], Unitary(lazify(eig_vecs[:, eig_slice]))
186-
187-
188-
def get_slice(num, which):
189-
if num == -1:
190-
raise ValueError(f"Number of eigenvalues {num} must be explicitly specified")
191-
if which == "SM":
192-
eig_slice = slice(0, num, None)
193-
elif which == "LM":
194-
id = -1 if num is None else -num
195-
eig_slice = slice(id, None, None)
196-
else:
197-
raise NotImplementedError(f"which={which} is not implemented")
198-
return eig_slice

cola/linalg/svd/svd.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import numpy as np
2+
from plum import dispatch
3+
4+
from cola.annotations import Unitary
5+
from cola.fns import lazify
6+
from cola.linalg.algorithm_base import Algorithm, Auto
7+
from cola.linalg.decompositions.decompositions import Lanczos, get_slice
8+
from cola.linalg.decompositions.lanczos import lanczos_eigs
9+
from cola.linalg.eig.lobpcg import LOBPCG, lobpcg
10+
from cola.linalg.inverse.inv import inv
11+
from cola.ops.operator_base import LinearOperator
12+
from cola.ops.operators import Dense, Diagonal, I_like, Identity
13+
from cola.utils import export
14+
15+
16+
@export
17+
class DenseSVD(Algorithm):
18+
"""
19+
Performs SVD on A.
20+
"""
21+
22+
23+
@export
24+
@dispatch.abstract
25+
def svd(A: LinearOperator, k: int, which: str = "LM", alg: Algorithm = Auto()):
26+
"""
27+
Computes the SVD of the linear operator A.
28+
29+
Args:
30+
A (LinearOperator): The linear operator to decompose.
31+
alg (Algorithm): (Auto, SVD, LanczosSVD)
32+
33+
Returns:
34+
Tuple[LinearOperator]: A tuple U, D, V, such that U D V^{*} = A.
35+
"""
36+
37+
38+
@dispatch(precedence=-1)
39+
def svd(A: LinearOperator, k: int, which: str, alg: Auto):
40+
""" Auto:
41+
- if A is small, use dense SVD
42+
- if A is large, use Lanczos
43+
"""
44+
match bool(np.prod(A.shape) <= 1e6):
45+
case True:
46+
alg = DenseSVD()
47+
case False:
48+
alg = Lanczos(**alg.__dict__)
49+
return svd(A, k, which, alg)
50+
51+
52+
@dispatch
53+
def svd(A: LinearOperator, k: int, which: str, alg: DenseSVD):
54+
U, Sigma, V = A.xnp.svd(A.to_dense(), full_matrices=True)
55+
idx = A.xnp.argsort(Sigma, axis=-1)
56+
return Unitary(Dense(U[:, idx])), Diagonal(Sigma[..., idx]), Unitary(Dense(V[:, idx]))
57+
58+
59+
@dispatch
60+
def svd(A: LinearOperator, k: int, which: str, alg: Lanczos):
61+
xnp = A.xnp
62+
eig_slice = get_slice(k, which)
63+
if A.shape[1] <= A.shape[0]:
64+
eig_vals, V, _ = lanczos_eigs(A.H @ A, **alg.__dict__)
65+
V = Unitary(V[:, eig_slice])
66+
Sigma = Diagonal(xnp.sqrt(eig_vals[eig_slice]))
67+
U = Unitary(lazify((A @ V @ inv(Sigma)).to_dense()))
68+
else:
69+
eig_vals, U, _ = lanczos_eigs(A @ A.H, **alg.__dict__)
70+
U = Unitary(U[:, eig_slice])
71+
Sigma = Diagonal(xnp.sqrt(eig_vals[eig_slice]))
72+
V = Unitary(lazify((inv(Sigma) @ U.H @ A).to_dense().conj().T))
73+
return U, Sigma, V
74+
75+
76+
@dispatch
77+
def svd(A: LinearOperator, k: int, which: str, alg: LOBPCG):
78+
xnp = A.xnp
79+
eig_slice = get_slice(k, which)
80+
eig_vals, V = lobpcg(A.H @ A, **alg.__dict__)
81+
V = Unitary(V[:, eig_slice])
82+
Sigma = Diagonal(xnp.sqrt(eig_vals[eig_slice]))
83+
U = Unitary(lazify((A @ V @ inv(Sigma)).to_dense()))
84+
return U, Sigma, V
85+
86+
87+
@dispatch
88+
def svd(A: Identity, k: int, which: str, alg: Algorithm):
89+
ones = A.xnp.ones(A.shape[0], device=A.device, dtype=A.dtype)
90+
return Unitary(I_like(A)), Diagonal(ones), Unitary(I_like(A))
91+
92+
93+
@dispatch
94+
def svd(A: Diagonal, k: int, which: str, alg: Algorithm):
95+
return Unitary(I_like(A)), A, Unitary(I_like(A))

cola/linalg/tbd/svd.py

Lines changed: 0 additions & 48 deletions
This file was deleted.

0 commit comments

Comments
 (0)