@@ -14,9 +14,303 @@ pub trait SVDDC_: Scalar {
14
14
/// |:-------|:-------|:-------|:-------|
15
15
/// | sgesdd | dgesdd | cgesdd | zgesdd |
16
16
///
17
- fn svddc ( l : MatrixLayout , jobz : JobSvd , a : & mut [ Self ] ) -> Result < SvdOwned < Self > > ;
17
+ fn svddc ( layout : MatrixLayout , jobz : JobSvd , a : & mut [ Self ] ) -> Result < SvdOwned < Self > > ;
18
18
}
19
19
20
+ pub struct SvdDcWork < T : Scalar > {
21
+ pub jobz : JobSvd ,
22
+ pub layout : MatrixLayout ,
23
+ pub s : Vec < MaybeUninit < T :: Real > > ,
24
+ pub u : Option < Vec < MaybeUninit < T > > > ,
25
+ pub vt : Option < Vec < MaybeUninit < T > > > ,
26
+ pub work : Vec < MaybeUninit < T > > ,
27
+ pub iwork : Vec < MaybeUninit < i32 > > ,
28
+ pub rwork : Option < Vec < MaybeUninit < T :: Real > > > ,
29
+ }
30
+
31
+ pub trait SvdDcWorkImpl : Sized {
32
+ type Elem : Scalar ;
33
+ fn new ( layout : MatrixLayout , jobz : JobSvd ) -> Result < Self > ;
34
+ fn calc ( & mut self , a : & mut [ Self :: Elem ] ) -> Result < SvdRef < Self :: Elem > > ;
35
+ fn eval ( self , a : & mut [ Self :: Elem ] ) -> Result < SvdOwned < Self :: Elem > > ;
36
+ }
37
+
38
+ macro_rules! impl_svd_dc_work_c {
39
+ ( $s: ty, $sdd: path) => {
40
+ impl SvdDcWorkImpl for SvdDcWork <$s> {
41
+ type Elem = $s;
42
+
43
+ fn new( layout: MatrixLayout , jobz: JobSvd ) -> Result <Self > {
44
+ let m = layout. lda( ) ;
45
+ let n = layout. len( ) ;
46
+ let k = m. min( n) ;
47
+ let ( u_col, vt_row) = match jobz {
48
+ JobSvd :: All | JobSvd :: None => ( m, n) ,
49
+ JobSvd :: Some => ( k, k) ,
50
+ } ;
51
+
52
+ let mut s = vec_uninit( k as usize ) ;
53
+ let ( mut u, mut vt) = match jobz {
54
+ JobSvd :: All => (
55
+ Some ( vec_uninit( ( m * m) as usize ) ) ,
56
+ Some ( vec_uninit( ( n * n) as usize ) ) ,
57
+ ) ,
58
+ JobSvd :: Some => (
59
+ Some ( vec_uninit( ( m * u_col) as usize ) ) ,
60
+ Some ( vec_uninit( ( n * vt_row) as usize ) ) ,
61
+ ) ,
62
+ JobSvd :: None => ( None , None ) ,
63
+ } ;
64
+ let mut iwork = vec_uninit( 8 * k as usize ) ;
65
+
66
+ let mx = n. max( m) as usize ;
67
+ let mn = n. min( m) as usize ;
68
+ let lrwork = match jobz {
69
+ JobSvd :: None => 7 * mn,
70
+ _ => std:: cmp:: max( 5 * mn * mn + 5 * mn, 2 * mx * mn + 2 * mn * mn + mn) ,
71
+ } ;
72
+ let mut rwork = vec_uninit( lrwork) ;
73
+
74
+ let mut info = 0 ;
75
+ let mut work_size = [ Self :: Elem :: zero( ) ] ;
76
+ unsafe {
77
+ $sdd(
78
+ jobz. as_ptr( ) ,
79
+ & m,
80
+ & n,
81
+ std:: ptr:: null_mut( ) ,
82
+ & m,
83
+ AsPtr :: as_mut_ptr( & mut s) ,
84
+ AsPtr :: as_mut_ptr( u. as_mut( ) . map( |x| x. as_mut_slice( ) ) . unwrap_or( & mut [ ] ) ) ,
85
+ & m,
86
+ AsPtr :: as_mut_ptr( vt. as_mut( ) . map( |x| x. as_mut_slice( ) ) . unwrap_or( & mut [ ] ) ) ,
87
+ & vt_row,
88
+ AsPtr :: as_mut_ptr( & mut work_size) ,
89
+ & ( -1 ) ,
90
+ AsPtr :: as_mut_ptr( & mut rwork) ,
91
+ AsPtr :: as_mut_ptr( & mut iwork) ,
92
+ & mut info,
93
+ ) ;
94
+ }
95
+ info. as_lapack_result( ) ?;
96
+ let lwork = work_size[ 0 ] . to_usize( ) . unwrap( ) ;
97
+ let work = vec_uninit( lwork) ;
98
+ Ok ( SvdDcWork {
99
+ layout,
100
+ jobz,
101
+ iwork,
102
+ work,
103
+ rwork: Some ( rwork) ,
104
+ u,
105
+ vt,
106
+ s,
107
+ } )
108
+ }
109
+
110
+ fn calc( & mut self , a: & mut [ Self :: Elem ] ) -> Result <SvdRef <Self :: Elem >> {
111
+ let m = self . layout. lda( ) ;
112
+ let n = self . layout. len( ) ;
113
+ let k = m. min( n) ;
114
+ let ( _, vt_row) = match self . jobz {
115
+ JobSvd :: All | JobSvd :: None => ( m, n) ,
116
+ JobSvd :: Some => ( k, k) ,
117
+ } ;
118
+ let lwork = self . work. len( ) . to_i32( ) . unwrap( ) ;
119
+
120
+ let mut info = 0 ;
121
+ unsafe {
122
+ $sdd(
123
+ self . jobz. as_ptr( ) ,
124
+ & m,
125
+ & n,
126
+ AsPtr :: as_mut_ptr( a) ,
127
+ & m,
128
+ AsPtr :: as_mut_ptr( & mut self . s) ,
129
+ AsPtr :: as_mut_ptr(
130
+ self . u. as_mut( ) . map( |x| x. as_mut_slice( ) ) . unwrap_or( & mut [ ] ) ,
131
+ ) ,
132
+ & m,
133
+ AsPtr :: as_mut_ptr(
134
+ self . vt
135
+ . as_mut( )
136
+ . map( |x| x. as_mut_slice( ) )
137
+ . unwrap_or( & mut [ ] ) ,
138
+ ) ,
139
+ & vt_row,
140
+ AsPtr :: as_mut_ptr( & mut self . work) ,
141
+ & lwork,
142
+ AsPtr :: as_mut_ptr( self . rwork. as_mut( ) . unwrap( ) ) ,
143
+ AsPtr :: as_mut_ptr( & mut self . iwork) ,
144
+ & mut info,
145
+ ) ;
146
+ }
147
+ info. as_lapack_result( ) ?;
148
+
149
+ let s = unsafe { self . s. slice_assume_init_ref( ) } ;
150
+ let u = self
151
+ . u
152
+ . as_ref( )
153
+ . map( |v| unsafe { v. slice_assume_init_ref( ) } ) ;
154
+ let vt = self
155
+ . vt
156
+ . as_ref( )
157
+ . map( |v| unsafe { v. slice_assume_init_ref( ) } ) ;
158
+
159
+ Ok ( match self . layout {
160
+ MatrixLayout :: F { .. } => SvdRef { s, u, vt } ,
161
+ MatrixLayout :: C { .. } => SvdRef { s, u: vt, vt: u } ,
162
+ } )
163
+ }
164
+
165
+ fn eval( mut self , a: & mut [ Self :: Elem ] ) -> Result <SvdOwned <Self :: Elem >> {
166
+ let _ref = self . calc( a) ?;
167
+ let s = unsafe { self . s. assume_init( ) } ;
168
+ let u = self . u. map( |v| unsafe { v. assume_init( ) } ) ;
169
+ let vt = self . vt. map( |v| unsafe { v. assume_init( ) } ) ;
170
+ Ok ( match self . layout {
171
+ MatrixLayout :: F { .. } => SvdOwned { s, u, vt } ,
172
+ MatrixLayout :: C { .. } => SvdOwned { s, u: vt, vt: u } ,
173
+ } )
174
+ }
175
+ }
176
+ } ;
177
+ }
178
+ impl_svd_dc_work_c ! ( c64, lapack_sys:: zgesdd_) ;
179
+ impl_svd_dc_work_c ! ( c32, lapack_sys:: cgesdd_) ;
180
+
181
+ macro_rules! impl_svd_dc_work_r {
182
+ ( $s: ty, $sdd: path) => {
183
+ impl SvdDcWorkImpl for SvdDcWork <$s> {
184
+ type Elem = $s;
185
+
186
+ fn new( layout: MatrixLayout , jobz: JobSvd ) -> Result <Self > {
187
+ let m = layout. lda( ) ;
188
+ let n = layout. len( ) ;
189
+ let k = m. min( n) ;
190
+ let ( u_col, vt_row) = match jobz {
191
+ JobSvd :: All | JobSvd :: None => ( m, n) ,
192
+ JobSvd :: Some => ( k, k) ,
193
+ } ;
194
+
195
+ let mut s = vec_uninit( k as usize ) ;
196
+ let ( mut u, mut vt) = match jobz {
197
+ JobSvd :: All => (
198
+ Some ( vec_uninit( ( m * m) as usize ) ) ,
199
+ Some ( vec_uninit( ( n * n) as usize ) ) ,
200
+ ) ,
201
+ JobSvd :: Some => (
202
+ Some ( vec_uninit( ( m * u_col) as usize ) ) ,
203
+ Some ( vec_uninit( ( n * vt_row) as usize ) ) ,
204
+ ) ,
205
+ JobSvd :: None => ( None , None ) ,
206
+ } ;
207
+ let mut iwork = vec_uninit( 8 * k as usize ) ;
208
+
209
+ let mut info = 0 ;
210
+ let mut work_size = [ Self :: Elem :: zero( ) ] ;
211
+ unsafe {
212
+ $sdd(
213
+ jobz. as_ptr( ) ,
214
+ & m,
215
+ & n,
216
+ std:: ptr:: null_mut( ) ,
217
+ & m,
218
+ AsPtr :: as_mut_ptr( & mut s) ,
219
+ AsPtr :: as_mut_ptr( u. as_mut( ) . map( |x| x. as_mut_slice( ) ) . unwrap_or( & mut [ ] ) ) ,
220
+ & m,
221
+ AsPtr :: as_mut_ptr( vt. as_mut( ) . map( |x| x. as_mut_slice( ) ) . unwrap_or( & mut [ ] ) ) ,
222
+ & vt_row,
223
+ AsPtr :: as_mut_ptr( & mut work_size) ,
224
+ & ( -1 ) ,
225
+ AsPtr :: as_mut_ptr( & mut iwork) ,
226
+ & mut info,
227
+ ) ;
228
+ }
229
+ info. as_lapack_result( ) ?;
230
+ let lwork = work_size[ 0 ] . to_usize( ) . unwrap( ) ;
231
+ let work = vec_uninit( lwork) ;
232
+ Ok ( SvdDcWork {
233
+ layout,
234
+ jobz,
235
+ iwork,
236
+ work,
237
+ rwork: None ,
238
+ u,
239
+ vt,
240
+ s,
241
+ } )
242
+ }
243
+
244
+ fn calc( & mut self , a: & mut [ Self :: Elem ] ) -> Result <SvdRef <Self :: Elem >> {
245
+ let m = self . layout. lda( ) ;
246
+ let n = self . layout. len( ) ;
247
+ let k = m. min( n) ;
248
+ let ( _, vt_row) = match self . jobz {
249
+ JobSvd :: All | JobSvd :: None => ( m, n) ,
250
+ JobSvd :: Some => ( k, k) ,
251
+ } ;
252
+ let lwork = self . work. len( ) . to_i32( ) . unwrap( ) ;
253
+
254
+ let mut info = 0 ;
255
+ unsafe {
256
+ $sdd(
257
+ self . jobz. as_ptr( ) ,
258
+ & m,
259
+ & n,
260
+ AsPtr :: as_mut_ptr( a) ,
261
+ & m,
262
+ AsPtr :: as_mut_ptr( & mut self . s) ,
263
+ AsPtr :: as_mut_ptr(
264
+ self . u. as_mut( ) . map( |x| x. as_mut_slice( ) ) . unwrap_or( & mut [ ] ) ,
265
+ ) ,
266
+ & m,
267
+ AsPtr :: as_mut_ptr(
268
+ self . vt
269
+ . as_mut( )
270
+ . map( |x| x. as_mut_slice( ) )
271
+ . unwrap_or( & mut [ ] ) ,
272
+ ) ,
273
+ & vt_row,
274
+ AsPtr :: as_mut_ptr( & mut self . work) ,
275
+ & lwork,
276
+ AsPtr :: as_mut_ptr( & mut self . iwork) ,
277
+ & mut info,
278
+ ) ;
279
+ }
280
+ info. as_lapack_result( ) ?;
281
+
282
+ let s = unsafe { self . s. slice_assume_init_ref( ) } ;
283
+ let u = self
284
+ . u
285
+ . as_ref( )
286
+ . map( |v| unsafe { v. slice_assume_init_ref( ) } ) ;
287
+ let vt = self
288
+ . vt
289
+ . as_ref( )
290
+ . map( |v| unsafe { v. slice_assume_init_ref( ) } ) ;
291
+
292
+ Ok ( match self . layout {
293
+ MatrixLayout :: F { .. } => SvdRef { s, u, vt } ,
294
+ MatrixLayout :: C { .. } => SvdRef { s, u: vt, vt: u } ,
295
+ } )
296
+ }
297
+
298
+ fn eval( mut self , a: & mut [ Self :: Elem ] ) -> Result <SvdOwned <Self :: Elem >> {
299
+ let _ref = self . calc( a) ?;
300
+ let s = unsafe { self . s. assume_init( ) } ;
301
+ let u = self . u. map( |v| unsafe { v. assume_init( ) } ) ;
302
+ let vt = self . vt. map( |v| unsafe { v. assume_init( ) } ) ;
303
+ Ok ( match self . layout {
304
+ MatrixLayout :: F { .. } => SvdOwned { s, u, vt } ,
305
+ MatrixLayout :: C { .. } => SvdOwned { s, u: vt, vt: u } ,
306
+ } )
307
+ }
308
+ }
309
+ } ;
310
+ }
311
+ impl_svd_dc_work_r ! ( f64 , lapack_sys:: dgesdd_) ;
312
+ impl_svd_dc_work_r ! ( f32 , lapack_sys:: sgesdd_) ;
313
+
20
314
macro_rules! impl_svddc {
21
315
( @real, $scalar: ty, $gesdd: path) => {
22
316
impl_svddc!( @body, $scalar, $gesdd, ) ;
0 commit comments