Skip to content

Commit 2ddcb3d

Browse files
authored
Code maintenance (#104)
Code maintenance: sorted imports, opened up CoLA ops for go to, pytest incorporates defaults marks, and took out unsued files.
1 parent c29c268 commit 2ddcb3d

32 files changed

+199
-458
lines changed

cola/annotations.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,24 @@
1+
from collections.abc import Iterable
12
from functools import reduce
23
from typing import Set, Union
3-
from collections.abc import Iterable
4+
45
from plum import dispatch
5-
from cola.ops import LinearOperator, Array
6-
from cola.ops import Kronecker, Product, Sum
7-
from cola.ops import Transpose, Adjoint
8-
from cola.ops import BlockDiag, Identity, ScalarMul
9-
from cola.ops import Hessian, Permutation, Sliced
6+
7+
from cola.ops import (
8+
Adjoint,
9+
Array,
10+
BlockDiag,
11+
Hessian,
12+
Identity,
13+
Kronecker,
14+
LinearOperator,
15+
Permutation,
16+
Product,
17+
ScalarMul,
18+
Sliced,
19+
Sum,
20+
Transpose,
21+
)
1022
from cola.utils import export
1123

1224
Scalar = Array

cola/linalg/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
""" High level linear algebra functions, """
22
import pkgutil
3+
34
from cola.utils import import_from_all
45

56
__all__ = []

cola/linalg/algorithm_base.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1+
from types import SimpleNamespace
2+
13
from plum import parametric
4+
25
from cola.ops import LinearOperator
36
from cola.utils import export
4-
from types import SimpleNamespace
7+
58
# import pytreeclass as tc
69

710

cola/linalg/decompositions/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import pkgutil
2+
23
from cola.utils import import_from_all
34

45
__all__ = []

cola/linalg/decompositions/arnoldi.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
from typing import Tuple
2-
from cola import Stiefel
3-
from cola.ops import LinearOperator
4-
from cola.ops import Array, Dense
5-
from cola.ops import Householder, Product
2+
63
# from cola.utils import export
7-
from cola import lazify
4+
from cola import Stiefel, lazify
5+
from cola.ops import Array, Dense, Householder, LinearOperator, Product
86

97
# def arnoldi_eigs_bwd(res, grads, unflatten, *args, **kwargs):
108
# val_grads, eig_grads, _ = grads

cola/linalg/decompositions/lanczos.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1+
import cola
12
from cola import SelfAdjoint, Unitary
23
from cola.fns import lazify
3-
from cola.ops import Array, LinearOperator, Dense, Tridiagonal
4-
import cola
4+
from cola.ops import Array, Dense, LinearOperator, Tridiagonal
55

66

77
def lanczos_eig_bwd(res, grads, unflatten, *args, **kwargs):

cola/linalg/eig/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import pkgutil
2+
23
from cola.utils import import_from_all
34

45
__all__ = []

cola/linalg/eig/iram.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
import numpy as np
22
from scipy.sparse.linalg import LinearOperator as LO
33
from scipy.sparse.linalg import eigs
4-
from cola.ops import LinearOperator
5-
from cola.ops import Array
6-
from cola.ops import Dense
4+
5+
from cola.ops import Array, Dense, LinearOperator
76
from cola.utils import export
87
from cola.utils.utils_linalg import get_numpy_dtype
98

cola/linalg/eig/lobpcg.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1-
import numpy as np
21
from dataclasses import dataclass
3-
from cola.linalg.algorithm_base import Algorithm
2+
3+
import numpy as np
44
from scipy.sparse.linalg import LinearOperator as LO
55
from scipy.sparse.linalg import lobpcg as lobpcg_sp
6-
from cola.ops import LinearOperator
7-
from cola.ops import Dense
6+
7+
from cola.linalg.algorithm_base import Algorithm
8+
from cola.ops import Dense, LinearOperator
89
from cola.utils import export
910

1011

cola/linalg/eig/power_iteration.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
from cola.utils import export
2-
from cola.ops import LinearOperator
3-
from cola.linalg.algorithm_base import Algorithm
41
from dataclasses import dataclass
5-
from typing import Optional, Any
2+
from typing import Any, Optional
3+
4+
from cola.linalg.algorithm_base import Algorithm
5+
from cola.ops import LinearOperator
6+
from cola.utils import export
67

78
PRNGKey = Any
89

cola/linalg/inverse/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import pkgutil
2+
23
from cola.utils import import_from_all
34

45
__all__ = []

cola/linalg/inverse/pinv.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
import numpy as np
22
from plum import dispatch
33

4-
import cola
4+
from cola.annotations import PSD
55
from cola.linalg.algorithm_base import Algorithm, Auto, IterativeOperatorWInfo
66
from cola.linalg.inverse.cg import CG
7-
from cola.ops.operators import Diagonal, Identity, LinearOperator, Permutation, ScalarMul, I_like
8-
from cola.annotations import PSD
7+
from cola.ops.operators import Diagonal, I_like, Identity, LinearOperator, Permutation, ScalarMul
98
from cola.utils import export
109
from cola.utils.utils_linalg import get_precision
1110

@@ -43,7 +42,7 @@ def pinv(A: LinearOperator, alg: Algorithm = Auto()):
4342
4443
Example:
4544
>>> A = MyLinearOperator()
46-
>>> x = cola.pseudo(A) @ b
45+
>>> x = cola.pinv(A) @ b
4746
4847
"""
4948

@@ -67,7 +66,7 @@ def pinv(A: LinearOperator, alg: Auto):
6766
def pinv(A: LinearOperator, alg: CG):
6867
xnp = A.xnp
6968
M = A.H @ A
70-
cons = get_precision(xnp, A.dtype) * xnp.sqrt(cola.eigmax(M))
69+
cons = get_precision(xnp, A.dtype) * max(A.shape)
7170
Op = IterativeOperatorWInfo(M, alg)
7271
return PSD(Op + cons * I_like(M)) @ A.H
7372

cola/linalg/logdet/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import pkgutil
2+
23
from cola.utils import import_from_all
34

45
__all__ = []

cola/linalg/logdet/logdet.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,25 @@
1-
import numpy as np
21
from functools import reduce
2+
3+
import numpy as np
34
from plum import dispatch
5+
46
from cola.annotations import PSD
5-
from cola.ops.operators import LinearOperator, Triangular, Permutation, Identity, ScalarMul
6-
from cola.ops.operators import Diagonal, Kronecker, BlockDiag, Product
7-
from cola.utils import export
87
from cola.linalg.algorithm_base import Algorithm, Auto
9-
from cola.linalg.decompositions.decompositions import Cholesky, LU, Arnoldi, Lanczos
10-
from cola.linalg.decompositions.decompositions import plu, cholesky
8+
from cola.linalg.decompositions.decompositions import LU, Arnoldi, Cholesky, Lanczos, cholesky, plu
119
from cola.linalg.trace.diag_trace import trace
1210
from cola.linalg.unary.unary import log
11+
from cola.ops.operators import (
12+
BlockDiag,
13+
Diagonal,
14+
Identity,
15+
Kronecker,
16+
LinearOperator,
17+
Permutation,
18+
Product,
19+
ScalarMul,
20+
Triangular,
21+
)
22+
from cola.utils import export
1323

1424

1525
def product(xs):

cola/linalg/preconditioning/preconditioners.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from typing import Union
2-
from cola.ops import LinearOperator
3-
from cola.linalg.eig.power_iteration import power_iteration
2+
43
from plum import dispatch
4+
5+
from cola.linalg.eig.power_iteration import power_iteration
6+
from cola.ops import LinearOperator
57
from cola.utils import export
68

79

cola/linalg/tbd/nullspace.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1-
from cola.ops import LinearOperator, Array
2-
from cola.backends import get_library_fns
3-
from cola.utils import export
41
import logging
2+
53
import numpy as np
64
from plum import dispatch
75

6+
from cola.backends import get_library_fns
7+
from cola.ops import Array, LinearOperator
8+
from cola.utils import export
9+
810
eigmax = None # TODO: fix
911

1012

cola/linalg/tbd/pinv.py

Lines changed: 0 additions & 37 deletions
This file was deleted.

cola/linalg/tbd/slq.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from typing import Callable
2-
from cola.ops import LinearOperator
2+
33
from cola.linalg.decompositions.lanczos import lanczos
44
from cola.linalg.inverse.cg import cg
5+
from cola.ops import LinearOperator
56
from cola.utils import export
67
from cola.utils.custom_autodiff import iterative_autograd
78

cola/linalg/tbd/svrg.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import numpy as np
2+
23
import cola
4+
35
# from cola.linalg.eigs import eigmax
4-
from cola.ops import Sum, Product, Dense
5-
from cola.ops import I_like
6+
from cola.ops import Dense, I_like, Product, Sum
67
from cola.utils import export
8+
79
# import standard Union type
810

911

cola/linalg/trace/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import pkgutil
2+
23
from cola.utils import import_from_all
34

45
__all__ = []

cola/linalg/trace/diag_trace.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,23 @@
11
from functools import reduce
2-
from cola.utils import export, dispatch
3-
from cola.ops.operators import LinearOperator, I_like, Diagonal, Identity
4-
from cola.ops.operators import BlockDiag, ScalarMul, Sum, Dense
5-
from cola.ops.operators import Kronecker, KronSum
6-
from cola.linalg.algorithm_base import Algorithm, Auto
7-
from cola.linalg.trace.diagonal_estimation import Hutch, HutchPP, Exact
2+
83
import numpy as np
94

5+
from cola.linalg.algorithm_base import Algorithm, Auto
6+
from cola.linalg.trace.diagonal_estimation import Exact, Hutch, HutchPP
7+
from cola.ops.operators import (
8+
BlockDiag,
9+
Dense,
10+
Diagonal,
11+
I_like,
12+
Identity,
13+
Kronecker,
14+
KronSum,
15+
LinearOperator,
16+
ScalarMul,
17+
Sum,
18+
)
19+
from cola.utils import dispatch, export
20+
1021

1122
@export
1223
@dispatch.abstract

cola/linalg/trace/diagonal_estimation.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1+
from dataclasses import dataclass
2+
from typing import Any, Optional
3+
14
import numpy as np
2-
from cola.utils import export
3-
from cola.ops import I_like, LinearOperator
5+
46
from cola.linalg.algorithm_base import Algorithm
5-
from dataclasses import dataclass
6-
from typing import Optional, Any
7+
from cola.ops import I_like, LinearOperator
8+
from cola.utils import export
79

810
PRNGKey = Any
911

cola/linalg/unary/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import pkgutil
2+
23
from cola.utils import import_from_all
34

45
__all__ = []

cola/linalg/unary/unary.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,32 @@
1-
from plum import dispatch
21
from dataclasses import dataclass
2+
from functools import reduce
33
from numbers import Number
44
from typing import Callable
5-
from functools import reduce
5+
66
import numpy as np
7-
from plum import parametric
7+
from plum import dispatch, parametric
8+
9+
from cola.annotations import PSD, SelfAdjoint
810
from cola.fns import lazify
9-
from cola.ops import LinearOperator
10-
from cola.ops import Diagonal, Identity, ScalarMul
11-
from cola.ops import BlockDiag, Kronecker, KronSum, I_like, Transpose, Adjoint
12-
from cola.annotations import SelfAdjoint, PSD
1311
from cola.linalg.algorithm_base import Algorithm, Auto
14-
from cola.linalg.inverse.inv import inv
12+
from cola.linalg.decompositions.arnoldi import arnoldi
13+
from cola.linalg.decompositions.decompositions import LU, Arnoldi, Cholesky, Lanczos
14+
from cola.linalg.decompositions.lanczos import lanczos
1515
from cola.linalg.inverse.cg import CG
1616
from cola.linalg.inverse.gmres import GMRES
17-
from cola.linalg.decompositions.lanczos import lanczos
18-
from cola.linalg.decompositions.arnoldi import arnoldi
19-
from cola.linalg.decompositions.decompositions import Arnoldi, Lanczos, LU, Cholesky
17+
from cola.linalg.inverse.inv import inv
18+
from cola.ops import (
19+
Adjoint,
20+
BlockDiag,
21+
Diagonal,
22+
I_like,
23+
Identity,
24+
Kronecker,
25+
KronSum,
26+
LinearOperator,
27+
ScalarMul,
28+
Transpose,
29+
)
2030
from cola.utils import export
2131

2232

0 commit comments

Comments
 (0)