Skip to content

Commit 169c39f

Browse files
committed
updates
Signed-off-by: Christian Glusa <caglusa@sandia.gov>
1 parent 11cd6a9 commit 169c39f

21 files changed

+747
-549
lines changed

base/PyNucleus_base/LinearOperator_decl_{SCALAR}.pxi

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,3 +132,12 @@ cdef class {SCALAR_label}Transpose_Linear_Operator({SCALAR_label}LinearOperator)
132132
{SCALAR}_t[::1] rhs,
133133
{SCALAR}_t[::1] result,
134134
BOOL_t simpleResidual=*)
135+
136+
137+
cdef class {SCALAR_label}nullOperator({SCALAR_label}LinearOperator):
138+
cdef INDEX_t matvec(self,
139+
{SCALAR}_t[::1] x,
140+
{SCALAR}_t[::1] y) except -1
141+
cdef INDEX_t matvec_no_overwrite(self,
142+
{SCALAR}_t[::1] x,
143+
{SCALAR}_t[::1] y) except -1

base/PyNucleus_base/LinearOperator_{SCALAR}.pxi

Lines changed: 54 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,11 @@ cdef class {SCALAR_label}LinearOperator:
7777

7878
def __add__(self, x):
7979
if isinstance(x, {SCALAR_label}LinearOperator):
80-
if isinstance(self, {SCALAR_label}Multiply_Linear_Operator):
80+
if isinstance(x, {SCALAR_label}nullOperator):
81+
return self
82+
elif isinstance(self, {SCALAR_label}nullOperator):
83+
return x
84+
elif isinstance(self, {SCALAR_label}Multiply_Linear_Operator):
8185
if isinstance(x, {SCALAR_label}Multiply_Linear_Operator):
8286
return {SCALAR_label}TimeStepperLinearOperator(self.A, x.A, x.factor, self.factor)
8387
else:
@@ -119,9 +123,15 @@ cdef class {SCALAR_label}LinearOperator:
119123
tsOp = self
120124
return {SCALAR_label}TimeStepperLinearOperator(tsOp.M, tsOp.S, tsOp.facS*x, tsOp.facM*x)
121125
elif isinstance(self, {SCALAR_label}LinearOperator) and isinstance(x, (float, int, {SCALAR})):
122-
return {SCALAR_label}Multiply_Linear_Operator(self, x)
126+
if x == 0:
127+
return {SCALAR_label}nullOperator(self.num_rows, self.num_columns)
128+
else:
129+
return {SCALAR_label}Multiply_Linear_Operator(self, x)
123130
elif isinstance(x, {SCALAR_label}LinearOperator) and isinstance(self, (float, int, {SCALAR})):
124-
return {SCALAR_label}Multiply_Linear_Operator(x, self)
131+
if self == 0:
132+
return {SCALAR_label}nullOperator(x.num_rows, x.num_columns)
133+
else:
134+
return {SCALAR_label}Multiply_Linear_Operator(x, self)
125135
elif isinstance(x, complex):
126136
if isinstance(self, ComplexLinearOperator):
127137
return {SCALAR_label}Multiply_Linear_Operator(self, COMPLEX(x))
@@ -142,7 +152,10 @@ cdef class {SCALAR_label}LinearOperator:
142152
tsOp = self
143153
return {SCALAR_label}TimeStepperLinearOperator(tsOp.M, tsOp.S, tsOp.facS*x, tsOp.facM*x)
144154
else:
145-
return {SCALAR_label}Multiply_Linear_Operator(self, x)
155+
if x == 0:
156+
return {SCALAR_label}nullOperator(self.num_rows, self.num_columns)
157+
else:
158+
return {SCALAR_label}Multiply_Linear_Operator(self, x)
146159
else:
147160
raise NotImplementedError('Cannot multiply with {}'.format(x))
148161

@@ -334,11 +347,17 @@ cdef class {SCALAR_label}TimeStepperLinearOperator({SCALAR_label}LinearOperator)
334347
self.S.matvec(x, y)
335348
if self.facS != 1.0:
336349
scaleScalar(y, self.facS)
337-
if self.facM == 1.0:
338-
self.M.matvec_no_overwrite(x, y)
350+
if self.facM == 1.0:
351+
self.M.matvec_no_overwrite(x, y)
352+
else:
353+
self.M.matvec(x, self.z)
354+
assign3(y, y, 1.0, self.z, self.facM)
339355
else:
340-
self.M.matvec(x, self.z)
341-
assign3(y, y, 1.0, self.z, self.facM)
356+
if self.facM == 1.0:
357+
self.M.matvec(x, y)
358+
else:
359+
self.M.matvec(x, y)
360+
scaleScalar(y, self.facM)
342361
return 0
343362

344363
cdef INDEX_t matvec_no_overwrite(self,
@@ -804,3 +823,30 @@ cdef class {SCALAR_label}Transpose_Linear_Operator({SCALAR_label}LinearOperator)
804823

805824
def getMemorySize(self):
806825
return self.A.getMemorySize()
826+
827+
828+
cdef class {SCALAR_label}nullOperator({SCALAR_label}LinearOperator):
829+
def __init__(self, INDEX_t num_rows, INDEX_t num_columns):
830+
super({SCALAR_label}nullOperator, self).__init__(num_rows, num_columns)
831+
832+
cdef INDEX_t matvec(self,
833+
{SCALAR}_t[::1] x,
834+
{SCALAR}_t[::1] y) except -1:
835+
cdef:
836+
INDEX_t i
837+
for i in range(self.num_rows):
838+
y[i] = 0.
839+
return 0
840+
841+
cdef INDEX_t matvec_no_overwrite(self,
842+
{SCALAR}_t[::1] x,
843+
{SCALAR}_t[::1] y) except -1:
844+
return 0
845+
846+
def toarray(self):
847+
return np.zeros((self.num_rows, self.num_columns), dtype={SCALAR})
848+
849+
def get_diagonal(self):
850+
return np.zeros((min(self.num_rows, self.num_columns)), dtype={SCALAR})
851+
852+
diagonal = property(fget=get_diagonal)

base/PyNucleus_base/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from . utilsFem import driver, problem
1212
from . myTypes import REAL, INDEX, COMPLEX
1313
from . blas import uninitialized, uninitialized_like
14+
from . timestepping import timestepperFactory
1415

1516

1617
def get_include():
@@ -52,5 +53,5 @@ def get_include():
5253

5354

5455
__all__ = ['REAL', 'INDEX', 'COMPLEX',
55-
'solverFactory',
56+
'solverFactory', 'timestepperFactory',
5657
'driver', 'problem']

base/PyNucleus_base/linear_operators.pxd

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -96,15 +96,6 @@ cdef class blockOperator(LinearOperator):
9696
REAL_t[::1] y) except -1
9797

9898

99-
cdef class nullOperator(LinearOperator):
100-
cdef INDEX_t matvec(self,
101-
REAL_t[::1] x,
102-
REAL_t[::1] y) except -1
103-
cdef INDEX_t matvec_no_overwrite(self,
104-
REAL_t[::1] x,
105-
REAL_t[::1] y) except -1
106-
107-
10899
cdef class identityOperator(LinearOperator):
109100
cdef:
110101
REAL_t alpha

base/PyNucleus_base/linear_operators.pyx

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -908,33 +908,6 @@ cdef class blockDiagonalOperator(blockOperator):
908908
super(blockDiagonalOperator, self).__init__(subblocks)
909909

910910

911-
cdef class nullOperator(LinearOperator):
912-
def __init__(self, INDEX_t num_rows, INDEX_t num_columns):
913-
super(nullOperator, self).__init__(num_rows, num_columns)
914-
915-
cdef INDEX_t matvec(self,
916-
REAL_t[::1] x,
917-
REAL_t[::1] y) except -1:
918-
cdef:
919-
INDEX_t i
920-
for i in range(self.num_rows):
921-
y[i] = 0.
922-
return 0
923-
924-
cdef INDEX_t matvec_no_overwrite(self,
925-
REAL_t[::1] x,
926-
REAL_t[::1] y) except -1:
927-
return 0
928-
929-
def toarray(self):
930-
return np.zeros((self.num_rows, self.num_columns), dtype=REAL)
931-
932-
def get_diagonal(self):
933-
return np.zeros((min(self.num_rows, self.num_columns)), dtype=REAL)
934-
935-
diagonal = property(fget=get_diagonal)
936-
937-
938911
cdef class identityOperator(LinearOperator):
939912
def __init__(self, INDEX_t num_rows, REAL_t alpha=1.0):
940913
super(identityOperator, self).__init__(num_rows, num_rows)

0 commit comments

Comments
 (0)