Skip to content

Commit 0b01dbd

Browse files
committed
Add a customizable logger facility, change linop matvec to apply
1 parent bfafaa5 commit 0b01dbd

File tree

4 files changed

+44
-21
lines changed

4 files changed

+44
-21
lines changed

example/linalg/example_solve_custom.f90

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,17 +25,18 @@ subroutine solve_pccg_custom(A,b,x,di,tol,maxiter,restart,workspace)
2525
logical :: restart_
2626
logical(1), pointer :: di_(:)
2727
real(dp), allocatable :: diagonal(:)
28+
real(dp) :: norm_sq0
2829
!-------------------------
2930
n = size(b)
3031
maxiter_ = n; if(present(maxiter)) maxiter_ = maxiter
3132
restart_ = .true.; if(present(restart)) restart_ = restart
3233
tol_ = 1.e-4_dp; if(present(tol)) tol_ = tol
33-
34+
norm_sq0 = 0.d0
3435
!-------------------------
3536
! internal memory setup
36-
op%matvec => my_matvec
37+
op%apply => my_matvec
3738
op%inner_product => my_dot
38-
M%matvec => jacobi_preconditionner
39+
M%apply => jacobi_preconditionner
3940
if(present(di))then
4041
di_ => di
4142
else
@@ -48,6 +49,7 @@ subroutine solve_pccg_custom(A,b,x,di,tol,maxiter,restart,workspace)
4849
allocate( workspace_ )
4950
end if
5051
if(.not.allocated(workspace_%tmp)) allocate( workspace_%tmp(n,size_wksp_pccg) , source = 0.d0 )
52+
workspace_%callback => my_logger
5153
!-------------------------
5254
! Jacobi preconditionner factorization
5355
call diag(A,diagonal)
@@ -83,6 +85,13 @@ pure real(dp) function my_dot(x,y) result(r)
8385
real(dp), intent(in) :: y(:)
8486
r = dot_product(x,y)
8587
end function
88+
subroutine my_logger(x,norm_sq,iter)
89+
real(dp), intent(in) :: x(:)
90+
real(dp), intent(in) :: norm_sq
91+
integer, intent(in) :: iter
92+
if(iter == 0) norm_sq0 = norm_sq
93+
print *, "Iteration: ", iter, " Residual: ", sqrt(norm_sq), " Relative: ", sqrt(norm_sq)/sqrt(norm_sq0)
94+
end subroutine
8695
end subroutine
8796

8897
end module custom_solver

src/stdlib_linalg_iterative_solvers.fypp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,22 +10,26 @@ module stdlib_linalg_iterative_solvers
1010
implicit none
1111
private
1212

13+
!> brief workspace size for the iterative solvers
14+
!> details The size of the workspace is defined by the number of vectors used in the iterative solver.
1315
enum, bind(c)
1416
enumerator :: size_wksp_cg = 3
1517
enumerator :: size_wksp_pccg = 4
1618
end enum
1719
public :: size_wksp_cg, size_wksp_pccg
1820

21+
!> linear operator class for the iterative solvers
1922
#:for k, t, s in R_KINDS_TYPES
2023
type, public :: linop_${s}$
21-
procedure(vector_sub_${s}$), nopass, pointer :: matvec => null()
24+
procedure(vector_sub_${s}$), nopass, pointer :: apply => null()
2225
procedure(reduction_sub_${s}$), nopass, pointer :: inner_product => default_dot_${s}$
2326
end type
2427
#:endfor
2528

2629
#:for k, t, s in R_KINDS_TYPES
2730
type, public :: solver_workspace_${s}$
2831
${t}$, allocatable :: tmp(:,:)
32+
procedure(logger_sub_${s}$), pointer, nopass :: callback => null()
2933
end type
3034

3135
#:endfor
@@ -42,6 +46,12 @@ module stdlib_linalg_iterative_solvers
4246
${t}$, intent(in) :: x(:)
4347
${t}$, intent(in) :: y(:)
4448
end function
49+
subroutine logger_sub_${s}$(x,norm_sq,iter)
50+
import :: ${k}$
51+
${t}$, intent(in) :: x(:)
52+
${t}$, intent(in) :: norm_sq
53+
integer, intent(in) :: iter
54+
end subroutine
4555
#:endfor
4656
end interface
4757

src/stdlib_linalg_iterative_solvers_cg.fypp

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,30 +24,31 @@ contains
2424
type(solver_workspace_${s}$), intent(inout) :: workspace
2525
!-------------------------
2626
integer :: iter
27-
${t}$ :: norm_sq, norm_sq_old, norm_sq0, residual0, residual
27+
${t}$ :: norm_sq, norm_sq_old, norm_sq0
2828
${t}$ :: alpha, beta, tolsq
2929
!-------------------------
30+
iter = 0
3031
associate( P => workspace%tmp(:,1), &
3132
R => workspace%tmp(:,2), &
3233
Ap => workspace%tmp(:,3))
3334

3435
norm_sq0 = A%inner_product(B, B)
35-
residual0 = sqrt(norm_sq0)
36+
if(associated(workspace%callback)) call workspace%callback(x, norm_sq0, iter)
3637

3738
if(restart) X = zero_${s}$
3839
X = merge( B, X, di ) !> copy dirichlet load conditions encoded in B and indicated by di
3940

40-
call A%matvec(X, R)
41+
call A%apply(X, R)
4142
R = merge( zero_${s}$, B - R , di ) !> R = B - A*X
4243
norm_sq = A%inner_product(R, R)
4344

4445
P = R
4546

4647
tolsq = tol*tol
4748
beta = zero_${s}$
48-
iter = 0
49+
if(associated(workspace%callback)) call workspace%callback(x, norm_sq, iter)
4950
do while( norm_sq > tolsq * norm_sq0 .and. iter < maxiter)
50-
call A%matvec(P,Ap)
51+
call A%apply(P,Ap)
5152
Ap = merge( zero_${s}$, Ap, di )
5253

5354
alpha = norm_sq / A%inner_product(P, Ap)
@@ -62,8 +63,9 @@ contains
6263
P = R + beta * P
6364

6465
iter = iter + 1
66+
67+
if(associated(workspace%callback)) call workspace%callback(x, norm_sq, iter)
6568
end do
66-
residual = sqrt(norm_sq)
6769
end associate
6870
end subroutine
6971
#:endfor
@@ -98,7 +100,7 @@ contains
98100

99101
!-------------------------
100102
! internal memory setup
101-
op%matvec => default_matvec
103+
op%apply => default_matvec
102104
! op%inner_product => default_dot
103105
if(present(di))then
104106
di_ => di

src/stdlib_linalg_iterative_solvers_pccg.fypp

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,33 +25,35 @@ contains
2525
type(solver_workspace_${s}$), intent(inout) :: workspace
2626
!-------------------------
2727
integer :: iter
28-
${t}$ :: norm_sq, norm_sq0, norm_sq_old, residual0, residual
28+
${t}$ :: norm_sq, norm_sq0, norm_sq_old
2929
${t}$ :: zr1, zr2, zv2, alpha, beta, tolsq
3030
!-------------------------
31+
iter = 0
3132
associate( R => workspace%tmp(:,1), &
3233
S => workspace%tmp(:,2), &
3334
P => workspace%tmp(:,3), &
3435
Q => workspace%tmp(:,4))
3536
norm_sq = A%inner_product( b, b )
3637
norm_sq0 = norm_sq
37-
residual0 = sqrt(norm_sq0)
38+
if(associated(workspace%callback)) call workspace%callback(x, norm_sq, iter)
39+
3840
if ( norm_sq0 > zero_${s}$ ) then
3941
if(restart) X = zero_${s}$
4042
X = merge( B, X, di ) !> copy dirichlet load conditions encoded in B and indicated by di
4143

42-
call A%matvec(X, R)
44+
call A%apply(X, R)
4345
R = merge( zero_${s}$, B - R , di ) !> R = B - A*X
4446

45-
call M%matvec(R,P) !> P = M^{-1}*R
47+
call M%apply(R,P) !> P = M^{-1}*R
4648
P = merge( zero_${s}$, P, di )
4749

4850
tolsq = tol*tol
49-
iter = 0
51+
5052
zr1 = zero_${s}$
5153
zr2 = one_${s}$
5254
do while ( (iter < maxiter) .AND. (norm_sq > tolsq * norm_sq0) )
5355

54-
call M%matvec(R,S) !> S = M^{-1}*R
56+
call M%apply(R,S) !> S = M^{-1}*R
5557
S = merge( zero_${s}$, S, di )
5658
zr2 = A%inner_product( R, S )
5759

@@ -60,7 +62,7 @@ contains
6062
P = S + beta * P
6163
end if
6264

63-
call A%matvec(P, Q)
65+
call A%apply(P, Q)
6466
Q = merge( zero_${s}$, Q, di )
6567
zv2 = A%inner_product( P, Q )
6668

@@ -72,9 +74,9 @@ contains
7274
norm_sq_old = norm_sq
7375
zr1 = zr2
7476
iter = iter + 1
77+
if(associated(workspace%callback)) call workspace%callback(x, norm_sq, iter)
7578
end do
7679
end if
77-
residual = sqrt(norm_sq)
7880
end associate
7981
end subroutine
8082
#:endfor
@@ -111,12 +113,12 @@ contains
111113

112114
!-------------------------
113115
! internal memory setup
114-
op%matvec => default_matvec
116+
op%apply => default_matvec
115117
if(present(M)) then
116118
M_ => M
117119
else
118120
allocate( M_ )
119-
M_%matvec => default_preconditionner
121+
M_%apply => default_preconditionner
120122
end if
121123
if(present(di))then
122124
di_ => di

0 commit comments

Comments
 (0)