@@ -30,6 +30,307 @@ pub trait SVD_: Scalar {
30
30
-> Result < SVDOutput < Self > > ;
31
31
}
32
32
33
+ pub struct SvdWork < T : Scalar > {
34
+ pub ju : JobSvd ,
35
+ pub jvt : JobSvd ,
36
+ pub layout : MatrixLayout ,
37
+ pub s : Vec < MaybeUninit < T :: Real > > ,
38
+ pub u : Option < Vec < MaybeUninit < T > > > ,
39
+ pub vt : Option < Vec < MaybeUninit < T > > > ,
40
+ pub work : Vec < MaybeUninit < T > > ,
41
+ pub rwork : Option < Vec < MaybeUninit < T :: Real > > > ,
42
+ }
43
+
44
+ #[ derive( Debug , Clone ) ]
45
+ pub struct SvdRef < ' work , T : Scalar > {
46
+ pub s : & ' work [ T :: Real ] ,
47
+ pub u : Option < & ' work [ T ] > ,
48
+ pub vt : Option < & ' work [ T ] > ,
49
+ }
50
+
51
+ #[ derive( Debug , Clone ) ]
52
+ pub struct SvdOwned < T : Scalar > {
53
+ pub s : Vec < T :: Real > ,
54
+ pub u : Option < Vec < T > > ,
55
+ pub vt : Option < Vec < T > > ,
56
+ }
57
+
58
+ pub trait SvdWorkImpl : Sized {
59
+ type Elem : Scalar ;
60
+ fn new ( layout : MatrixLayout , calc_u : bool , calc_vt : bool ) -> Result < Self > ;
61
+ fn calc ( & mut self , a : & mut [ Self :: Elem ] ) -> Result < SvdRef < Self :: Elem > > ;
62
+ fn eval ( self , a : & mut [ Self :: Elem ] ) -> Result < SvdOwned < Self :: Elem > > ;
63
+ }
64
+
65
+ macro_rules! impl_svd_work_c {
66
+ ( $s: ty, $svd: path) => {
67
+ impl SvdWorkImpl for SvdWork <$s> {
68
+ type Elem = $s;
69
+
70
+ fn new( layout: MatrixLayout , calc_u: bool , calc_vt: bool ) -> Result <Self > {
71
+ let ju = match layout {
72
+ MatrixLayout :: F { .. } => JobSvd :: from_bool( calc_u) ,
73
+ MatrixLayout :: C { .. } => JobSvd :: from_bool( calc_vt) ,
74
+ } ;
75
+ let jvt = match layout {
76
+ MatrixLayout :: F { .. } => JobSvd :: from_bool( calc_vt) ,
77
+ MatrixLayout :: C { .. } => JobSvd :: from_bool( calc_u) ,
78
+ } ;
79
+
80
+ let m = layout. lda( ) ;
81
+ let mut u = match ju {
82
+ JobSvd :: All => Some ( vec_uninit( ( m * m) as usize ) ) ,
83
+ JobSvd :: None => None ,
84
+ _ => unimplemented!( "SVD with partial vector output is not supported yet" ) ,
85
+ } ;
86
+
87
+ let n = layout. len( ) ;
88
+ let mut vt = match jvt {
89
+ JobSvd :: All => Some ( vec_uninit( ( n * n) as usize ) ) ,
90
+ JobSvd :: None => None ,
91
+ _ => unimplemented!( "SVD with partial vector output is not supported yet" ) ,
92
+ } ;
93
+
94
+ let k = std:: cmp:: min( m, n) ;
95
+ let mut s = vec_uninit( k as usize ) ;
96
+ let mut rwork = vec_uninit( 5 * k as usize ) ;
97
+
98
+ // eval work size
99
+ let mut info = 0 ;
100
+ let mut work_size = [ Self :: Elem :: zero( ) ] ;
101
+ unsafe {
102
+ $svd(
103
+ ju. as_ptr( ) ,
104
+ jvt. as_ptr( ) ,
105
+ & m,
106
+ & n,
107
+ std:: ptr:: null_mut( ) ,
108
+ & m,
109
+ AsPtr :: as_mut_ptr( & mut s) ,
110
+ AsPtr :: as_mut_ptr( u. as_mut( ) . map( |x| x. as_mut_slice( ) ) . unwrap_or( & mut [ ] ) ) ,
111
+ & m,
112
+ AsPtr :: as_mut_ptr( vt. as_mut( ) . map( |x| x. as_mut_slice( ) ) . unwrap_or( & mut [ ] ) ) ,
113
+ & n,
114
+ AsPtr :: as_mut_ptr( & mut work_size) ,
115
+ & ( -1 ) ,
116
+ AsPtr :: as_mut_ptr( & mut rwork) ,
117
+ & mut info,
118
+ ) ;
119
+ }
120
+ info. as_lapack_result( ) ?;
121
+ let lwork = work_size[ 0 ] . to_usize( ) . unwrap( ) ;
122
+ let work = vec_uninit( lwork) ;
123
+ Ok ( SvdWork {
124
+ layout,
125
+ ju,
126
+ jvt,
127
+ s,
128
+ u,
129
+ vt,
130
+ work,
131
+ rwork: Some ( rwork) ,
132
+ } )
133
+ }
134
+
135
+ fn calc( & mut self , a: & mut [ Self :: Elem ] ) -> Result <SvdRef <Self :: Elem >> {
136
+ let m = self . layout. lda( ) ;
137
+ let n = self . layout. len( ) ;
138
+ let lwork = self . work. len( ) . to_i32( ) . unwrap( ) ;
139
+
140
+ let mut info = 0 ;
141
+ unsafe {
142
+ $svd(
143
+ self . ju. as_ptr( ) ,
144
+ self . jvt. as_ptr( ) ,
145
+ & m,
146
+ & n,
147
+ AsPtr :: as_mut_ptr( a) ,
148
+ & m,
149
+ AsPtr :: as_mut_ptr( & mut self . s) ,
150
+ AsPtr :: as_mut_ptr(
151
+ self . u. as_mut( ) . map( |x| x. as_mut_slice( ) ) . unwrap_or( & mut [ ] ) ,
152
+ ) ,
153
+ & m,
154
+ AsPtr :: as_mut_ptr(
155
+ self . vt
156
+ . as_mut( )
157
+ . map( |x| x. as_mut_slice( ) )
158
+ . unwrap_or( & mut [ ] ) ,
159
+ ) ,
160
+ & n,
161
+ AsPtr :: as_mut_ptr( & mut self . work) ,
162
+ & ( lwork as i32 ) ,
163
+ AsPtr :: as_mut_ptr( self . rwork. as_mut( ) . unwrap( ) ) ,
164
+ & mut info,
165
+ ) ;
166
+ }
167
+ info. as_lapack_result( ) ?;
168
+
169
+ let s = unsafe { self . s. slice_assume_init_ref( ) } ;
170
+ let u = self
171
+ . u
172
+ . as_ref( )
173
+ . map( |v| unsafe { v. slice_assume_init_ref( ) } ) ;
174
+ let vt = self
175
+ . vt
176
+ . as_ref( )
177
+ . map( |v| unsafe { v. slice_assume_init_ref( ) } ) ;
178
+
179
+ match self . layout {
180
+ MatrixLayout :: F { .. } => Ok ( SvdRef { s, u, vt } ) ,
181
+ MatrixLayout :: C { .. } => Ok ( SvdRef { s, u: vt, vt: u } ) ,
182
+ }
183
+ }
184
+
185
+ fn eval( mut self , a: & mut [ Self :: Elem ] ) -> Result <SvdOwned <Self :: Elem >> {
186
+ let _ref = self . calc( a) ?;
187
+ let s = unsafe { self . s. assume_init( ) } ;
188
+ let u = self . u. map( |v| unsafe { v. assume_init( ) } ) ;
189
+ let vt = self . vt. map( |v| unsafe { v. assume_init( ) } ) ;
190
+ match self . layout {
191
+ MatrixLayout :: F { .. } => Ok ( SvdOwned { s, u, vt } ) ,
192
+ MatrixLayout :: C { .. } => Ok ( SvdOwned { s, u: vt, vt: u } ) ,
193
+ }
194
+ }
195
+ }
196
+ } ;
197
+ }
198
+ impl_svd_work_c ! ( c64, lapack_sys:: zgesvd_) ;
199
+ impl_svd_work_c ! ( c32, lapack_sys:: cgesvd_) ;
200
+
201
+ macro_rules! impl_svd_work_r {
202
+ ( $s: ty, $svd: path) => {
203
+ impl SvdWorkImpl for SvdWork <$s> {
204
+ type Elem = $s;
205
+
206
+ fn new( layout: MatrixLayout , calc_u: bool , calc_vt: bool ) -> Result <Self > {
207
+ let ju = match layout {
208
+ MatrixLayout :: F { .. } => JobSvd :: from_bool( calc_u) ,
209
+ MatrixLayout :: C { .. } => JobSvd :: from_bool( calc_vt) ,
210
+ } ;
211
+ let jvt = match layout {
212
+ MatrixLayout :: F { .. } => JobSvd :: from_bool( calc_vt) ,
213
+ MatrixLayout :: C { .. } => JobSvd :: from_bool( calc_u) ,
214
+ } ;
215
+
216
+ let m = layout. lda( ) ;
217
+ let mut u = match ju {
218
+ JobSvd :: All => Some ( vec_uninit( ( m * m) as usize ) ) ,
219
+ JobSvd :: None => None ,
220
+ _ => unimplemented!( "SVD with partial vector output is not supported yet" ) ,
221
+ } ;
222
+
223
+ let n = layout. len( ) ;
224
+ let mut vt = match jvt {
225
+ JobSvd :: All => Some ( vec_uninit( ( n * n) as usize ) ) ,
226
+ JobSvd :: None => None ,
227
+ _ => unimplemented!( "SVD with partial vector output is not supported yet" ) ,
228
+ } ;
229
+
230
+ let k = std:: cmp:: min( m, n) ;
231
+ let mut s = vec_uninit( k as usize ) ;
232
+
233
+ // eval work size
234
+ let mut info = 0 ;
235
+ let mut work_size = [ Self :: Elem :: zero( ) ] ;
236
+ unsafe {
237
+ $svd(
238
+ ju. as_ptr( ) ,
239
+ jvt. as_ptr( ) ,
240
+ & m,
241
+ & n,
242
+ std:: ptr:: null_mut( ) ,
243
+ & m,
244
+ AsPtr :: as_mut_ptr( & mut s) ,
245
+ AsPtr :: as_mut_ptr( u. as_mut( ) . map( |x| x. as_mut_slice( ) ) . unwrap_or( & mut [ ] ) ) ,
246
+ & m,
247
+ AsPtr :: as_mut_ptr( vt. as_mut( ) . map( |x| x. as_mut_slice( ) ) . unwrap_or( & mut [ ] ) ) ,
248
+ & n,
249
+ AsPtr :: as_mut_ptr( & mut work_size) ,
250
+ & ( -1 ) ,
251
+ & mut info,
252
+ ) ;
253
+ }
254
+ info. as_lapack_result( ) ?;
255
+ let lwork = work_size[ 0 ] . to_usize( ) . unwrap( ) ;
256
+ let work = vec_uninit( lwork) ;
257
+ Ok ( SvdWork {
258
+ layout,
259
+ ju,
260
+ jvt,
261
+ s,
262
+ u,
263
+ vt,
264
+ work,
265
+ rwork: None ,
266
+ } )
267
+ }
268
+
269
+ fn calc( & mut self , a: & mut [ Self :: Elem ] ) -> Result <SvdRef <Self :: Elem >> {
270
+ let m = self . layout. lda( ) ;
271
+ let n = self . layout. len( ) ;
272
+ let lwork = self . work. len( ) . to_i32( ) . unwrap( ) ;
273
+
274
+ let mut info = 0 ;
275
+ unsafe {
276
+ $svd(
277
+ self . ju. as_ptr( ) ,
278
+ self . jvt. as_ptr( ) ,
279
+ & m,
280
+ & n,
281
+ AsPtr :: as_mut_ptr( a) ,
282
+ & m,
283
+ AsPtr :: as_mut_ptr( & mut self . s) ,
284
+ AsPtr :: as_mut_ptr(
285
+ self . u. as_mut( ) . map( |x| x. as_mut_slice( ) ) . unwrap_or( & mut [ ] ) ,
286
+ ) ,
287
+ & m,
288
+ AsPtr :: as_mut_ptr(
289
+ self . vt
290
+ . as_mut( )
291
+ . map( |x| x. as_mut_slice( ) )
292
+ . unwrap_or( & mut [ ] ) ,
293
+ ) ,
294
+ & n,
295
+ AsPtr :: as_mut_ptr( & mut self . work) ,
296
+ & ( lwork as i32 ) ,
297
+ & mut info,
298
+ ) ;
299
+ }
300
+ info. as_lapack_result( ) ?;
301
+
302
+ let s = unsafe { self . s. slice_assume_init_ref( ) } ;
303
+ let u = self
304
+ . u
305
+ . as_ref( )
306
+ . map( |v| unsafe { v. slice_assume_init_ref( ) } ) ;
307
+ let vt = self
308
+ . vt
309
+ . as_ref( )
310
+ . map( |v| unsafe { v. slice_assume_init_ref( ) } ) ;
311
+
312
+ match self . layout {
313
+ MatrixLayout :: F { .. } => Ok ( SvdRef { s, u, vt } ) ,
314
+ MatrixLayout :: C { .. } => Ok ( SvdRef { s, u: vt, vt: u } ) ,
315
+ }
316
+ }
317
+
318
+ fn eval( mut self , a: & mut [ Self :: Elem ] ) -> Result <SvdOwned <Self :: Elem >> {
319
+ let _ref = self . calc( a) ?;
320
+ let s = unsafe { self . s. assume_init( ) } ;
321
+ let u = self . u. map( |v| unsafe { v. assume_init( ) } ) ;
322
+ let vt = self . vt. map( |v| unsafe { v. assume_init( ) } ) ;
323
+ match self . layout {
324
+ MatrixLayout :: F { .. } => Ok ( SvdOwned { s, u, vt } ) ,
325
+ MatrixLayout :: C { .. } => Ok ( SvdOwned { s, u: vt, vt: u } ) ,
326
+ }
327
+ }
328
+ }
329
+ } ;
330
+ }
331
+ impl_svd_work_r ! ( f64 , lapack_sys:: dgesvd_) ;
332
+ impl_svd_work_r ! ( f32 , lapack_sys:: sgesvd_) ;
333
+
33
334
macro_rules! impl_svd {
34
335
( @real, $scalar: ty, $gesvd: path) => {
35
336
impl_svd!( @body, $scalar, $gesvd, ) ;
0 commit comments