Skip to content

Commit 379cd81

Browse files
committed
refactor to remove hard dependency on dirichlet BCs
1 parent 07b97ce commit 379cd81

File tree

4 files changed

+50
-46
lines changed

4 files changed

+50
-46
lines changed

example/linalg/example_solve_custom.f90

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ subroutine solve_pccg_custom(A,b,x,di,tol,maxiter,restart,workspace)
5656
where(abs(diagonal)>epsilon(0.d0)) diagonal = 1._dp/diagonal
5757
!-------------------------
5858
! main call to the solver
59-
call solve_pccg_generic(op,M,b,x,di_,tol_,maxiter_,restart_,workspace_)
59+
call solve_pccg_generic(op,M,b,x,tol_,maxiter_,workspace_)
6060

6161
!-------------------------
6262
! internal memory cleanup
@@ -70,15 +70,20 @@ subroutine solve_pccg_custom(A,b,x,di,tol,maxiter,restart,workspace)
7070
workspace_ => null()
7171
contains
7272

73-
subroutine my_apply(x,y)
73+
subroutine my_apply(x,y,alpha,beta)
7474
real(dp), intent(in) :: x(:)
7575
real(dp), intent(inout) :: y(:)
76-
call spmv( A , x, y )
76+
real(dp), intent(in) :: alpha
77+
real(dp), intent(in) :: beta
78+
call spmv( A , x, y , alpha, beta )
79+
y = merge( 0._dp, y, di_ )
7780
end subroutine
78-
subroutine my_jacobi_preconditionner(x,y)
81+
subroutine my_jacobi_preconditionner(x,y,alpha,beta)
7982
real(dp), intent(in) :: x(:)
8083
real(dp), intent(inout) :: y(:)
81-
y = diagonal * x
84+
real(dp), intent(in) :: alpha
85+
real(dp), intent(in) :: beta
86+
y = merge( 0._dp, diagonal * x , di_ )
8287
end subroutine
8388
pure real(dp) function my_dot(x,y) result(r)
8489
real(dp), intent(in) :: x(:)

src/stdlib_linalg_iterative_solvers.fypp

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,12 @@ module stdlib_linalg_iterative_solvers
3636

3737
abstract interface
3838
#:for k, t, s in R_KINDS_TYPES
39-
subroutine vector_sub_${s}$(x,y)
39+
subroutine vector_sub_${s}$(x,y,alpha,beta)
4040
import :: ${k}$
4141
${t}$, intent(in) :: x(:)
4242
${t}$, intent(inout) :: y(:)
43+
${t}$, intent(in) :: alpha
44+
${t}$, intent(in) :: beta
4345
end subroutine
4446
pure ${t}$ function reduction_sub_${s}$(x,y) result(r)
4547
import :: ${k}$
@@ -57,14 +59,12 @@ module stdlib_linalg_iterative_solvers
5759

5860
interface solve_cg_generic
5961
#:for k, t, s in R_KINDS_TYPES
60-
module subroutine solve_cg_generic_${s}$(A,b,x,di,tol,maxiter,restart,workspace)
62+
module subroutine solve_cg_generic_${s}$(A,b,x,tol,maxiter,workspace)
6163
class(linop_${s}$), intent(in) :: A
6264
${t}$, intent(in) :: b(:)
6365
${t}$, intent(inout) :: x(:)
6466
${t}$, intent(in) :: tol
65-
logical(1), intent(in) :: di(:)
6667
integer, intent(in) :: maxiter
67-
logical, intent(in) :: restart
6868
type(solver_workspace_${s}$), intent(inout) :: workspace
6969
end subroutine
7070
#:endfor
@@ -95,15 +95,13 @@ module stdlib_linalg_iterative_solvers
9595

9696
interface solve_pccg_generic
9797
#:for k, t, s in R_KINDS_TYPES
98-
module subroutine solve_pccg_generic_${s}$(A,M,b,x,di,tol,maxiter,restart,workspace)
98+
module subroutine solve_pccg_generic_${s}$(A,M,b,x,tol,maxiter,workspace)
9999
class(linop_${s}$), intent(in) :: A
100100
class(linop_${s}$), intent(in) :: M !> preconditionner
101101
${t}$, intent(in) :: b(:)
102102
${t}$, intent(inout) :: x(:)
103103
${t}$, intent(in) :: tol
104-
logical(1), intent(in) :: di(:)
105104
integer, intent(in) :: maxiter
106-
logical, intent(in) :: restart
107105
type(solver_workspace_${s}$), intent(inout) :: workspace
108106
end subroutine
109107
#:endfor
@@ -124,7 +122,7 @@ module stdlib_linalg_iterative_solvers
124122
${t}$, intent(in), optional :: tol
125123
logical(1), intent(in), optional, target :: di(:)
126124
integer, intent(in), optional :: maxiter
127-
logical, intent(in), optional :: restart
125+
logical, intent(in), optional :: restart
128126
class(linop_${s}$), optional , intent(in), target :: M !> preconditionner
129127
type(solver_workspace_${s}$), optional, intent(inout), target :: workspace
130128
end subroutine

src/stdlib_linalg_iterative_solvers_cg.fypp

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,11 @@ submodule(stdlib_linalg_iterative_solvers) stdlib_linalg_iterative_cg
1414
contains
1515

1616
#:for k, t, s in R_KINDS_TYPES
17-
module subroutine solve_cg_generic_${s}$(A,b,x,di,tol,maxiter,restart,workspace)
17+
module subroutine solve_cg_generic_${s}$(A,b,x,tol,maxiter,workspace)
1818
class(linop_${s}$), intent(in) :: A
1919
${t}$, intent(in) :: b(:), tol
2020
${t}$, intent(inout) :: x(:)
21-
logical(1), intent(in) :: di(:)
2221
integer, intent(in) :: maxiter
23-
logical, intent(in) :: restart
2422
type(solver_workspace_${s}$), intent(inout) :: workspace
2523
!-------------------------
2624
integer :: iter
@@ -35,11 +33,8 @@ contains
3533
norm_sq0 = A%inner_product(B, B)
3634
if(associated(workspace%callback)) call workspace%callback(x, norm_sq0, iter)
3735

38-
if(restart) X = zero_${s}$
39-
X = merge( B, X, di ) !> copy dirichlet load conditions encoded in B and indicated by di
40-
41-
call A%apply(X, R)
42-
R = merge( zero_${s}$, B - R , di ) !> R = B - A*X
36+
R = B
37+
call A%apply(X, R, alpha= -one_${s}$, beta=one_${s}$) !> R = B - A*X
4338
norm_sq = A%inner_product(R, R)
4439

4540
P = R
@@ -48,8 +43,7 @@ contains
4843
beta = zero_${s}$
4944
if(associated(workspace%callback)) call workspace%callback(x, norm_sq, iter)
5045
do while( norm_sq > tolsq * norm_sq0 .and. iter < maxiter)
51-
call A%apply(P,Ap)
52-
Ap = merge( zero_${s}$, Ap, di )
46+
call A%apply(P,Ap, alpha= one_${s}$, beta=zero_${s}$) !> Ap = A*P
5347

5448
alpha = norm_sq / A%inner_product(P, Ap)
5549

@@ -116,7 +110,9 @@ contains
116110
if(.not.allocated(workspace_%tmp)) allocate( workspace_%tmp(n,size_wksp_cg), source = zero_${s}$ )
117111
!-------------------------
118112
! main call to the solver
119-
call solve_cg_generic(op,b,x,di_,tol_,maxiter_,restart_,workspace_)
113+
if(restart_) x = zero_${s}$
114+
x = merge( b, x, di_ ) !> copy dirichlet load conditions encoded in B and indicated by di
115+
call solve_cg_generic(op,b,x,tol_,maxiter_,workspace_)
120116

121117
!-------------------------
122118
! internal memory cleanup
@@ -130,14 +126,17 @@ contains
130126
workspace_ => null()
131127
contains
132128

133-
subroutine default_matvec(x,y)
129+
subroutine default_matvec(x,y,alpha,beta)
134130
${t}$, intent(in) :: x(:)
135131
${t}$, intent(inout) :: y(:)
132+
${t}$, intent(in) :: alpha
133+
${t}$, intent(in) :: beta
136134
#:if matrix == "dense"
137-
y = matmul(A,x)
135+
y = alpha * matmul(A,x) + beta * y
138136
#:else
139-
call spmv( A , x, y )
137+
call spmv( A , x, y , alpha, beta )
140138
#:endif
139+
y = merge( zero_${s}$, y, di_ )
141140
end subroutine
142141
end subroutine
143142

src/stdlib_linalg_iterative_solvers_pccg.fypp

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,12 @@ submodule(stdlib_linalg_iterative_solvers) stdlib_linalg_iterative_pccg
1414
contains
1515

1616
#:for k, t, s in R_KINDS_TYPES
17-
module subroutine solve_pccg_generic_${s}$(A,M,b,x,di,tol,maxiter,restart,workspace)
17+
module subroutine solve_pccg_generic_${s}$(A,M,b,x,tol,maxiter,workspace)
1818
class(linop_${s}$), intent(in) :: A
1919
class(linop_${s}$), intent(in) :: M !> preconditionner
2020
${t}$, intent(in) :: b(:), tol
2121
${t}$, intent(inout) :: x(:)
22-
logical(1), intent(in) :: di(:)
2322
integer, intent(in) :: maxiter
24-
logical, intent(in) :: restart
2523
type(solver_workspace_${s}$), intent(inout) :: workspace
2624
!-------------------------
2725
integer :: iter
@@ -38,40 +36,37 @@ contains
3836
if(associated(workspace%callback)) call workspace%callback(x, norm_sq, iter)
3937

4038
if ( norm_sq0 > zero_${s}$ ) then
41-
if(restart) X = zero_${s}$
42-
X = merge( B, X, di ) !> copy dirichlet load conditions encoded in B and indicated by di
4339

44-
call A%apply(X, R)
45-
R = merge( zero_${s}$, B - R , di ) !> R = B - A*X
40+
R = B
41+
call A%apply(X, R, alpha= -one_${s}$, beta=one_${s}$) !> R = B - A*X
4642

47-
call M%apply(R,P) !> P = M^{-1}*R
48-
P = merge( zero_${s}$, P, di )
43+
call M%apply(R,P, alpha= one_${s}$, beta=zero_${s}$) !> P = M^{-1}*R
4944

5045
tolsq = tol*tol
5146

5247
zr1 = zero_${s}$
5348
zr2 = one_${s}$
5449
do while ( (iter < maxiter) .AND. (norm_sq > tolsq * norm_sq0) )
5550

56-
call M%apply(R,S) !> S = M^{-1}*R
57-
S = merge( zero_${s}$, S, di )
51+
call M%apply(R,S, alpha= one_${s}$, beta=zero_${s}$) !> S = M^{-1}*R
5852
zr2 = A%inner_product( R, S )
5953

6054
if (iter>0) then
6155
beta = zr2 / zr1
6256
P = S + beta * P
6357
end if
6458

65-
call A%apply(P, Q)
66-
Q = merge( zero_${s}$, Q, di )
59+
call A%apply(P, Q, alpha= one_${s}$, beta=zero_${s}$) !> Q = A*P
6760
zv2 = A%inner_product( P, Q )
6861

6962
alpha = zr2 / zv2
7063

7164
X = X + alpha * P
7265
R = R - alpha * Q
66+
7367
norm_sq = A%inner_product( R, R )
7468
norm_sq_old = norm_sq
69+
7570
zr1 = zr2
7671
iter = iter + 1
7772
if(associated(workspace%callback)) call workspace%callback(x, norm_sq, iter)
@@ -134,7 +129,9 @@ contains
134129
if(.not.allocated(workspace_%tmp)) allocate( workspace_%tmp(n,size_wksp_pccg) , source = zero_${s}$ )
135130
!-------------------------
136131
! main call to the solver
137-
call solve_pccg_generic(op,M_,b,x,di_,tol_,maxiter_,restart_,workspace_)
132+
if(restart_) x = zero_${s}$
133+
x = merge( b, x, di_ ) !> copy dirichlet load conditions encoded in B and indicated by di
134+
call solve_pccg_generic(op,M_,b,x,tol_,maxiter_,workspace_)
138135

139136
!-------------------------
140137
! internal memory cleanup
@@ -149,19 +146,24 @@ contains
149146
workspace_ => null()
150147
contains
151148

152-
subroutine default_matvec(x,y)
149+
subroutine default_matvec(x,y,alpha,beta)
153150
${t}$, intent(in) :: x(:)
154151
${t}$, intent(inout) :: y(:)
152+
${t}$, intent(in) :: alpha
153+
${t}$, intent(in) :: beta
155154
#:if matrix == "dense"
156-
y = matmul(A,x)
155+
y = alpha * matmul(A,x) + beta * y
157156
#:else
158-
call spmv( A , x, y )
157+
call spmv( A , x, y , alpha, beta )
159158
#:endif
159+
y = merge( zero_${s}$, y, di_ )
160160
end subroutine
161-
subroutine default_preconditionner(x,y)
161+
subroutine default_preconditionner(x,y,alpha,beta)
162162
${t}$, intent(in) :: x(:)
163163
${t}$, intent(inout) :: y(:)
164-
y = x
164+
${t}$, intent(in) :: alpha
165+
${t}$, intent(in) :: beta
166+
y = merge( zero_${s}$, x, di_ )
165167
end subroutine
166168
end subroutine
167169

0 commit comments

Comments
 (0)