Skip to content

Commit 4a90c09

Browse files
committed
tests: replace with testdrive
1 parent 036c574 commit 4a90c09

File tree

1 file changed

+137
-69
lines changed

1 file changed

+137
-69
lines changed

test/linalg/test_linalg_svd.fypp

Lines changed: 137 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -9,44 +9,37 @@ module test_linalg_svd
99
use stdlib_linalg_state, only: linalg_state_type
1010

1111
implicit none (type,external)
12+
13+
public :: test_svd
1214

1315
contains
1416

15-
!> SVD tests
16-
subroutine test_svd(error)
17-
logical, intent(out) :: error
18-
19-
real :: t0,t1
20-
21-
call cpu_time(t0)
17+
!> Solve several SVD problems
18+
subroutine test_svd(tests)
19+
!> Collection of tests
20+
type(unittest_type), allocatable, intent(out) :: tests(:)
21+
22+
allocate(tests(0))
2223

2324
#:for rk,rt,ri in REAL_KINDS_TYPES
2425
#:if rk!="xdp"
25-
call test_svd_${ri}$(error)
26-
if (error) return
26+
tests = [tests,new_unittest("test_svd_${ri}$",test_svd_${ri}$)]
2727
#:endif
2828
#:endfor
2929

3030
#:for ck,ct,ci in CMPLX_KINDS_TYPES
3131
#:if ck!="xdp"
32-
call test_complex_svd_${ci}$(error)
33-
if (error) return
32+
tests = [tests,new_unittest("test_complex_svd_${ci}$",test_complex_svd_${ci}$)]
3433
#:endif
3534
#:endfor
3635

37-
call cpu_time(t1)
38-
39-
print 1, 1000*(t1-t0), merge('SUCCESS','ERROR ',.not.error)
40-
41-
1 format('SVD tests completed in ',f9.4,' milliseconds, result=',a)
42-
4336
end subroutine test_svd
4437

4538
!> Real matrix svd
4639
#:for rk,rt,ri in REAL_KINDS_TYPES
4740
#:if rk!="xdp"
4841
subroutine test_svd_${ri}$(error)
49-
logical,intent(out) :: error
42+
type(error_type), allocatable, intent(out) :: error
5043

5144
!> Reference solution
5245
${rt}$, parameter :: tol = sqrt(epsilon(0.0_${rk}$))
@@ -63,6 +56,7 @@ module test_linalg_svd
6356
0.0_${rk}$,4*rsqrt18,-third],[3,3])
6457

6558
!> Local variables
59+
character(:), allocatable :: test
6660
type(linalg_state_type) :: state
6761
${rt}$ :: A(2,3),s(2),u(2,2),vt(3,3)
6862

@@ -71,72 +65,110 @@ module test_linalg_svd
7165

7266
!> Simple subroutine version
7367
call svd(A,s,err=state)
74-
error = state%error() .or. .not. all(abs(s-s_sol)<=tol)
75-
if (error) return
76-
68+
69+
test = 'subroutine version'
70+
call check(error,state%ok(),test//': '//state%print())
71+
if (allocated(error)) return
72+
call check(error, all(abs(s-s_sol)<=tol), test//': S')
73+
if (allocated(error)) return
74+
7775
!> Function interface
7876
s = svdvals(A,err=state)
79-
error = state%error() .or. .not. all(abs(s-s_sol)<=tol)
80-
if (error) return
77+
78+
test = 'function interface'
79+
call check(error,state%ok(),test//': '//state%print())
80+
if (allocated(error)) return
81+
call check(error, all(abs(s-s_sol)<=tol), test//': S')
82+
if (allocated(error)) return
8183

8284
!> [S, U]. Singular vectors could be all flipped
8385
call svd(A,s,u,err=state)
84-
error = state%error() .or. &
85-
.not. all(abs(s-s_sol)<=tol) .or. &
86-
.not.(all(abs(u-u_sol)<=tol) .or. all(abs(u+u_sol)<=tol))
87-
if (error) return
86+
87+
test = 'subroutine with singular vectors'
88+
call check(error,state%ok(),test//': '//state%print())
89+
if (allocated(error)) return
90+
call check(error, all(abs(s-s_sol)<=tol), test//': S')
91+
if (allocated(error)) return
92+
call check(error, all(abs(u-u_sol)<=tol) .or. all(abs(u+u_sol)<=tol), test//': U')
93+
if (allocated(error)) return
8894

8995
!> [S, U]. Overwrite A matrix
9096
call svd(A,s,u,overwrite_a=.true.,err=state)
91-
error = state%error() .or. &
92-
.not. all(abs(s-s_sol)<=tol) .or. &
93-
.not.(all(abs(u-u_sol)<=tol) .or. all(abs(u+u_sol)<=tol))
94-
if (error) return
97+
98+
test = 'subroutine, overwrite_a'
99+
call check(error,state%ok(),test//': '//state%print())
100+
if (allocated(error)) return
101+
call check(error, all(abs(s-s_sol)<=tol), test//': S')
102+
if (allocated(error)) return
103+
call check(error, all(abs(u-u_sol)<=tol) .or. all(abs(u+u_sol)<=tol), test//': U')
104+
if (allocated(error)) return
95105

96106
!> [S, U, V^T]
97107
A = A_mat
98108
call svd(A,s,u,vt,overwrite_a=.true.,err=state)
99-
error = state%error() .or. &
100-
.not. all(abs(s-s_sol)<=tol) .or. &
101-
.not.(all(abs(u-u_sol)<=tol) .or. all(abs(u+u_sol)<=tol)) .or. &
102-
.not.(all(abs(vt-vt_sol)<=tol) .or. all(abs(vt+vt_sol)<=tol))
103-
if (error) return
104-
109+
110+
test = '[S, U, V^T]'
111+
call check(error,state%ok(),test//': '//state%print())
112+
if (allocated(error)) return
113+
call check(error, all(abs(s-s_sol)<=tol), test//': S')
114+
if (allocated(error)) return
115+
call check(error, all(abs(u-u_sol)<=tol) .or. all(abs(u+u_sol)<=tol), test//': U')
116+
if (allocated(error)) return
117+
call check(error, all(abs(vt-vt_sol)<=tol) .or. all(abs(vt+vt_sol)<=tol), test//': V^T')
118+
if (allocated(error)) return
119+
105120
!> [S, V^T]. Do not overwrite A matrix
106121
A = A_mat
107122
call svd(A,s,vt=vt,err=state)
108-
error = state%error() .or. &
109-
.not. all(abs(s-s_sol)<=tol) .or. &
110-
.not.(all(abs(vt+vt_sol)<=tol) .or. all(abs(vt+vt_sol)<=tol))
111-
if (error) return
123+
124+
test = '[S, V^T], overwrite_a=.false.'
125+
call check(error,state%ok(),test//': '//state%print())
126+
if (allocated(error)) return
127+
call check(error, all(abs(s-s_sol)<=tol), test//': S')
128+
if (allocated(error)) return
129+
call check(error, all(abs(vt-vt_sol)<=tol) .or. all(abs(vt+vt_sol)<=tol), test//': V^T')
130+
if (allocated(error)) return
112131

113132
!> [S, V^T]. Overwrite A matrix
114133
call svd(A,s,vt=vt,overwrite_a=.true.,err=state)
115-
error = state%error() .or. &
116-
.not. all(abs(s-s_sol)<=tol) .or. &
117-
.not.(all(abs(vt-vt_sol)<=tol) .or. all(abs(vt+vt_sol)<=tol))
118-
if (error) return
119-
134+
135+
test = '[S, V^T], overwrite_a=.true.'
136+
call check(error,state%ok(),test//': '//state%print())
137+
if (allocated(error)) return
138+
call check(error, all(abs(s-s_sol)<=tol), test//': S')
139+
if (allocated(error)) return
140+
call check(error, all(abs(vt-vt_sol)<=tol) .or. all(abs(vt+vt_sol)<=tol), test//': V^T')
141+
if (allocated(error)) return
142+
120143
!> [U, S, V^T].
121144
A = A_mat
122145
call svd(A,s,u,vt,err=state)
123-
error = state%error() .or. &
124-
.not. all(abs(s-s_sol)<=tol) .or. &
125-
.not.(all(abs(u-u_sol)<=tol) .or. all(abs(u+u_sol)<=tol)) .or. &
126-
.not.(all(abs(vt-vt_sol)<=tol) .or. all(abs(vt+vt_sol)<=tol))
127-
if (error) return
146+
147+
test = '[U, S, V^T]'
148+
call check(error,state%ok(),test//': '//state%print())
149+
if (allocated(error)) return
150+
call check(error, all(abs(u-u_sol)<=tol) .or. all(abs(u+u_sol)<=tol), test//': U')
151+
if (allocated(error)) return
152+
call check(error, all(abs(s-s_sol)<=tol), test//': S')
153+
if (allocated(error)) return
154+
call check(error, all(abs(vt-vt_sol)<=tol) .or. all(abs(vt+vt_sol)<=tol), test//': V^T')
155+
if (allocated(error)) return
128156

129157
!> [U, S, V^T]. Partial storage -> compare until k=2 columns of U rows of V^T
130158
A = A_mat
131159
u = 0
132160
vt = 0
133161
call svd(A,s,u,vt,full_matrices=.false.,err=state)
134-
error = state%error() &
135-
.or. .not. all(abs(s-s_sol)<=tol) &
136-
.or. .not.(all(abs( u(:,:2)- u_sol(:,:2))<=tol) .or. all(abs( u(:,:2)+ u_sol(:,:2))<=tol)) &
137-
.or. .not.(all(abs(vt(:2,:)-vt_sol(:2,:))<=tol) .or. all(abs(vt(:2,:)+vt_sol(:2,:))<=tol))
138-
139-
if (error) return
162+
163+
test = '[U, S, V^T], partial storage'
164+
call check(error,state%ok(),test//': '//state%print())
165+
if (allocated(error)) return
166+
call check(error, all(abs(u(:,:2)-u_sol(:,:2))<=tol) .or. all(abs(u(:,:2)+u_sol(:,:2))<=tol), test//': U(:,:2)')
167+
if (allocated(error)) return
168+
call check(error, all(abs(s-s_sol)<=tol), test//': S')
169+
if (allocated(error)) return
170+
call check(error, all(abs(vt(:2,:)-vt_sol(:2,:))<=tol) .or. all(abs(vt(:2,:)+vt_sol(:2,:))<=tol), test//': V^T(:2,:)')
171+
if (allocated(error)) return
140172

141173
end subroutine test_svd_${ri}$
142174

@@ -147,7 +179,7 @@ module test_linalg_svd
147179
#:for ck,ct,ci in CMPLX_KINDS_TYPES
148180
#:if ck!="xdp"
149181
subroutine test_complex_svd_${ci}$(error)
150-
logical,intent(out) :: error
182+
type(error_type), allocatable, intent(out) :: error
151183

152184
!> Reference solution
153185
real(${ck}$), parameter :: tol = sqrt(epsilon(0.0_${ck}$))
@@ -165,6 +197,7 @@ module test_linalg_svd
165197
${ct}$, parameter :: vt_sol(2,2) = reshape([cone,czero,czero,cone],[2,2])
166198

167199
!> Local variables
200+
character(:), allocatable :: test
168201
type(linalg_state_type) :: state
169202
${ct}$ :: A(2,2),u(2,2),vt(2,2)
170203
real(${ck}$) :: s(2)
@@ -174,28 +207,63 @@ module test_linalg_svd
174207

175208
!> Simple subroutine version
176209
call svd(A,s,err=state)
177-
error = state%error() .or. .not. all(abs(s-s_sol)<=tol)
178-
if (error) return
179-
210+
211+
test = '[S], complex subroutine'
212+
call check(error,state%ok(),test//': '//state%print())
213+
if (allocated(error)) return
214+
call check(error, all(abs(s-s_sol)<=tol), test//': S')
215+
if (allocated(error)) return
216+
180217
!> Function interface
181218
s = svdvals(A,err=state)
182-
error = state%error() .or. .not. all(abs(s-s_sol)<=tol)
183-
if (error) return
219+
220+
test = 'svdvals, complex function'
221+
call check(error,state%ok(),test//': '//state%print())
222+
if (allocated(error)) return
223+
call check(error, all(abs(s-s_sol)<=tol), test//': S')
224+
if (allocated(error)) return
184225

185226
!> [S, U, V^T]
186227
A = A_mat
187228
call svd(A,s,u,vt,overwrite_a=.true.,err=state)
188-
error = state%error() .or. &
189-
.not. all(abs(s-s_sol)<=tol) .or. &
190-
.not. all(abs(matmul(u,matmul(diag(s),vt)) - A_mat)<=tol)
191-
if (error) return
229+
230+
test = '[S, U, V^T], complex'
231+
call check(error,state%ok(),test//': '//state%print())
232+
if (allocated(error)) return
233+
call check(error, all(abs(s-s_sol)<=tol), test//': S')
234+
if (allocated(error)) return
235+
call check(error, all(abs(matmul(u,matmul(diag(s),vt))-A_mat)<=tol), test//': U*S*V^T')
236+
if (allocated(error)) return
192237

193238
end subroutine test_complex_svd_${ci}$
194239

195240
#:endif
196241
#:endfor
197242

198-
199243
end module test_linalg_svd
200244

201-
245+
program test_lstsq
246+
use, intrinsic :: iso_fortran_env, only : error_unit
247+
use testdrive, only : run_testsuite, new_testsuite, testsuite_type
248+
use test_linalg_svd, only : test_svd
249+
implicit none
250+
integer :: stat, is
251+
type(testsuite_type), allocatable :: testsuites(:)
252+
character(len=*), parameter :: fmt = '("#", *(1x, a))'
253+
254+
stat = 0
255+
256+
testsuites = [ &
257+
new_testsuite("linalg_svd", test_svd) &
258+
]
259+
260+
do is = 1, size(testsuites)
261+
write(error_unit, fmt) "Testing:", testsuites(is)%name
262+
call run_testsuite(testsuites(is)%collect, error_unit, stat)
263+
end do
264+
265+
if (stat > 0) then
266+
write(error_unit, '(i0, 1x, a)') stat, "test(s) failed!"
267+
error stop
268+
end if
269+
end program test_lstsq

0 commit comments

Comments
 (0)