Skip to content

Commit 42aa3e5

Browse files
committed
add svd
1 parent 33559f5 commit 42aa3e5

File tree

2 files changed

+278
-1
lines changed

2 files changed

+278
-1
lines changed

src/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ set(fppFiles
2626
stdlib_linalg_kronecker.fypp
2727
stdlib_linalg_cross_product.fypp
2828
stdlib_linalg_determinant.fypp
29-
stdlib_linalg_state.fypp
29+
stdlib_linalg_state.fypp
30+
stdlib_linalg_svd.fypp
3031
stdlib_optval.fypp
3132
stdlib_selection.fypp
3233
stdlib_sorting.fypp

src/stdlib_linalg_svd.fypp

Lines changed: 276 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,276 @@
1+
#:include "common.fypp"
2+
module stdlib_linalg_svd
3+
use stdlib_linalg_constants
4+
use stdlib_linalg_blas
5+
use stdlib_linalg_lapack
6+
use stdlib_linalg_state
7+
use iso_fortran_env,only:real32,real64,real128,int8,int16,int32,int64,stderr => error_unit
8+
implicit none(type,external)
9+
private
10+
11+
!> Singular value decomposition
12+
public :: svd
13+
!> Singular values
14+
public :: svdvals
15+
16+
! Numpy: svd(a, full_matrices=True, compute_uv=True, hermitian=False)
17+
! Scipy: svd(a, full_matrices=True, compute_uv=True, overwrite_a=False, check_finite=True, lapack_driver='gesdd')
18+
19+
interface svd
20+
#:for rk,rt,ri in ALL_KINDS_TYPES
21+
module procedure stdlib_linalg_svd_${ri}$
22+
#:endfor
23+
end interface svd
24+
25+
interface svdvals
26+
#:for rk,rt,ri in ALL_KINDS_TYPES
27+
module procedure stdlib_linalg_svdvals_${ri}$
28+
#:endfor
29+
end interface svdvals
30+
31+
!> Return full matrices U, V^T to separate storage
32+
character, parameter :: GESDD_FULL_MATRICES = 'A'
33+
34+
!> Return shrunk matrices U, V^T to k = min(m,n)
35+
character, parameter :: GESDD_SHRINK_MATRICES = 'S'
36+
37+
!> Overwrite A storage with U (if M>=N) or VT (if M<N); separate storage for the other matrix
38+
character, parameter :: GESDD_OVERWRITE_A = 'O'
39+
40+
!> Do not return either U or VT (singular values array only)
41+
character, parameter :: GESDD_SINGVAL_ONLY = 'N'
42+
43+
character(*), parameter :: this = 'svd'
44+
45+
46+
contains
47+
48+
!> Process GESDD output flag
49+
elemental subroutine gesdd_info(err,info,m,n)
50+
!> Error handler
51+
type(linalg_state), intent(inout) :: err
52+
!> GESDD return flag
53+
integer(ilp), intent(in) :: info
54+
!> Input matrix size
55+
integer(ilp), intent(in) :: m,n
56+
57+
select case (info)
58+
case (0)
59+
! Success!
60+
err%state = LINALG_SUCCESS
61+
case (-1)
62+
err = linalg_state(this,LINALG_INTERNAL_ERROR,'Invalid task ID on input to GESDD.')
63+
case (-5,-3:-2)
64+
err = linalg_state(this,LINALG_VALUE_ERROR,'invalid matrix size: a=[',m,',',n,']')
65+
case (-8)
66+
err = linalg_state(this,LINALG_VALUE_ERROR,'invalid matrix U size, with a=[',m,',',n,']')
67+
case (-10)
68+
err = linalg_state(this,LINALG_VALUE_ERROR,'invalid matrix V size, with a=[',m,',',n,']')
69+
case (-4)
70+
err = linalg_state(this,LINALG_VALUE_ERROR,'A contains invalid/NaN values.')
71+
case (1:)
72+
err = linalg_state(this,LINALG_ERROR,'SVD computation did not converge.')
73+
case default
74+
err = linalg_state(this,LINALG_INTERNAL_ERROR,'Unknown error returned by GESDD.')
75+
end select
76+
77+
end subroutine gesdd_info
78+
79+
80+
#:for rk,rt,ri in ALL_KINDS_TYPES
81+
82+
!> Singular values of matrix A
83+
function stdlib_linalg_svdvals_${ri}$(a,err) result(s)
84+
!> Input matrix A[m,n]
85+
${rt}$, intent(in), target :: a(:,:)
86+
!> [optional] state return flag. On error if not requested, the code will stop
87+
type(linalg_state), optional, intent(out) :: err
88+
!> Array of singular values
89+
real(${rk}$), allocatable :: s(:)
90+
91+
!> Create
92+
${rt}$, pointer :: amat(:,:)
93+
integer(ilp) :: m,n,k
94+
95+
!> Create an internal pointer so the intent of A won't affect the next call
96+
amat => a
97+
98+
m = size(a,1,kind=ilp)
99+
n = size(a,2,kind=ilp)
100+
k = min(m,n)
101+
102+
!> Allocate return storage
103+
allocate(s(k))
104+
105+
!> Compute singular values
106+
call stdlib_linalg_svd_${ri}$(amat,s,overwrite_a=.false.,err=err)
107+
108+
end function stdlib_linalg_svdvals_${ri}$
109+
110+
!> SVD of matrix A = U S V^T, returning S and optionally U and V^T
111+
subroutine stdlib_linalg_svd_${ri}$(a,s,u,vt,overwrite_a,full_matrices,err)
112+
!> Input matrix A[m,n]
113+
${rt}$, intent(inout), target :: a(:,:)
114+
!> Array of singular values
115+
real(${rk}$), intent(out) :: s(:)
116+
!> The columns of U contain the eigenvectors of A A^T
117+
${rt}$, optional, intent(out), target :: u(:,:)
118+
!> The rows of V^T contain the eigenvectors of A^T A
119+
${rt}$, optional, intent(out), target :: vt(:,:)
120+
!> [optional] Can A data be overwritten and destroyed?
121+
logical(lk), optional, intent(in) :: overwrite_a
122+
!> [optional] full matrices have shape(u)==[m,m], shape(vh)==[n,n] (default); otherwise
123+
!> they are shape(u)==[m,k] and shape(vh)==[k,n] with k=min(m,n)
124+
logical(lk), optional, intent(in) :: full_matrices
125+
!> [optional] state return flag. On error if not requested, the code will stop
126+
type(linalg_state), optional, intent(out) :: err
127+
128+
!> Local variables
129+
type(linalg_state) :: err0
130+
integer(ilp) :: m,n,lda,ldu,ldvt,info,k,lwork,liwork,lrwork
131+
integer(ilp), allocatable :: iwork(:)
132+
logical(lk) :: copy_a,full_storage,compute_uv,alloc_u,alloc_vt,can_overwrite_a
133+
character :: task
134+
${rt}$, target :: work_dummy(1),u_dummy(1,1),vt_dummy(1,1)
135+
${rt}$, allocatable :: work(:)
136+
real(${rk}$), allocatable :: rwork(:)
137+
${rt}$, pointer :: amat(:,:),umat(:,:),vtmat(:,:)
138+
139+
!> Matrix determinant size
140+
m = size(a,1,kind=ilp)
141+
n = size(a,2,kind=ilp)
142+
k = min(m,n)
143+
lda = m
144+
145+
if (.not.k>0) then
146+
err0 = linalg_state(this,LINALG_VALUE_ERROR,'invalid or matrix size: a=[',m,',',n,']')
147+
goto 1
148+
end if
149+
150+
if (.not.size(s,kind=ilp)>=k) then
151+
err0 = linalg_state(this,LINALG_VALUE_ERROR,'singular value array has insufficient size:',&
152+
' s=[',size(s,kind=ilp),'], k=',k)
153+
goto 1
154+
endif
155+
156+
! Integer storage
157+
liwork = 8*k
158+
allocate(iwork(liwork))
159+
160+
! Can A be overwritten? By default, do not overwrite
161+
if (present(overwrite_a)) then
162+
copy_a = .not.overwrite_a
163+
else
164+
copy_a = .true._lk
165+
endif
166+
167+
! Initialize a matrix temporary
168+
if (copy_a) then
169+
allocate(amat(m,n),source=a)
170+
171+
! Check if we can overwrite A with data that will be lost
172+
can_overwrite_a = merge(.not.present(u),.not.present(vt),m>=n)
173+
174+
else
175+
amat => a
176+
177+
can_overwrite_a = .false.
178+
179+
endif
180+
181+
! Full-size matrices
182+
if (present(full_matrices)) then
183+
full_storage = full_matrices
184+
else
185+
full_storage = .true.
186+
endif
187+
188+
! Decide if U, VT matrices should be computed
189+
compute_uv = present(u) .or. present(vt)
190+
191+
! U, VT storage
192+
if (present(u)) then
193+
umat => u
194+
alloc_u = .false.
195+
elseif ((copy_a .and. m>=n) .or. .not.compute_uv) then
196+
! U not wanted, and A can be overwritten: do not allocate
197+
umat => u_dummy
198+
alloc_u = .false.
199+
elseif (.not.full_storage) then
200+
allocate(umat(m,k))
201+
alloc_u = .true.
202+
else
203+
allocate(umat(m,m))
204+
alloc_u = .true.
205+
end if
206+
207+
if (present(vt)) then
208+
vtmat => vt
209+
alloc_vt = .false.
210+
elseif ((copy_a .and. m<n) .or. .not.compute_uv) then
211+
! amat can be overwritten, VT not wanted: VT is returned upon A
212+
vtmat => vt_dummy
213+
alloc_vt = .false.
214+
elseif (.not.full_storage) then
215+
allocate(vtmat(k,n))
216+
alloc_vt = .true.
217+
else
218+
allocate(vtmat(n,n))
219+
alloc_vt = .true.
220+
end if
221+
222+
ldu = size(umat ,1,kind=ilp)
223+
ldvt = size(vtmat,1,kind=ilp)
224+
225+
! Decide SVD task
226+
if (.not.compute_uv) then
227+
task = GESDD_SINGVAL_ONLY
228+
elseif (can_overwrite_a) then
229+
! A is a copy: we can overwrite its storage
230+
task = GESDD_OVERWRITE_A
231+
elseif (.not.full_storage) then
232+
task = GESDD_SHRINK_MATRICES
233+
else
234+
task = GESDD_FULL_MATRICES
235+
end if
236+
237+
! Compute workspace
238+
#:if rt.startswith('complex')
239+
if (task==GESDD_SINGVAL_ONLY) then
240+
lrwork = max(1,7*k)
241+
else
242+
lrwork = max(1,5*k*(k+1),2*k*(k+max(m,n))+k)
243+
endif
244+
allocate(rwork(lrwork))
245+
#:endif
246+
247+
lwork = -1_ilp
248+
249+
call gesdd(task,m,n,amat,lda,s,umat,ldu,vtmat,ldvt,&
250+
work_dummy,lwork,#{if rt.startswith('complex')}#rwork,#{endif}#iwork,info)
251+
call gesdd_info(err0,info,m,n)
252+
253+
! Compute SVD
254+
if (info==0) then
255+
256+
!> Prepare working storage
257+
lwork = nint(real(work_dummy(1),kind=${rk}$), kind=ilp)
258+
allocate(work(lwork))
259+
260+
!> Compute SVD
261+
call gesdd(task,m,n,amat,lda,s,umat,ldu,vtmat,ldvt,&
262+
work,lwork,#{if rt.startswith('comp')}#rwork,#{endif}#iwork,info)
263+
call gesdd_info(err0,info,m,n)
264+
265+
endif
266+
267+
! Finalize storage and process output flag
268+
if (copy_a) deallocate(amat)
269+
if (alloc_u) deallocate(umat)
270+
if (alloc_vt) deallocate(vtmat)
271+
1 call linalg_error_handling(err0,err)
272+
273+
end subroutine stdlib_linalg_svd_${ri}$
274+
#:endfor
275+
276+
end module stdlib_linalg_svd

0 commit comments

Comments
 (0)