Skip to content

Commit 7f81ada

Browse files
authored
Merge pull request #50 from simpeg/conjugate
Add conjugate solve
2 parents 32ade92 + 7d1dc30 commit 7f81ada

File tree

2 files changed

+79
-1
lines changed

2 files changed

+79
-1
lines changed

pymatsolver/solvers.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ class Base(ABC):
4747
__numpy_ufunc__ = True
4848
__array_ufunc__ = None
4949

50+
_is_conjugate = False
51+
5052
def __init__(
5153
self, A, is_symmetric=None, is_positive_definite=False, is_hermitian=None, check_accuracy=False, check_rtol=1e-6, check_atol=0, accuracy_tol=None, **kwargs
5254
):
@@ -251,7 +253,13 @@ def _transpose_class(self):
251253
return self.__class__
252254

253255
def transpose(self):
254-
"""Return the transposed solve operator."""
256+
"""Return the transposed solve operator.
257+
258+
Returns
259+
-------
260+
pymatsolver.solvers.Base
261+
"""
262+
255263
if self.is_symmetric:
256264
return self
257265
if self._transpose_class is None:
@@ -274,6 +282,23 @@ def T(self):
274282
"""
275283
return self.transpose()
276284

285+
def conjugate(self):
286+
"""Return the complex conjugate version of this solver.
287+
288+
Returns
289+
-------
290+
pymatsolver.solvers.Base
291+
"""
292+
if self.is_real:
293+
return self
294+
else:
295+
# make a shallow copy of myself
296+
conjugated = copy.copy(self)
297+
conjugated._is_conjugate = not self._is_conjugate
298+
return conjugated
299+
300+
conj = conjugate
301+
277302
def _compute_accuracy(self, rhs, x):
278303
resid_norm = np.linalg.norm(rhs - self.A @ x)
279304
rhs_norm = np.linalg.norm(rhs)
@@ -308,6 +333,8 @@ def solve(self, rhs):
308333
if ndim == 1:
309334
if len(rhs) != n:
310335
raise ValueError(f'Expected a vector of length {n}, got {len(rhs)}')
336+
if self._is_conjugate:
337+
rhs = rhs.conjugate()
311338
x = self._solve_single(rhs)
312339
else:
313340
if ndim == 2 and rhs.shape[-1] == 1:
@@ -331,6 +358,8 @@ def solve(self, rhs):
331358
# (which is more common for direct solvers).
332359
rhs = rhs.transpose()
333360
# should end up with shape (n, -1)
361+
if self._is_conjugate:
362+
rhs = rhs.conjugate()
334363
x = self._solve_multiple(rhs)
335364
if do_broadcast:
336365
# undo the reshaping above
@@ -347,6 +376,9 @@ def solve(self, rhs):
347376
#TODO remove this in v0.4.0.
348377
if x.size == n:
349378
x = x.reshape(-1)
379+
380+
if self._is_conjugate:
381+
x = x.conjugate()
350382
return x
351383

352384
@abstractmethod

tests/test_conjugate.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import pytest
2+
import pymatsolver
3+
import numpy as np
4+
import scipy.sparse as sp
5+
import numpy.testing as npt
6+
7+
8+
@pytest.mark.parametrize('solver_class', [pymatsolver.Solver, pymatsolver.SolverLU, pymatsolver.Pardiso, pymatsolver.Mumps])
9+
@pytest.mark.parametrize('dtype', [np.float64, np.complex128])
10+
@pytest.mark.parametrize('n_rhs', [1, 4])
11+
def test_conjugate_solve(solver_class, dtype, n_rhs):
12+
if solver_class is pymatsolver.Pardiso and not pymatsolver.AvailableSolvers['Pardiso']:
13+
pytest.skip("pydiso not installed.")
14+
if solver_class is pymatsolver.Mumps and not pymatsolver.AvailableSolvers['Mumps']:
15+
pytest.skip("python-mumps not installed.")
16+
17+
n = 10
18+
D = sp.diags(np.linspace(1, 10, n))
19+
if dtype == np.float64:
20+
L = sp.diags([1, -1], [0, -1], shape=(n, n))
21+
22+
sol = np.linspace(0.9, 1.1, n)
23+
# non-symmetric real matrix
24+
else:
25+
# non-symmetric
26+
L = sp.diags([1, -1j], [0, -1], shape=(n, n))
27+
sol = np.linspace(0.9, 1.1, n) - 1j * np.linspace(0.9, 1.1, n)[::-1]
28+
29+
if n_rhs > 1:
30+
sol = np.pad(sol[:, None], [(0, 0), (0, n_rhs - 1)], mode='constant')
31+
32+
A = D @ L @ D @ L.T
33+
34+
# double check it solves
35+
rhs = A @ sol
36+
Ainv = solver_class(A)
37+
npt.assert_allclose(Ainv @ rhs, sol)
38+
39+
# is conjugate solve correct?
40+
rhs_conj = A.conjugate() @ sol
41+
Ainv_conj = Ainv.conjugate()
42+
npt.assert_allclose(Ainv_conj @ rhs_conj, sol)
43+
44+
# is conjugate -> conjugate solve correct?
45+
Ainv2 = Ainv_conj.conjugate()
46+
npt.assert_allclose(Ainv2 @ rhs, sol)

0 commit comments

Comments
 (0)