@@ -77,7 +77,11 @@ cdef class {SCALAR_label}LinearOperator:
7777
7878 def __add__ (self , x ):
7979 if isinstance (x, {SCALAR_label}LinearOperator):
80- if isinstance (self , {SCALAR_label}Multiply_Linear_Operator):
80+ if isinstance (x, {SCALAR_label}nullOperator):
81+ return self
82+ elif isinstance (self , {SCALAR_label}nullOperator):
83+ return x
84+ elif isinstance (self , {SCALAR_label}Multiply_Linear_Operator):
8185 if isinstance (x, {SCALAR_label}Multiply_Linear_Operator):
8286 return {SCALAR_label}TimeStepperLinearOperator(self .A, x.A, x.factor, self .factor)
8387 else :
@@ -119,9 +123,15 @@ cdef class {SCALAR_label}LinearOperator:
119123 tsOp = self
120124 return {SCALAR_label}TimeStepperLinearOperator(tsOp.M, tsOp.S, tsOp.facS* x, tsOp.facM* x)
121125 elif isinstance (self , {SCALAR_label}LinearOperator) and isinstance (x, (float , int , {SCALAR})):
122- return {SCALAR_label}Multiply_Linear_Operator(self , x)
126+ if x == 0 :
127+ return {SCALAR_label}nullOperator(self .num_rows, self .num_columns)
128+ else :
129+ return {SCALAR_label}Multiply_Linear_Operator(self , x)
123130 elif isinstance (x, {SCALAR_label}LinearOperator) and isinstance (self , (float , int , {SCALAR})):
124- return {SCALAR_label}Multiply_Linear_Operator(x, self )
131+ if self == 0 :
132+ return {SCALAR_label}nullOperator(x.num_rows, x.num_columns)
133+ else :
134+ return {SCALAR_label}Multiply_Linear_Operator(x, self )
125135 elif isinstance (x, complex ):
126136 if isinstance (self , ComplexLinearOperator):
127137 return {SCALAR_label}Multiply_Linear_Operator(self , COMPLEX(x))
@@ -142,7 +152,10 @@ cdef class {SCALAR_label}LinearOperator:
142152 tsOp = self
143153 return {SCALAR_label}TimeStepperLinearOperator(tsOp.M, tsOp.S, tsOp.facS* x, tsOp.facM* x)
144154 else :
145- return {SCALAR_label}Multiply_Linear_Operator(self , x)
155+ if x == 0 :
156+ return {SCALAR_label}nullOperator(self .num_rows, self .num_columns)
157+ else :
158+ return {SCALAR_label}Multiply_Linear_Operator(self , x)
146159 else :
147160 raise NotImplementedError (' Cannot multiply with {}' .format(x))
148161
@@ -334,11 +347,17 @@ cdef class {SCALAR_label}TimeStepperLinearOperator({SCALAR_label}LinearOperator)
334347 self .S.matvec(x, y)
335348 if self .facS != 1.0 :
336349 scaleScalar(y, self .facS)
337- if self .facM == 1.0 :
338- self .M.matvec_no_overwrite(x, y)
350+ if self .facM == 1.0 :
351+ self .M.matvec_no_overwrite(x, y)
352+ else :
353+ self .M.matvec(x, self .z)
354+ assign3(y, y, 1.0 , self .z, self .facM)
339355 else :
340- self .M.matvec(x, self .z)
341- assign3(y, y, 1.0 , self .z, self .facM)
356+ if self .facM == 1.0 :
357+ self .M.matvec(x, y)
358+ else :
359+ self .M.matvec(x, y)
360+ scaleScalar(y, self .facM)
342361 return 0
343362
344363 cdef INDEX_t matvec_no_overwrite(self ,
@@ -804,3 +823,30 @@ cdef class {SCALAR_label}Transpose_Linear_Operator({SCALAR_label}LinearOperator)
804823
805824 def getMemorySize (self ):
806825 return self .A.getMemorySize()
826+
827+
828+ cdef class {SCALAR_label}nullOperator({SCALAR_label}LinearOperator):
829+ def __init__ (self , INDEX_t num_rows , INDEX_t num_columns ):
830+ super ({SCALAR_label}nullOperator, self ).__init__(num_rows, num_columns)
831+
832+ cdef INDEX_t matvec(self ,
833+ {SCALAR}_t[::1 ] x,
834+ {SCALAR}_t[::1 ] y) except - 1 :
835+ cdef:
836+ INDEX_t i
837+ for i in range (self .num_rows):
838+ y[i] = 0.
839+ return 0
840+
841+ cdef INDEX_t matvec_no_overwrite(self ,
842+ {SCALAR}_t[::1 ] x,
843+ {SCALAR}_t[::1 ] y) except - 1 :
844+ return 0
845+
846+ def toarray (self ):
847+ return np.zeros((self .num_rows, self .num_columns), dtype = {SCALAR})
848+
849+ def get_diagonal (self ):
850+ return np.zeros((min (self .num_rows, self .num_columns)), dtype = {SCALAR})
851+
852+ diagonal = property(fget = get_diagonal)
0 commit comments