@@ -187,7 +187,7 @@ def solve(self, rhs):
187
187
else :
188
188
if ndim == 2 and rhs .shape [- 1 ] == 1 :
189
189
warnings .warn (
190
- "In the pymatsolver v0.7.0 passing a vector of shape (n, 1) to the solve method "
190
+ "In Future pymatsolver v0.4.0, passing a vector of shape (n, 1) to the solve method "
191
191
"will return an array with shape (n, 1), instead of always returning a flattened array. "
192
192
"This is to be consistent with numpy.linalg.solve broadcasting." ,
193
193
FutureWarning
@@ -199,7 +199,7 @@ def solve(self, rhs):
199
199
# switch last two dimensions
200
200
rhs = np .transpose (rhs , (* range (rhs .ndim - 2 ), - 1 , - 2 ))
201
201
in_shape = rhs .shape
202
- # Then collapse all other vectors into the last dimension
202
+ # Then collapse all other vectors into the first dimension
203
203
rhs = np .reshape (rhs , (- 1 , in_shape [- 1 ]))
204
204
# Then reverse the two axes to get the array to end up in fortran order
205
205
# (which is more common for direct solvers).
@@ -226,7 +226,7 @@ def solve(self, rhs):
226
226
def _solve_single (self , rhs ):
227
227
...
228
228
229
-
229
+ @ abstractmethod
230
230
def _solve_multiple (self , rhs ):
231
231
...
232
232
@@ -263,7 +263,7 @@ def get_attributes(self):
263
263
264
264
class Diagonal (Base ):
265
265
266
- def __init__ (self , A , ** kwargs ):
266
+ def __init__ (self , A , check_accuracy = False , check_rtol = 1e-6 , check_atol = 0 , accuracy_tol = None , ** kwargs ):
267
267
try :
268
268
self ._diagonal = np .asarray (A .diagonal ())
269
269
if not np .all (self ._diagonal ):
@@ -275,7 +275,7 @@ def __init__(self, A, **kwargs):
275
275
kwargs .pop ("is_hermitian" , None )
276
276
is_positive_definite = kwargs .pop ("is_positive_definite" , None )
277
277
super ().__init__ (
278
- A , is_symmetric = True , is_hermitian = True , ** kwargs
278
+ A , is_symmetric = True , is_hermitian = True , check_accuracy = check_accuracy , check_rtol = check_rtol , check_atol = check_atol , accuracy_tol = accuracy_tol , ** kwargs
279
279
)
280
280
if is_positive_definite is None :
281
281
if self .is_real :
@@ -294,13 +294,13 @@ def _solve_multiple(self, rhs):
294
294
295
295
296
296
class TriangularSolver (Base ):
297
- def __init__ (self , A , lower = True , ** kwargs ):
297
+ def __init__ (self , A , lower = True , check_accuracy = False , check_rtol = 1e-6 , check_atol = 0 , accuracy_tol = None , ** kwargs ):
298
298
kwargs .pop ("is_hermitian" , False )
299
299
kwargs .pop ("is_symmetric" , False )
300
300
if not (sp .issparse (A ) and A .format in ['csr' ,'csc' ]):
301
301
A = sp .csc_matrix (A )
302
302
A .sum_duplicates ()
303
- super ().__init__ (A , is_hermitian = False , is_symmetric = False , ** kwargs )
303
+ super ().__init__ (A , is_hermitian = False , is_symmetric = False , check_accuracy = check_accuracy , check_rtol = check_rtol , check_atol = check_atol , accuracy_tol = accuracy_tol , ** kwargs )
304
304
self .lower = lower
305
305
306
306
@property
@@ -320,23 +320,23 @@ def _solve_multiple(self, rhs):
320
320
_solve_single = _solve_multiple
321
321
322
322
def transpose (self ):
323
- transed = super ().transpose ()
324
- transed .lower = not self .lower
325
- return transed
323
+ trans = super ().transpose ()
324
+ trans .lower = not self .lower
325
+ return trans
326
326
327
327
class Forward (TriangularSolver ):
328
328
329
- def __init__ (self , A , ** kwargs ):
329
+ def __init__ (self , A , check_accuracy = False , check_rtol = 1e-6 , check_atol = 0 , accuracy_tol = None , ** kwargs ):
330
330
kwargs .pop ("lower" , None )
331
- super ().__init__ (A , lower = True , ** kwargs )
331
+ super ().__init__ (A , lower = True , check_accuracy = check_accuracy , check_rtol = check_rtol , check_atol = check_atol , accuracy_tol = accuracy_tol , ** kwargs )
332
332
333
333
class Backward (TriangularSolver ):
334
334
335
335
_transpose_class = Forward
336
336
337
- def __init__ (self , A , ** kwargs ):
337
+ def __init__ (self , A , check_accuracy = False , check_rtol = 1e-6 , check_atol = 0 , accuracy_tol = None , ** kwargs ):
338
338
kwargs .pop ("lower" , None )
339
- super ().__init__ (A , lower = False , ** kwargs )
339
+ super ().__init__ (A , lower = False , check_accuracy = check_accuracy , check_rtol = check_rtol , check_atol = check_atol , accuracy_tol = accuracy_tol , ** kwargs )
340
340
341
341
342
342
Forward ._transpose_class = Backward
0 commit comments