Skip to content

Commit e551a5d

Browse files
committed
complete cg with dirichlet flag, add example, fix di filter
1 parent 16e5cd7 commit e551a5d

6 files changed

+130
-51
lines changed

example/linalg/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ ADD_EXAMPLE(get_norm)
4141
ADD_EXAMPLE(solve1)
4242
ADD_EXAMPLE(solve2)
4343
ADD_EXAMPLE(solve3)
44+
ADD_EXAMPLE(solve_cg)
4445
ADD_EXAMPLE(solve_pccg)
4546
ADD_EXAMPLE(sparse_from_ijv)
4647
ADD_EXAMPLE(sparse_data_accessors)

example/linalg/example_solve_cg.f90

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
program example_solve_pccg
2+
use stdlib_kinds, only: dp
3+
use stdlib_linalg_iterative_solvers, only: solve_cg
4+
5+
real(dp) :: matrix(2,2)
6+
real(dp) :: x(2), load(2)
7+
8+
matrix = reshape( [4, 1,&
9+
1, 3] , [2,2])
10+
11+
x = dble( [2,1] ) !> initial guess
12+
load = dble( [1,2] ) !> load vector
13+
14+
call solve_cg(matrix, load, x, restart=.false.)
15+
print *, x !> solution: [0.0909, 0.6364]
16+
17+
end program

example/linalg/example_solve_pccg.f90

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@ program example_solve_pccg
2424
dirichlet([1,5]) = .true._1
2525

2626
call solve_pccg(laplacian, load, x, tol=1.d-6, di=dirichlet)
27-
print *, x
27+
print *, x !> solution: [0.0, 2.5, 5.0, 2.5, 0.0]
2828
x = 0._dp
2929

3030
call solve_pccg(laplacian_csr, load, x, tol=1.d-6, di=dirichlet)
31-
print *, x
31+
print *, x !> solution: [0.0, 2.5, 5.0, 2.5, 0.0]
3232
end program

src/stdlib_linalg_iterative_solvers.fypp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,14 @@ module stdlib_linalg_iterative_solvers
4242

4343
interface solve_cg_generic
4444
#:for k, t, s in R_KINDS_TYPES
45-
module subroutine solve_cg_generic_${s}$(A,b,x,tol,maxiter,workspace)
45+
module subroutine solve_cg_generic_${s}$(A,b,x,di,tol,maxiter,restart,workspace)
4646
class(linop_${s}$), intent(in) :: A
4747
${t}$, intent(in) :: b(:)
4848
${t}$, intent(inout) :: x(:)
4949
${t}$, intent(in) :: tol
50+
logical(1), intent(in) :: di(:)
5051
integer, intent(in) :: maxiter
52+
logical, intent(in) :: restart
5153
type(solver_workspace_${s}$), intent(inout) :: workspace
5254
end subroutine
5355
#:endfor
@@ -57,7 +59,7 @@ module stdlib_linalg_iterative_solvers
5759
interface solve_cg
5860
#:for matrix in MATRIX_TYPES
5961
#:for k, t, s in R_KINDS_TYPES
60-
module subroutine solve_cg_${matrix}$_${s}$(A,b,x,tol,maxiter,workspace)
62+
module subroutine solve_cg_${matrix}$_${s}$(A,b,x,di,tol,maxiter,restart,workspace)
6163
#:if matrix == "dense"
6264
${t}$, intent(in) :: A(:,:)
6365
#:else
@@ -66,7 +68,9 @@ module stdlib_linalg_iterative_solvers
6668
${t}$, intent(in) :: b(:)
6769
${t}$, intent(inout) :: x(:)
6870
${t}$, intent(in), optional :: tol
71+
logical(1), intent(in), optional, target :: di(:)
6972
integer, intent(in), optional :: maxiter
73+
logical, intent(in), optional :: restart
7074
type(solver_workspace_${s}$), optional, intent(inout), target :: workspace
7175
end subroutine
7276
#:endfor

src/stdlib_linalg_iterative_solvers_cg.fypp

Lines changed: 73 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -15,43 +15,63 @@ submodule(stdlib_linalg_iterative_solvers) stdlib_linalg_iterative_cg
1515
contains
1616

1717
#:for k, t, s in R_KINDS_TYPES
18-
module subroutine solve_cg_generic_${s}$(A,b,x,tol,maxiter,workspace)
18+
module subroutine solve_cg_generic_${s}$(A,b,x,di,tol,maxiter,restart,workspace)
1919
class(linop_${s}$), intent(in) :: A
2020
${t}$, intent(in) :: b(:), tol
2121
${t}$, intent(inout) :: x(:)
22+
logical(1), intent(in) :: di(:)
2223
integer, intent(in) :: maxiter
24+
logical, intent(in) :: restart
2325
type(solver_workspace_${s}$), intent(inout) :: workspace
2426
!-------------------------
2527
integer :: iter
26-
${t}$ :: rtr, rtrold, alpha, beta, norm0_sq
28+
${t}$ :: norm_sq, norm_sq_old, norm_sq0, residual0, residual
29+
${t}$ :: alpha, beta, tolsq
2730
!-------------------------
28-
associate( p => workspace%tmp(:,1), &
29-
r => workspace%tmp(:,2), &
31+
associate( P => workspace%tmp(:,1), &
32+
R => workspace%tmp(:,2), &
3033
Ap => workspace%tmp(:,3))
31-
x = zero_${s}$
32-
rtr = A%inner_product(r, r)
33-
norm0_sq = A%inner_product(b, b)
34-
p = b
34+
35+
norm_sq0 = A%inner_product(B, B)
36+
residual0 = sqrt(norm_sq0)
37+
38+
if(restart) X = zero_${s}$
39+
X = merge( B, X, di ) !> copy dirichlet load conditions encoded in B and indicated by di
40+
41+
call A%matvec(X, R)
42+
R = merge( zero_${s}$, B - R , di ) !> R = B - A*X
43+
norm_sq = A%inner_product(R, R)
44+
45+
P = R
46+
47+
tolsq = tol*tol
3548
beta = zero_${s}$
36-
iter = 1
37-
do while( rtr > tol**2 * norm0_sq .and. iter < maxiter)
38-
p = r + beta * p
39-
call A%matvec(p,Ap)
40-
alpha = rtr / A%inner_product(p, Ap)
41-
x = x + alpha * p
42-
r = r - alpha * Ap
43-
rtrold = rtr
44-
rtr = A%inner_product(r, r)
45-
beta = rtr / rtrold
49+
iter = 0
50+
do while( norm_sq > tolsq * norm_sq0 .and. iter < maxiter)
51+
call A%matvec(P,Ap)
52+
Ap = merge( zero_${s}$, Ap, di )
53+
54+
alpha = norm_sq / A%inner_product(P, Ap)
55+
56+
X = X + alpha * P
57+
R = R - alpha * Ap
58+
59+
norm_sq_old = norm_sq
60+
norm_sq = A%inner_product(R, R)
61+
beta = norm_sq / norm_sq_old
62+
63+
P = R + beta * P
64+
4665
iter = iter + 1
4766
end do
67+
residual = sqrt(norm_sq)
4868
end associate
4969
end subroutine
5070
#:endfor
5171

5272
#:for matrix in MATRIX_TYPES
5373
#:for k, t, s in R_KINDS_TYPES
54-
module subroutine solve_cg_${matrix}$_${s}$(A,b,x,tol,maxiter,workspace)
74+
module subroutine solve_cg_${matrix}$_${s}$(A,b,x,di,tol,maxiter,restart,workspace)
5575
#:if matrix == "dense"
5676
${t}$, intent(in) :: A(:,:)
5777
#:else
@@ -60,30 +80,57 @@ contains
6080
${t}$, intent(in) :: b(:)
6181
${t}$, intent(inout) :: x(:)
6282
${t}$, intent(in), optional :: tol
83+
logical(1), intent(in), optional, target :: di(:)
6384
integer, intent(in), optional :: maxiter
85+
logical, intent(in), optional :: restart
6486
type(solver_workspace_${s}$), optional, intent(inout), target :: workspace
6587
!-------------------------
6688
type(linop_${s}$) :: op
6789
type(solver_workspace_${s}$), pointer :: workspace_
6890
integer :: n, maxiter_
6991
${t}$ :: tol_
92+
logical :: restart_
93+
logical(1), pointer :: di_(:)
7094
!-------------------------
7195
n = size(b)
96+
maxiter_ = n; if(present(maxiter)) maxiter_ = maxiter
97+
restart_ = .true.; if(present(restart)) restart_ = restart
98+
tol_ = 1.e-4_${s}$; if(present(tol)) tol_ = tol
99+
100+
!-------------------------
101+
! internal memory setup
72102
op%matvec => default_matvec
73103
op%inner_product => default_dot
74-
75-
maxiter_ = n
76-
if(present(maxiter)) maxiter_ = maxiter
77-
tol_ = 1.e-4_${s}$
78-
if(present(tol)) tol_ = tol
104+
if(present(di))then
105+
di_ => di
106+
else
107+
allocate(di_(n),source=.false._1)
108+
end if
79109

80110
if(present(workspace)) then
81111
if(.not.allocated(workspace_%tmp)) allocate( workspace_%tmp(n,3) )
82112
workspace_ => workspace
83113
else
84-
allocate( workspace_%tmp(n,3) )
114+
allocate( workspace_ )
115+
allocate( workspace_%tmp(n,3), source = zero_${s}$ )
116+
end if
117+
!-------------------------
118+
! main call to the solver
119+
call solve_cg_generic(op,b,x,di_,tol_,maxiter_,restart_,workspace_)
120+
121+
!-------------------------
122+
! internal memory cleanup
123+
if(present(di))then
124+
di_ => null()
125+
else
126+
deallocate(di_)
127+
end if
128+
if(present(workspace)) then
129+
workspace_ => null()
130+
else
131+
deallocate( workspace_%tmp )
132+
deallocate( workspace_ )
85133
end if
86-
call solve_cg_generic(op,b,x,tol,maxiter_,workspace_)
87134
contains
88135

89136
subroutine default_matvec(x,y)

src/stdlib_linalg_iterative_solvers_pccg.fypp

Lines changed: 31 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ contains
2525
type(solver_workspace_${s}$), intent(inout) :: workspace
2626
!-------------------------
2727
integer :: iter
28-
${t}$ :: norm_sq, norm_sq0, norm_sq_old, residual0
28+
${t}$ :: norm_sq, norm_sq0, norm_sq_old, residual0, residual
2929
${t}$ :: zr1, zr2, zv2, alpha, beta, tolsq
3030
!-------------------------
3131
associate( R => workspace%tmp(:,1), &
@@ -37,23 +37,22 @@ contains
3737
residual0 = sqrt(norm_sq0)
3838
if ( norm_sq0 > zero_${s}$ ) then
3939
if(restart) X = zero_${s}$
40-
where( di ) X = B
40+
X = merge( B, X, di ) !> copy dirichlet load conditions encoded in B and indicated by di
4141

4242
call A%matvec(X, R)
43-
R = B - R
44-
where( di ) R = zero_${s}$
43+
R = merge( zero_${s}$, B - R , di ) !> R = B - A*X
4544

46-
call A%preconditionner(R,P)
47-
where( di ) P = zero_${s}$
45+
call A%preconditionner(R,P) !> P = M^{-1}*R
46+
P = merge( zero_${s}$, P, di )
4847

4948
tolsq = tol*tol
5049
iter = 0
5150
zr1 = zero_${s}$
5251
zr2 = one_${s}$
5352
do while ( (iter < maxiter) .AND. (norm_sq > tolsq * norm_sq0) )
5453

55-
call A%preconditionner(R,S)
56-
where ( di ) S = zero_${s}$
54+
call A%preconditionner(R,S) !> S = M^{-1}*R
55+
S = merge( zero_${s}$, S, di )
5756
zr2 = A%inner_product( R, S )
5857

5958
if (iter>0) then
@@ -62,7 +61,7 @@ contains
6261
end if
6362

6463
call A%matvec(P, Q)
65-
where( di ) Q = zero_${s}$
64+
Q = merge( zero_${s}$, Q, di )
6665
zv2 = A%inner_product( P, Q )
6766

6867
alpha = zr2 / zv2
@@ -75,6 +74,7 @@ contains
7574
iter = iter + 1
7675
end do
7776
end if
77+
residual = sqrt(norm_sq)
7878
end associate
7979
end subroutine
8080
#:endfor
@@ -103,18 +103,15 @@ contains
103103
logical(1), pointer :: di_(:)
104104
!-------------------------
105105
n = size(b)
106+
maxiter_ = n; if(present(maxiter)) maxiter_ = maxiter
107+
restart_ = .true.; if(present(restart)) restart_ = restart
108+
tol_ = 1.e-4_${s}$; if(present(tol)) tol_ = tol
106109

110+
!-------------------------
111+
! internal memory setup
107112
op%matvec => default_matvec
108113
op%inner_product => default_dot
109114
op%preconditionner => default_preconditionner
110-
111-
maxiter_ = n
112-
if(present(maxiter)) maxiter_ = maxiter
113-
restart_ = .true.
114-
if(present(restart)) restart_ = restart
115-
tol_ = 1.e-4_${s}$
116-
if(present(tol)) tol_ = tol
117-
118115
if(present(di))then
119116
di_ => di
120117
else
@@ -126,12 +123,25 @@ contains
126123
workspace_ => workspace
127124
else
128125
allocate( workspace_ )
129-
allocate( workspace_%tmp(n,4) )
126+
allocate( workspace_%tmp(n,4) , source = zero_${s}$ )
130127
end if
131-
132-
call solve_pccg_generic(op,b,x,di_,tol,maxiter_,restart_,workspace_)
128+
!-------------------------
129+
! main call to the solver
130+
call solve_pccg_generic(op,b,x,di_,tol_,maxiter_,restart_,workspace_)
133131

134-
if(.not.present(di)) deallocate(di_)
132+
!-------------------------
133+
! internal memory cleanup
134+
if(present(di))then
135+
di_ => null()
136+
else
137+
deallocate(di_)
138+
end if
139+
if(present(workspace)) then
140+
workspace_ => null()
141+
else
142+
deallocate( workspace_%tmp )
143+
deallocate( workspace_ )
144+
end if
135145
contains
136146

137147
subroutine default_matvec(x,y)

0 commit comments

Comments
 (0)