Skip to content

Commit 036c574

Browse files
committed
add test programs
1 parent 64adda4 commit 036c574

File tree

2 files changed

+203
-0
lines changed

2 files changed

+203
-0
lines changed

test/linalg/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ set(
44
"test_blas_lapack.fypp"
55
"test_linalg_lstsq.fypp"
66
"test_linalg_determinant.fypp"
7+
"test_linalg_svd.fypp"
78
"test_linalg_matrix_property_checks.fypp"
89
)
910
fypp_f90("${fyppFlags}" "${fppFiles}" outFiles)
@@ -12,4 +13,5 @@ ADDTEST(linalg)
1213
ADDTEST(linalg_determinant)
1314
ADDTEST(linalg_matrix_property_checks)
1415
ADDTEST(linalg_lstsq)
16+
ADDTEST(linalg_svd)
1517
ADDTEST(blas_lapack)

test/linalg/test_linalg_svd.fypp

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
#:include "common.fypp"
2+
#:set RC_KINDS_TYPES = REAL_KINDS_TYPES + CMPLX_KINDS_TYPES
3+
! Test singular value decomposition
4+
module test_linalg_svd
5+
use testdrive, only: error_type, check, new_unittest, unittest_type
6+
use stdlib_linalg_constants
7+
use stdlib_linalg, only: diag
8+
use stdlib_linalg_svd, only: svd,svdvals
9+
use stdlib_linalg_state, only: linalg_state_type
10+
11+
implicit none (type,external)
12+
13+
contains
14+
15+
!> SVD tests
16+
subroutine test_svd(error)
17+
logical, intent(out) :: error
18+
19+
real :: t0,t1
20+
21+
call cpu_time(t0)
22+
23+
#:for rk,rt,ri in REAL_KINDS_TYPES
24+
#:if rk!="xdp"
25+
call test_svd_${ri}$(error)
26+
if (error) return
27+
#:endif
28+
#:endfor
29+
30+
#:for ck,ct,ci in CMPLX_KINDS_TYPES
31+
#:if ck!="xdp"
32+
call test_complex_svd_${ci}$(error)
33+
if (error) return
34+
#:endif
35+
#:endfor
36+
37+
call cpu_time(t1)
38+
39+
print 1, 1000*(t1-t0), merge('SUCCESS','ERROR ',.not.error)
40+
41+
1 format('SVD tests completed in ',f9.4,' milliseconds, result=',a)
42+
43+
end subroutine test_svd
44+
45+
!> Real matrix svd
46+
#:for rk,rt,ri in REAL_KINDS_TYPES
47+
#:if rk!="xdp"
48+
subroutine test_svd_${ri}$(error)
49+
logical,intent(out) :: error
50+
51+
!> Reference solution
52+
${rt}$, parameter :: tol = sqrt(epsilon(0.0_${rk}$))
53+
${rt}$, parameter :: third = 1.0_${rk}$/3.0_${rk}$
54+
${rt}$, parameter :: twothd = 2*third
55+
${rt}$, parameter :: rsqrt2 = 1.0_${rk}$/sqrt(2.0_${rk}$)
56+
${rt}$, parameter :: rsqrt18 = 1.0_${rk}$/sqrt(18.0_${rk}$)
57+
58+
${rt}$, parameter :: A_mat(2,3) = reshape([${rt}$ :: 3,2, 2,3, 2,-2],[2,3])
59+
${rt}$, parameter :: s_sol(2) = [${rt}$ :: 5, 3]
60+
${rt}$, parameter :: u_sol(2,2) = reshape(rsqrt2*[1,1,1,-1],[2,2])
61+
${rt}$, parameter :: vt_sol(3,3) = reshape([rsqrt2,rsqrt18,twothd, &
62+
rsqrt2,-rsqrt18,-twothd,&
63+
0.0_${rk}$,4*rsqrt18,-third],[3,3])
64+
65+
!> Local variables
66+
type(linalg_state_type) :: state
67+
${rt}$ :: A(2,3),s(2),u(2,2),vt(3,3)
68+
69+
!> Initialize matrix
70+
A = A_mat
71+
72+
!> Simple subroutine version
73+
call svd(A,s,err=state)
74+
error = state%error() .or. .not. all(abs(s-s_sol)<=tol)
75+
if (error) return
76+
77+
!> Function interface
78+
s = svdvals(A,err=state)
79+
error = state%error() .or. .not. all(abs(s-s_sol)<=tol)
80+
if (error) return
81+
82+
!> [S, U]. Singular vectors could be all flipped
83+
call svd(A,s,u,err=state)
84+
error = state%error() .or. &
85+
.not. all(abs(s-s_sol)<=tol) .or. &
86+
.not.(all(abs(u-u_sol)<=tol) .or. all(abs(u+u_sol)<=tol))
87+
if (error) return
88+
89+
!> [S, U]. Overwrite A matrix
90+
call svd(A,s,u,overwrite_a=.true.,err=state)
91+
error = state%error() .or. &
92+
.not. all(abs(s-s_sol)<=tol) .or. &
93+
.not.(all(abs(u-u_sol)<=tol) .or. all(abs(u+u_sol)<=tol))
94+
if (error) return
95+
96+
!> [S, U, V^T]
97+
A = A_mat
98+
call svd(A,s,u,vt,overwrite_a=.true.,err=state)
99+
error = state%error() .or. &
100+
.not. all(abs(s-s_sol)<=tol) .or. &
101+
.not.(all(abs(u-u_sol)<=tol) .or. all(abs(u+u_sol)<=tol)) .or. &
102+
.not.(all(abs(vt-vt_sol)<=tol) .or. all(abs(vt+vt_sol)<=tol))
103+
if (error) return
104+
105+
!> [S, V^T]. Do not overwrite A matrix
106+
A = A_mat
107+
call svd(A,s,vt=vt,err=state)
108+
error = state%error() .or. &
109+
.not. all(abs(s-s_sol)<=tol) .or. &
110+
.not.(all(abs(vt+vt_sol)<=tol) .or. all(abs(vt+vt_sol)<=tol))
111+
if (error) return
112+
113+
!> [S, V^T]. Overwrite A matrix
114+
call svd(A,s,vt=vt,overwrite_a=.true.,err=state)
115+
error = state%error() .or. &
116+
.not. all(abs(s-s_sol)<=tol) .or. &
117+
.not.(all(abs(vt-vt_sol)<=tol) .or. all(abs(vt+vt_sol)<=tol))
118+
if (error) return
119+
120+
!> [U, S, V^T].
121+
A = A_mat
122+
call svd(A,s,u,vt,err=state)
123+
error = state%error() .or. &
124+
.not. all(abs(s-s_sol)<=tol) .or. &
125+
.not.(all(abs(u-u_sol)<=tol) .or. all(abs(u+u_sol)<=tol)) .or. &
126+
.not.(all(abs(vt-vt_sol)<=tol) .or. all(abs(vt+vt_sol)<=tol))
127+
if (error) return
128+
129+
!> [U, S, V^T]. Partial storage -> compare until k=2 columns of U rows of V^T
130+
A = A_mat
131+
u = 0
132+
vt = 0
133+
call svd(A,s,u,vt,full_matrices=.false.,err=state)
134+
error = state%error() &
135+
.or. .not. all(abs(s-s_sol)<=tol) &
136+
.or. .not.(all(abs( u(:,:2)- u_sol(:,:2))<=tol) .or. all(abs( u(:,:2)+ u_sol(:,:2))<=tol)) &
137+
.or. .not.(all(abs(vt(:2,:)-vt_sol(:2,:))<=tol) .or. all(abs(vt(:2,:)+vt_sol(:2,:))<=tol))
138+
139+
if (error) return
140+
141+
end subroutine test_svd_${ri}$
142+
143+
#:endif
144+
#:endfor
145+
146+
!> Test complex svd
147+
#:for ck,ct,ci in CMPLX_KINDS_TYPES
148+
#:if ck!="xdp"
149+
subroutine test_complex_svd_${ci}$(error)
150+
logical,intent(out) :: error
151+
152+
!> Reference solution
153+
real(${ck}$), parameter :: tol = sqrt(epsilon(0.0_${ck}$))
154+
real(${ck}$), parameter :: one = 1.0_${ck}$
155+
real(${ck}$), parameter :: zero = 0.0_${ck}$
156+
real(${ck}$), parameter :: sqrt2 = sqrt(2.0_${ck}$)
157+
real(${ck}$), parameter :: rsqrt2 = one/sqrt2
158+
${ct}$, parameter :: cone = (1.0_${ck}$,0.0_${ck}$)
159+
${ct}$, parameter :: cimg = (0.0_${ck}$,1.0_${ck}$)
160+
${ct}$, parameter :: czero = (0.0_${ck}$,0.0_${ck}$)
161+
162+
real(${ck}$), parameter :: s_sol(2) = [sqrt2,sqrt2]
163+
${ct}$, parameter :: A_mat(2,2) = reshape([cone,cimg,cimg,cone],[2,2])
164+
${ct}$, parameter :: u_sol(2,2) = reshape(rsqrt2*[cone,cimg,cimg,cone],[2,2])
165+
${ct}$, parameter :: vt_sol(2,2) = reshape([cone,czero,czero,cone],[2,2])
166+
167+
!> Local variables
168+
type(linalg_state_type) :: state
169+
${ct}$ :: A(2,2),u(2,2),vt(2,2)
170+
real(${ck}$) :: s(2)
171+
172+
!> Initialize matrix
173+
A = A_mat
174+
175+
!> Simple subroutine version
176+
call svd(A,s,err=state)
177+
error = state%error() .or. .not. all(abs(s-s_sol)<=tol)
178+
if (error) return
179+
180+
!> Function interface
181+
s = svdvals(A,err=state)
182+
error = state%error() .or. .not. all(abs(s-s_sol)<=tol)
183+
if (error) return
184+
185+
!> [S, U, V^T]
186+
A = A_mat
187+
call svd(A,s,u,vt,overwrite_a=.true.,err=state)
188+
error = state%error() .or. &
189+
.not. all(abs(s-s_sol)<=tol) .or. &
190+
.not. all(abs(matmul(u,matmul(diag(s),vt)) - A_mat)<=tol)
191+
if (error) return
192+
193+
end subroutine test_complex_svd_${ci}$
194+
195+
#:endif
196+
#:endfor
197+
198+
199+
end module test_linalg_svd
200+
201+

0 commit comments

Comments
 (0)