Skip to content

Commit dc4905b

Browse files
authored
Arnoldi & Lanczos output and naming convention parity (#46)
2 parents 76a799a + 15c64d6 commit dc4905b

File tree

6 files changed

+39
-24
lines changed

6 files changed

+39
-24
lines changed

cola/algorithms/arnoldi.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import cola
88
from cola import Stiefel, lazify
99

10-
1110
# def arnoldi_eigs_bwd(res, grads, unflatten, *args, **kwargs):
1211
# val_grads, eig_grads, _ = grads
1312
# op_args, (eig_vals, eig_vecs, _) = res
@@ -34,8 +33,8 @@
3433
# return (dA, )
3534

3635

37-
#@export
38-
#@iterative_autograd(arnoldi_eigs_bwd)
36+
# @export
37+
# @iterative_autograd(arnoldi_eigs_bwd)
3938
@export
4039
def arnoldi_eigs(A: LinearOperator, start_vector: Array = None, max_iters: int = 100,
4140
tol: float = 1e-7, use_householder: bool = False, pbar: bool = False):

cola/algorithms/lanczos.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def lanczos_max_eig(A: LinearOperator, rhs: Array, max_iters: int, tol: float =
1717
max_iters: int maximum number of iters to run lanczos
1818
tol: float: tolerance criteria to stop lanczos
1919
"""
20-
eigvals, *_ = lanczos(A=A, start_vector=rhs, max_iters=max_iters, tol=tol)
20+
eigvals, *_ = lanczos_eigs(A=A, start_vector=rhs, max_iters=max_iters, tol=tol)
2121
return eigvals[-1]
2222

2323

@@ -50,8 +50,8 @@ def altogether(*theta):
5050
# @export
5151
# @iterative_autograd(lanczos_eig_bwd)
5252
@export
53-
def lanczos(A: LinearOperator, start_vector: Array = None, max_iters: int = 100, tol: float = 1e-7,
54-
pbar: bool = False):
53+
def lanczos_eigs(A: LinearOperator, start_vector: Array = None, max_iters: int = 100,
54+
tol: float = 1e-7, pbar: bool = False):
5555
"""
5656
Computes the eigenvalues and eigenvectors using Lanczos.
5757
@@ -71,12 +71,11 @@ def lanczos(A: LinearOperator, start_vector: Array = None, max_iters: int = 100,
7171
7272
"""
7373
xnp = A.xnp
74-
Q, T, info = lanczos_decomp(A=A, start_vector=start_vector, max_iters=max_iters, tol=tol,
75-
pbar=pbar)
74+
Q, T, info = lanczos(A=A, start_vector=start_vector, max_iters=max_iters, tol=tol, pbar=pbar)
7675
eigvals, eigvectors = xnp.eigh(T)
7776
idx = xnp.argsort(eigvals, axis=-1)
78-
V = lazify(Q) @ lazify(eigvectors[:,idx])
79-
77+
V = lazify(Q) @ lazify(eigvectors[:, idx])
78+
8079
eigvals = eigvals[..., idx]
8180
# V = V[..., idx]
8281
return eigvals, V, info
@@ -85,16 +84,14 @@ def lanczos(A: LinearOperator, start_vector: Array = None, max_iters: int = 100,
8584
def LanczosDecomposition(A: LinearOperator, start_vector=None, max_iters=100, tol=1e-7, pbar=False):
8685
""" Provides the Lanczos decomposition of a matrix A = Q T Q^*.
8786
LinearOperator form of lanczos, see lanczos for arguments."""
88-
Q, T, info = lanczos_decomp(A=A, start_vector=start_vector, max_iters=max_iters, tol=tol,
89-
pbar=pbar)
87+
Q, T, info = lanczos(A=A, start_vector=start_vector, max_iters=max_iters, tol=tol, pbar=pbar)
9088
A_approx = cola.UnitaryDecomposition(lazify(Q), SelfAdjoint(lazify(T)))
9189
A_approx.info = info
9290
return A_approx
9391

9492

9593
@export
96-
def lanczos_decomp(A: LinearOperator, start_vector: Array = None, max_iters=100, tol=1e-7,
97-
pbar=False):
94+
def lanczos(A: LinearOperator, start_vector: Array = None, max_iters=100, tol=1e-7, pbar=False):
9895
"""
9996
Computes the Lanczos decomposition of a the operator A, A = Q T Q^*.
10097

cola/algorithms/slq.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def slq_fwd(A, fun, num_samples, max_iters, tol, pbar, key):
5050
tau = Q[..., 0, :]
5151
# approx = xnp.sum(tau**2 * fun(eigvals), axis=-1)
5252
# fn_vals = xnp.where(xnp.abs(eigvals) > _mp, fun(eigvals), xnp.zeros_like(eigvals))
53-
const = 10*_mp * xnp.max(eigvals, axis=1, keepdims=True)
53+
const = 10 * _mp * xnp.max(eigvals, axis=1, keepdims=True)
5454
fn_vals = xnp.where(xnp.abs(eigvals) > const, fun(eigvals), xnp.zeros_like(eigvals))
5555
approx = xnp.sum(tau**2 * fn_vals, axis=-1)
5656
estimate = A.shape[-2] * approx

cola/linalg/eigs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from cola.ops import Identity
1212
from cola.ops import Triangular
1313
from cola.algorithms import power_iteration
14-
from cola.algorithms.lanczos import lanczos
14+
from cola.algorithms.lanczos import lanczos_eigs
1515
from cola.algorithms.arnoldi import arnoldi_eigs
1616
from cola.utils import export
1717

@@ -75,7 +75,7 @@ def eig(A: LinearOperator, **kwargs):
7575
eig_vals, eig_vecs = xnp.eigh(A.to_dense())
7676
return eig_vals[eig_slice], Stiefel(lazify(eig_vecs[:, eig_slice]))
7777
elif method in ('lanczos', 'iterative') or (method == 'auto' and prod(A.shape) >= 1e6):
78-
eig_vals, eig_vecs, _ = lanczos(A, **kws)
78+
eig_vals, eig_vecs, _ = lanczos_eigs(A, **kws)
7979
return eig_vals, eig_vecs
8080
else:
8181
raise ValueError(f"Unknown method {method} for SelfAdjoint operator")

tests/algorithms/test_arnoldi.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from cola.ops import Householder
33
from cola.ops import Product
44
from cola.ops import Dense
5-
from cola.algorithms.arnoldi import get_householder_vec
65
from cola.fns import lazify
76
from cola.algorithms.arnoldi import get_arnoldi_matrix
87
from cola.algorithms.arnoldi import arnoldi_eigs
@@ -92,7 +91,7 @@ def test_householder_arnoldi_decomp(backend):
9291
# A_np, rhs_np = np.array(A, dtype=np.complex128), np.array(rhs[:, 0], dtype=np.complex128)
9392
A_np, rhs_np = np.array(A, dtype=np.float64), np.array(rhs[:, 0], dtype=np.float64)
9493
# Q_sol, H_sol = run_householder_arnoldi(A, rhs, A.shape[0], np.float64, xnp)
95-
Q_sol, H_sol = run_householder_arnoldi_np(A_np, rhs_np, A.shape[0], np.float64, xnp)
94+
Q_sol, H_sol = run_householder_arnoldi_np(A_np, rhs_np, A.shape[0], np.float64)
9695

9796
# fn = run_householder_arnoldi
9897
fn = xnp.jit(run_householder_arnoldi, static_argnums=(0, 2))
@@ -146,7 +145,7 @@ def test_numpy_arnoldi(backend):
146145
rhs = np.random.normal(size=(A.shape[0], ))
147146
# rhs = np.random.normal(size=(A.shape[0], 2)).view(np.complex128)[:, 0]
148147

149-
Q, H = run_householder_arnoldi_np(A, rhs, max_iter=A.shape[0], dtype=dtype, xnp=xnp)
148+
Q, H = run_householder_arnoldi_np(A, rhs, max_iter=A.shape[0], dtype=dtype)
150149
abs_error = np.linalg.norm(np.eye(A.shape[0]) - Q.T @ Q)
151150
assert abs_error < 1e-4
152151
abs_error = np.linalg.norm(Q.T @ A @ Q - H)
@@ -159,10 +158,10 @@ def test_numpy_arnoldi(backend):
159158
assert abs_error < 1e-10
160159

161160

162-
def run_householder_arnoldi_np(A, rhs, max_iter, dtype, xnp):
161+
def run_householder_arnoldi_np(A, rhs, max_iter, dtype):
163162
H, Q, Ps, zj = initialize_householder_arnoldi(rhs, max_iter, dtype)
164163
for jdx in range(1, max_iter + 2):
165-
vec, beta = get_householder_vec(zj, jdx - 1, xnp)
164+
vec, beta = get_householder_vec_np(zj, jdx - 1)
166165
Ps[jdx].vec, Ps[jdx].beta = vec[:, None], beta
167166
H[:, jdx - 1] = np.array(Ps[jdx] @ zj)
168167
if jdx <= max_iter:
@@ -186,6 +185,26 @@ def initialize_householder_arnoldi(rhs, max_iter, dtype):
186185
return H, Q, Ps, zj
187186

188187

188+
def get_householder_vec_np(x, idx):
189+
sigma_2 = np.linalg.norm(x[idx + 1:])**2.
190+
vec = np.zeros_like(x)
191+
vec[idx:] = x[idx:]
192+
if sigma_2 == 0 and x[idx] >= 0:
193+
beta = 0
194+
elif sigma_2 == 0 and x[idx] < 0:
195+
beta = -2
196+
else:
197+
x_norm_partial = np.sqrt(x[idx]**2 + sigma_2)
198+
if x[idx] <= 0:
199+
vec[idx] = x[idx] - x_norm_partial
200+
else:
201+
vec[idx] = -sigma_2 / (x[idx] + x_norm_partial)
202+
beta = 2 * vec[idx]**2 / (sigma_2 + vec[idx]**2)
203+
vec = vec / vec[idx]
204+
vec[idx:] = vec[idx:] / vec[idx]
205+
return vec, beta
206+
207+
189208
def run_arnoldi(A, rhs, max_iter, tol, dtype):
190209
Q, H = initialize_arnoldi(rhs, max_iter=max_iter, dtype=dtype)
191210
idx, vec = 0, rhs.copy()

tests/algorithms/test_lanczos.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from cola.algorithms.lanczos import construct_tridiagonal_batched
66
from cola.algorithms.lanczos import get_lanczos_coeffs
77
from cola.algorithms.lanczos import lanczos_parts
8-
from cola.algorithms.lanczos import lanczos
8+
from cola.algorithms.lanczos import lanczos_eigs
99
from cola.algorithms.lanczos import lanczos_max_eig
1010
from cola.utils_test import get_xnp, parametrize, relative_error
1111
from cola.utils_test import generate_spectrum, generate_pd_from_diag
@@ -35,7 +35,7 @@ def test_lanczos_vjp(backend):
3535

3636
def f(theta):
3737
Aop = unflatten([theta])
38-
out = lanczos(Aop, x0, max_iters=10, tol=1e-6, pbar=False)
38+
out = lanczos_eigs(Aop, x0, max_iters=10, tol=1e-6, pbar=False)
3939
eig_vals, eig_vecs, _ = out
4040
# loss = xnp.sum(eig_vals ** 2.) + xnp.sum(xnp.abs(eig_vecs), axis=[0, 1])
4141
loss = xnp.sum(eig_vals**2.)

0 commit comments

Comments
 (0)