@@ -9,44 +9,37 @@ module test_linalg_svd
9
9
use stdlib_linalg_state, only: linalg_state_type
10
10
11
11
implicit none (type,external)
12
+
13
+ public :: test_svd
12
14
13
15
contains
14
16
15
- !> SVD tests
16
- subroutine test_svd(error)
17
- logical, intent(out) :: error
18
-
19
- real :: t0,t1
20
-
21
- call cpu_time(t0)
17
+ !> Solve several SVD problems
18
+ subroutine test_svd(tests)
19
+ !> Collection of tests
20
+ type(unittest_type), allocatable, intent(out) :: tests(:)
21
+
22
+ allocate(tests(0))
22
23
23
24
#:for rk,rt,ri in REAL_KINDS_TYPES
24
25
#:if rk!="xdp"
25
- call test_svd_${ri}$(error)
26
- if (error) return
26
+ tests = [tests,new_unittest("test_svd_${ri}$",test_svd_${ri}$)]
27
27
#:endif
28
28
#:endfor
29
29
30
30
#:for ck,ct,ci in CMPLX_KINDS_TYPES
31
31
#:if ck!="xdp"
32
- call test_complex_svd_${ci}$(error)
33
- if (error) return
32
+ tests = [tests,new_unittest("test_complex_svd_${ci}$",test_complex_svd_${ci}$)]
34
33
#:endif
35
34
#:endfor
36
35
37
- call cpu_time(t1)
38
-
39
- print 1, 1000*(t1-t0), merge('SUCCESS','ERROR ',.not.error)
40
-
41
- 1 format('SVD tests completed in ',f9.4,' milliseconds, result=',a)
42
-
43
36
end subroutine test_svd
44
37
45
38
!> Real matrix svd
46
39
#:for rk,rt,ri in REAL_KINDS_TYPES
47
40
#:if rk!="xdp"
48
41
subroutine test_svd_${ri}$(error)
49
- logical, intent(out) :: error
42
+ type(error_type), allocatable, intent(out) :: error
50
43
51
44
!> Reference solution
52
45
${rt}$, parameter :: tol = sqrt(epsilon(0.0_${rk}$))
@@ -63,6 +56,7 @@ module test_linalg_svd
63
56
0.0_${rk}$,4*rsqrt18,-third],[3,3])
64
57
65
58
!> Local variables
59
+ character(:), allocatable :: test
66
60
type(linalg_state_type) :: state
67
61
${rt}$ :: A(2,3),s(2),u(2,2),vt(3,3)
68
62
@@ -71,72 +65,110 @@ module test_linalg_svd
71
65
72
66
!> Simple subroutine version
73
67
call svd(A,s,err=state)
74
- error = state%error() .or. .not. all(abs(s-s_sol)<=tol)
75
- if (error) return
76
-
68
+
69
+ test = 'subroutine version'
70
+ call check(error,state%ok(),test//': '//state%print())
71
+ if (allocated(error)) return
72
+ call check(error, all(abs(s-s_sol)<=tol), test//': S')
73
+ if (allocated(error)) return
74
+
77
75
!> Function interface
78
76
s = svdvals(A,err=state)
79
- error = state%error() .or. .not. all(abs(s-s_sol)<=tol)
80
- if (error) return
77
+
78
+ test = 'function interface'
79
+ call check(error,state%ok(),test//': '//state%print())
80
+ if (allocated(error)) return
81
+ call check(error, all(abs(s-s_sol)<=tol), test//': S')
82
+ if (allocated(error)) return
81
83
82
84
!> [S, U]. Singular vectors could be all flipped
83
85
call svd(A,s,u,err=state)
84
- error = state%error() .or. &
85
- .not. all(abs(s-s_sol)<=tol) .or. &
86
- .not.(all(abs(u-u_sol)<=tol) .or. all(abs(u+u_sol)<=tol))
87
- if (error) return
86
+
87
+ test = 'subroutine with singular vectors'
88
+ call check(error,state%ok(),test//': '//state%print())
89
+ if (allocated(error)) return
90
+ call check(error, all(abs(s-s_sol)<=tol), test//': S')
91
+ if (allocated(error)) return
92
+ call check(error, all(abs(u-u_sol)<=tol) .or. all(abs(u+u_sol)<=tol), test//': U')
93
+ if (allocated(error)) return
88
94
89
95
!> [S, U]. Overwrite A matrix
90
96
call svd(A,s,u,overwrite_a=.true.,err=state)
91
- error = state%error() .or. &
92
- .not. all(abs(s-s_sol)<=tol) .or. &
93
- .not.(all(abs(u-u_sol)<=tol) .or. all(abs(u+u_sol)<=tol))
94
- if (error) return
97
+
98
+ test = 'subroutine, overwrite_a'
99
+ call check(error,state%ok(),test//': '//state%print())
100
+ if (allocated(error)) return
101
+ call check(error, all(abs(s-s_sol)<=tol), test//': S')
102
+ if (allocated(error)) return
103
+ call check(error, all(abs(u-u_sol)<=tol) .or. all(abs(u+u_sol)<=tol), test//': U')
104
+ if (allocated(error)) return
95
105
96
106
!> [S, U, V^T]
97
107
A = A_mat
98
108
call svd(A,s,u,vt,overwrite_a=.true.,err=state)
99
- error = state%error() .or. &
100
- .not. all(abs(s-s_sol)<=tol) .or. &
101
- .not.(all(abs(u-u_sol)<=tol) .or. all(abs(u+u_sol)<=tol)) .or. &
102
- .not.(all(abs(vt-vt_sol)<=tol) .or. all(abs(vt+vt_sol)<=tol))
103
- if (error) return
104
-
109
+
110
+ test = '[S, U, V^T]'
111
+ call check(error,state%ok(),test//': '//state%print())
112
+ if (allocated(error)) return
113
+ call check(error, all(abs(s-s_sol)<=tol), test//': S')
114
+ if (allocated(error)) return
115
+ call check(error, all(abs(u-u_sol)<=tol) .or. all(abs(u+u_sol)<=tol), test//': U')
116
+ if (allocated(error)) return
117
+ call check(error, all(abs(vt-vt_sol)<=tol) .or. all(abs(vt+vt_sol)<=tol), test//': V^T')
118
+ if (allocated(error)) return
119
+
105
120
!> [S, V^T]. Do not overwrite A matrix
106
121
A = A_mat
107
122
call svd(A,s,vt=vt,err=state)
108
- error = state%error() .or. &
109
- .not. all(abs(s-s_sol)<=tol) .or. &
110
- .not.(all(abs(vt+vt_sol)<=tol) .or. all(abs(vt+vt_sol)<=tol))
111
- if (error) return
123
+
124
+ test = '[S, V^T], overwrite_a=.false.'
125
+ call check(error,state%ok(),test//': '//state%print())
126
+ if (allocated(error)) return
127
+ call check(error, all(abs(s-s_sol)<=tol), test//': S')
128
+ if (allocated(error)) return
129
+ call check(error, all(abs(vt-vt_sol)<=tol) .or. all(abs(vt+vt_sol)<=tol), test//': V^T')
130
+ if (allocated(error)) return
112
131
113
132
!> [S, V^T]. Overwrite A matrix
114
133
call svd(A,s,vt=vt,overwrite_a=.true.,err=state)
115
- error = state%error() .or. &
116
- .not. all(abs(s-s_sol)<=tol) .or. &
117
- .not.(all(abs(vt-vt_sol)<=tol) .or. all(abs(vt+vt_sol)<=tol))
118
- if (error) return
119
-
134
+
135
+ test = '[S, V^T], overwrite_a=.true.'
136
+ call check(error,state%ok(),test//': '//state%print())
137
+ if (allocated(error)) return
138
+ call check(error, all(abs(s-s_sol)<=tol), test//': S')
139
+ if (allocated(error)) return
140
+ call check(error, all(abs(vt-vt_sol)<=tol) .or. all(abs(vt+vt_sol)<=tol), test//': V^T')
141
+ if (allocated(error)) return
142
+
120
143
!> [U, S, V^T].
121
144
A = A_mat
122
145
call svd(A,s,u,vt,err=state)
123
- error = state%error() .or. &
124
- .not. all(abs(s-s_sol)<=tol) .or. &
125
- .not.(all(abs(u-u_sol)<=tol) .or. all(abs(u+u_sol)<=tol)) .or. &
126
- .not.(all(abs(vt-vt_sol)<=tol) .or. all(abs(vt+vt_sol)<=tol))
127
- if (error) return
146
+
147
+ test = '[U, S, V^T]'
148
+ call check(error,state%ok(),test//': '//state%print())
149
+ if (allocated(error)) return
150
+ call check(error, all(abs(u-u_sol)<=tol) .or. all(abs(u+u_sol)<=tol), test//': U')
151
+ if (allocated(error)) return
152
+ call check(error, all(abs(s-s_sol)<=tol), test//': S')
153
+ if (allocated(error)) return
154
+ call check(error, all(abs(vt-vt_sol)<=tol) .or. all(abs(vt+vt_sol)<=tol), test//': V^T')
155
+ if (allocated(error)) return
128
156
129
157
!> [U, S, V^T]. Partial storage -> compare until k=2 columns of U rows of V^T
130
158
A = A_mat
131
159
u = 0
132
160
vt = 0
133
161
call svd(A,s,u,vt,full_matrices=.false.,err=state)
134
- error = state%error() &
135
- .or. .not. all(abs(s-s_sol)<=tol) &
136
- .or. .not.(all(abs( u(:,:2)- u_sol(:,:2))<=tol) .or. all(abs( u(:,:2)+ u_sol(:,:2))<=tol)) &
137
- .or. .not.(all(abs(vt(:2,:)-vt_sol(:2,:))<=tol) .or. all(abs(vt(:2,:)+vt_sol(:2,:))<=tol))
138
-
139
- if (error) return
162
+
163
+ test = '[U, S, V^T], partial storage'
164
+ call check(error,state%ok(),test//': '//state%print())
165
+ if (allocated(error)) return
166
+ call check(error, all(abs(u(:,:2)-u_sol(:,:2))<=tol) .or. all(abs(u(:,:2)+u_sol(:,:2))<=tol), test//': U(:,:2)')
167
+ if (allocated(error)) return
168
+ call check(error, all(abs(s-s_sol)<=tol), test//': S')
169
+ if (allocated(error)) return
170
+ call check(error, all(abs(vt(:2,:)-vt_sol(:2,:))<=tol) .or. all(abs(vt(:2,:)+vt_sol(:2,:))<=tol), test//': V^T(:2,:)')
171
+ if (allocated(error)) return
140
172
141
173
end subroutine test_svd_${ri}$
142
174
@@ -147,7 +179,7 @@ module test_linalg_svd
147
179
#:for ck,ct,ci in CMPLX_KINDS_TYPES
148
180
#:if ck!="xdp"
149
181
subroutine test_complex_svd_${ci}$(error)
150
- logical, intent(out) :: error
182
+ type(error_type), allocatable, intent(out) :: error
151
183
152
184
!> Reference solution
153
185
real(${ck}$), parameter :: tol = sqrt(epsilon(0.0_${ck}$))
@@ -165,6 +197,7 @@ module test_linalg_svd
165
197
${ct}$, parameter :: vt_sol(2,2) = reshape([cone,czero,czero,cone],[2,2])
166
198
167
199
!> Local variables
200
+ character(:), allocatable :: test
168
201
type(linalg_state_type) :: state
169
202
${ct}$ :: A(2,2),u(2,2),vt(2,2)
170
203
real(${ck}$) :: s(2)
@@ -174,28 +207,63 @@ module test_linalg_svd
174
207
175
208
!> Simple subroutine version
176
209
call svd(A,s,err=state)
177
- error = state%error() .or. .not. all(abs(s-s_sol)<=tol)
178
- if (error) return
179
-
210
+
211
+ test = '[S], complex subroutine'
212
+ call check(error,state%ok(),test//': '//state%print())
213
+ if (allocated(error)) return
214
+ call check(error, all(abs(s-s_sol)<=tol), test//': S')
215
+ if (allocated(error)) return
216
+
180
217
!> Function interface
181
218
s = svdvals(A,err=state)
182
- error = state%error() .or. .not. all(abs(s-s_sol)<=tol)
183
- if (error) return
219
+
220
+ test = 'svdvals, complex function'
221
+ call check(error,state%ok(),test//': '//state%print())
222
+ if (allocated(error)) return
223
+ call check(error, all(abs(s-s_sol)<=tol), test//': S')
224
+ if (allocated(error)) return
184
225
185
226
!> [S, U, V^T]
186
227
A = A_mat
187
228
call svd(A,s,u,vt,overwrite_a=.true.,err=state)
188
- error = state%error() .or. &
189
- .not. all(abs(s-s_sol)<=tol) .or. &
190
- .not. all(abs(matmul(u,matmul(diag(s),vt)) - A_mat)<=tol)
191
- if (error) return
229
+
230
+ test = '[S, U, V^T], complex'
231
+ call check(error,state%ok(),test//': '//state%print())
232
+ if (allocated(error)) return
233
+ call check(error, all(abs(s-s_sol)<=tol), test//': S')
234
+ if (allocated(error)) return
235
+ call check(error, all(abs(matmul(u,matmul(diag(s),vt))-A_mat)<=tol), test//': U*S*V^T')
236
+ if (allocated(error)) return
192
237
193
238
end subroutine test_complex_svd_${ci}$
194
239
195
240
#:endif
196
241
#:endfor
197
242
198
-
199
243
end module test_linalg_svd
200
244
201
-
245
+ program test_lstsq
246
+ use, intrinsic :: iso_fortran_env, only : error_unit
247
+ use testdrive, only : run_testsuite, new_testsuite, testsuite_type
248
+ use test_linalg_svd, only : test_svd
249
+ implicit none
250
+ integer :: stat, is
251
+ type(testsuite_type), allocatable :: testsuites(:)
252
+ character(len=*), parameter :: fmt = '("#", *(1x, a))'
253
+
254
+ stat = 0
255
+
256
+ testsuites = [ &
257
+ new_testsuite("linalg_svd", test_svd) &
258
+ ]
259
+
260
+ do is = 1, size(testsuites)
261
+ write(error_unit, fmt) "Testing:", testsuites(is)%name
262
+ call run_testsuite(testsuites(is)%collect, error_unit, stat)
263
+ end do
264
+
265
+ if (stat > 0) then
266
+ write(error_unit, '(i0, 1x, a)') stat, "test(s) failed!"
267
+ error stop
268
+ end if
269
+ end program test_lstsq
0 commit comments