Skip to content

Commit 58eb8d1

Browse files
committed
add base kwargs to subclasses
1 parent 76c9c34 commit 58eb8d1

File tree

5 files changed

+37
-35
lines changed

5 files changed

+37
-35
lines changed

pymatsolver/direct/mumps.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ class Mumps(Base):
77
"""
88
_transposed = False
99

10-
def __init__(self, A, ordering=None, **kwargs):
11-
super().__init__(A, **kwargs)
10+
def __init__(self, A, ordering=None, is_symmetric=None, is_positive_definite=False, is_hermitian=None, check_accuracy=False, check_rtol=1e-6, check_atol=0, accuracy_tol=None, **kwargs):
11+
super().__init__(A, is_symmetric=is_symmetric, is_positive_definite=is_positive_definite, is_hermitian=is_hermitian, check_accuracy=check_accuracy, check_rtol=check_rtol, check_atol=check_atol, accuracy_tol=accuracy_tol, **kwargs)
1212
if ordering is None:
1313
ordering = "metis"
1414
self.ordering = ordering
@@ -19,7 +19,6 @@ def _set_A(self, A):
1919
self.solver.set_matrix(
2020
A,
2121
symmetric=self.is_symmetric,
22-
# positive_definite=self.is_positive_definite # doesn't (yet) support setting positive definiteness
2322
)
2423

2524
@property

pymatsolver/direct/pardiso.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ class Pardiso(Base):
1616

1717
_transposed = False
1818

19-
def __init__(self, A, n_threads=None, **kwargs):
20-
super().__init__(A, **kwargs)
19+
def __init__(self, A, n_threads=None, is_symmetric=None, is_positive_definite=False, is_hermitian=None, check_accuracy=False, check_rtol=1e-6, check_atol=0, accuracy_tol=None, **kwargs):
20+
super().__init__(A, is_symmetric=is_symmetric, is_positive_definite=is_positive_definite, is_hermitian=is_hermitian, check_accuracy=check_accuracy, check_rtol=check_rtol, check_atol=check_atol, accuracy_tol=accuracy_tol, **kwargs)
2121
self.solver = MKLPardisoSolver(
2222
self.A,
2323
matrix_type=self._matrixType(),

pymatsolver/iterative.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,13 @@
1717
class BiCGJacobi(Base):
1818
"""Bicg Solver with Jacobi preconditioner"""
1919

20-
def __init__(self, A, symmetric=None, maxiter=1000, rtol=1E-6, atol=0.0, **kwargs):
20+
def __init__(self, A, symmetric=None, maxiter=1000, rtol=1E-6, atol=0.0, check_accuracy=False, check_rtol=1e-6, check_atol=0, accuracy_tol=None, **kwargs):
2121
if symmetric is not None:
2222
warnings.warn(
23-
"The symmetric keyword argument is being deprecated and will be removed in pymatsolver 0.7.0",
23+
"The symmetric keyword argument is unused and is deprecated. It will be removed in pymatsolver 0.7.0.",
2424
DeprecationWarning, stacklevel=2
2525
)
26-
super().__init__(A, **kwargs)
26+
super().__init__(A, check_accuracy=check_accuracy, check_rtol=check_rtol, check_atol=check_atol, accuracy_tol=accuracy_tol, **kwargs)
2727
self._factored = False
2828
self.maxiter = maxiter
2929
self.rtol = rtol

pymatsolver/solvers.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ def solve(self, rhs):
187187
else:
188188
if ndim == 2 and rhs.shape[-1] == 1:
189189
warnings.warn(
190-
"In the pymatsolver v0.7.0 passing a vector of shape (n, 1) to the solve method "
190+
"In Future pymatsolver v0.4.0, passing a vector of shape (n, 1) to the solve method "
191191
"will return an array with shape (n, 1), instead of always returning a flattened array. "
192192
"This is to be consistent with numpy.linalg.solve broadcasting.",
193193
FutureWarning
@@ -199,7 +199,7 @@ def solve(self, rhs):
199199
# switch last two dimensions
200200
rhs = np.transpose(rhs, (*range(rhs.ndim-2), -1, -2))
201201
in_shape = rhs.shape
202-
# Then collapse all other vectors into the last dimension
202+
# Then collapse all other vectors into the first dimension
203203
rhs = np.reshape(rhs, (-1, in_shape[-1]))
204204
# Then reverse the two axes to get the array to end up in fortran order
205205
# (which is more common for direct solvers).
@@ -226,7 +226,7 @@ def solve(self, rhs):
226226
def _solve_single(self, rhs):
227227
...
228228

229-
229+
@abstractmethod
230230
def _solve_multiple(self, rhs):
231231
...
232232

@@ -263,7 +263,7 @@ def get_attributes(self):
263263

264264
class Diagonal(Base):
265265

266-
def __init__(self, A, **kwargs):
266+
def __init__(self, A, check_accuracy=False, check_rtol=1e-6, check_atol=0, accuracy_tol=None, **kwargs):
267267
try:
268268
self._diagonal = np.asarray(A.diagonal())
269269
if not np.all(self._diagonal):
@@ -275,7 +275,7 @@ def __init__(self, A, **kwargs):
275275
kwargs.pop("is_hermitian", None)
276276
is_positive_definite = kwargs.pop("is_positive_definite", None)
277277
super().__init__(
278-
A, is_symmetric=True, is_hermitian=True, **kwargs
278+
A, is_symmetric=True, is_hermitian=True, check_accuracy=check_accuracy, check_rtol=check_rtol, check_atol=check_atol, accuracy_tol=accuracy_tol, **kwargs
279279
)
280280
if is_positive_definite is None:
281281
if self.is_real:
@@ -294,13 +294,13 @@ def _solve_multiple(self, rhs):
294294

295295

296296
class TriangularSolver(Base):
297-
def __init__(self, A, lower=True, **kwargs):
297+
def __init__(self, A, lower=True, check_accuracy=False, check_rtol=1e-6, check_atol=0, accuracy_tol=None, **kwargs):
298298
kwargs.pop("is_hermitian", False)
299299
kwargs.pop("is_symmetric", False)
300300
if not (sp.issparse(A) and A.format in ['csr','csc']):
301301
A = sp.csc_matrix(A)
302302
A.sum_duplicates()
303-
super().__init__(A, is_hermitian=False, is_symmetric=False, **kwargs)
303+
super().__init__(A, is_hermitian=False, is_symmetric=False, check_accuracy=check_accuracy, check_rtol=check_rtol, check_atol=check_atol, accuracy_tol=accuracy_tol, **kwargs)
304304
self.lower = lower
305305

306306
@property
@@ -320,23 +320,23 @@ def _solve_multiple(self, rhs):
320320
_solve_single = _solve_multiple
321321

322322
def transpose(self):
323-
transed = super().transpose()
324-
transed.lower = not self.lower
325-
return transed
323+
trans = super().transpose()
324+
trans.lower = not self.lower
325+
return trans
326326

327327
class Forward(TriangularSolver):
328328

329-
def __init__(self, A, **kwargs):
329+
def __init__(self, A, check_accuracy=False, check_rtol=1e-6, check_atol=0, accuracy_tol=None, **kwargs):
330330
kwargs.pop("lower", None)
331-
super().__init__(A, lower=True, **kwargs)
331+
super().__init__(A, lower=True, check_accuracy=check_accuracy, check_rtol=check_rtol, check_atol=check_atol, accuracy_tol=accuracy_tol, **kwargs)
332332

333333
class Backward(TriangularSolver):
334334

335335
_transpose_class = Forward
336336

337-
def __init__(self, A, **kwargs):
337+
def __init__(self, A, check_accuracy=False, check_rtol=1e-6, check_atol=0, accuracy_tol=None, **kwargs):
338338
kwargs.pop("lower", None)
339-
super().__init__(A, lower=False, **kwargs)
339+
super().__init__(A, lower=False, check_accuracy=check_accuracy, check_rtol=check_rtol, check_atol=check_atol, accuracy_tol=accuracy_tol, **kwargs)
340340

341341

342342
Forward._transpose_class = Backward

pymatsolver/wrappers.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -71,13 +71,11 @@ def WrapDirect(fun, factorize=True, name=None):
7171
>>> SolverLU = pymatsolver.WrapDirect(splu, factorize=True)
7272
"""
7373

74-
def __init__(self, A, **kwargs):
75-
pymatsolver_options = {}
76-
if (check_accuracy := kwargs.pop('check_accuracy', None)) is not None:
77-
pymatsolver_options['check_accuracy'] = check_accuracy
78-
if (accuracy_tol := kwargs.pop('accuracy_tol', None)) is not None:
79-
pymatsolver_options['accuracy_tol'] = accuracy_tol
80-
Base.__init__(self, A, **pymatsolver_options)
74+
def __init__(self, A, check_accuracy=False, check_rtol=1E-6, check_atol=0, accuracy_tol=None, **kwargs):
75+
Base.__init__(
76+
self, A, check_accuracy=check_accuracy, check_rtol=check_rtol, check_atol=check_atol, accuracy_tol=accuracy_tol,
77+
is_symmetric=False, is_hermitian=False
78+
)
8179
self.kwargs = kwargs
8280
if factorize:
8381
self.solver = fun(self.A, **self.kwargs)
@@ -127,7 +125,7 @@ def clean(self):
127125
)
128126

129127

130-
def WrapIterative(fun, check_accuracy=False, accuracy_tol=1e-6, name=None):
128+
def WrapIterative(fun, check_accuracy=None, accuracy_tol=None, name=None):
131129
"""
132130
Wraps an iterative Solver.
133131
@@ -148,11 +146,16 @@ def WrapIterative(fun, check_accuracy=False, accuracy_tol=1e-6, name=None):
148146
>>> SolverCG = pymatsolver.WrapIterative(cg)
149147
150148
"""
151-
152-
def __init__(self, A, **kwargs):
153-
check_acc = kwargs.pop('check_accuracy', check_accuracy)
154-
acc_tol = kwargs.pop('accuracy_tol', accuracy_tol)
155-
Base.__init__(self, A, check_accuracy=check_acc, accuracy_tol=acc_tol)
149+
if check_accuracy is not None or accuracy_tol is not None:
150+
warnings.warn('check_accuracy and accuracy_tol were unused and are now deprecated. They '
151+
'will be removed in pymatsolver v0.4.0. Please pass the keyword arguments `check_rtol` '
152+
'and check_atol directly to the wrapped solver class.', FutureWarning)
153+
154+
def __init__(self, A, check_accuracy=False, check_rtol=1E-6, check_atol=0, accuracy_tol=None, **kwargs):
155+
Base.__init__(
156+
self, A, check_accuracy=check_accuracy, check_rtol=check_rtol, check_atol=check_atol, accuracy_tol=accuracy_tol,
157+
is_symmetric=False, is_hermitian=False
158+
)
156159
self.kwargs = kwargs
157160

158161
@property

0 commit comments

Comments
 (0)