@@ -12,6 +12,14 @@ pub struct LeastSquaresOwned<A: Scalar> {
12
12
pub rank : i32 ,
13
13
}
14
14
15
+ /// Result of LeastSquares
16
+ pub struct LeastSquaresRef < ' work , A : Scalar > {
17
+ /// singular values
18
+ pub singular_values : & ' work [ A :: Real ] ,
19
+ /// The rank of the input matrix A
20
+ pub rank : i32 ,
21
+ }
22
+
15
23
#[ cfg_attr( doc, katexit:: katexit) ]
16
24
/// Solve least square problem
17
25
pub trait LeastSquaresSvdDivideConquer_ : Scalar {
@@ -29,8 +37,325 @@ pub trait LeastSquaresSvdDivideConquer_: Scalar {
29
37
a : & mut [ Self ] ,
30
38
b_layout : MatrixLayout ,
31
39
b : & mut [ Self ] ,
32
- ) -> Result < LeastSquaresOutput < Self > > ;
40
+ ) -> Result < LeastSquaresOwned < Self > > ;
41
+ }
42
+
43
+ pub struct LeastSquaresWork < T : Scalar > {
44
+ pub a_layout : MatrixLayout ,
45
+ pub b_layout : MatrixLayout ,
46
+ pub singular_values : Vec < MaybeUninit < T :: Real > > ,
47
+ pub work : Vec < MaybeUninit < T > > ,
48
+ pub iwork : Vec < MaybeUninit < i32 > > ,
49
+ pub rwork : Option < Vec < MaybeUninit < T :: Real > > > ,
50
+ }
51
+
52
+ pub trait LeastSquaresWorkImpl : Sized {
53
+ type Elem : Scalar ;
54
+ fn new ( a_layout : MatrixLayout , b_layout : MatrixLayout ) -> Result < Self > ;
55
+ fn calc (
56
+ & mut self ,
57
+ a : & mut [ Self :: Elem ] ,
58
+ b : & mut [ Self :: Elem ] ,
59
+ ) -> Result < LeastSquaresRef < Self :: Elem > > ;
60
+ fn eval (
61
+ self ,
62
+ a : & mut [ Self :: Elem ] ,
63
+ b : & mut [ Self :: Elem ] ,
64
+ ) -> Result < LeastSquaresOwned < Self :: Elem > > ;
65
+ }
66
+
67
+ macro_rules! impl_least_squares_work_c {
68
+ ( $c: ty, $lsd: path) => {
69
+ impl LeastSquaresWorkImpl for LeastSquaresWork <$c> {
70
+ type Elem = $c;
71
+
72
+ fn new( a_layout: MatrixLayout , b_layout: MatrixLayout ) -> Result <Self > {
73
+ let ( m, n) = a_layout. size( ) ;
74
+ let ( m_, nrhs) = b_layout. size( ) ;
75
+ let k = m. min( n) ;
76
+ assert!( m_ >= m) ;
77
+
78
+ let rcond = -1. ;
79
+ let mut singular_values = vec_uninit( k as usize ) ;
80
+ let mut rank: i32 = 0 ;
81
+
82
+ // eval work size
83
+ let mut info = 0 ;
84
+ let mut work_size = [ Self :: Elem :: zero( ) ] ;
85
+ let mut iwork_size = [ 0 ] ;
86
+ let mut rwork = [ <Self :: Elem as Scalar >:: Real :: zero( ) ] ;
87
+ unsafe {
88
+ $lsd(
89
+ & m,
90
+ & n,
91
+ & nrhs,
92
+ std:: ptr:: null_mut( ) ,
93
+ & a_layout. lda( ) ,
94
+ std:: ptr:: null_mut( ) ,
95
+ & b_layout. lda( ) ,
96
+ AsPtr :: as_mut_ptr( & mut singular_values) ,
97
+ & rcond,
98
+ & mut rank,
99
+ AsPtr :: as_mut_ptr( & mut work_size) ,
100
+ & ( -1 ) ,
101
+ AsPtr :: as_mut_ptr( & mut rwork) ,
102
+ iwork_size. as_mut_ptr( ) ,
103
+ & mut info,
104
+ )
105
+ } ;
106
+ info. as_lapack_result( ) ?;
107
+
108
+ let lwork = work_size[ 0 ] . to_usize( ) . unwrap( ) ;
109
+ let liwork = iwork_size[ 0 ] . to_usize( ) . unwrap( ) ;
110
+ let lrwork = rwork[ 0 ] . to_usize( ) . unwrap( ) ;
111
+
112
+ let work = vec_uninit( lwork) ;
113
+ let iwork = vec_uninit( liwork) ;
114
+ let rwork = vec_uninit( lrwork) ;
115
+
116
+ Ok ( LeastSquaresWork {
117
+ a_layout,
118
+ b_layout,
119
+ work,
120
+ iwork,
121
+ rwork: Some ( rwork) ,
122
+ singular_values,
123
+ } )
124
+ }
125
+
126
+ fn calc(
127
+ & mut self ,
128
+ a: & mut [ Self :: Elem ] ,
129
+ b: & mut [ Self :: Elem ] ,
130
+ ) -> Result <LeastSquaresRef <Self :: Elem >> {
131
+ let ( m, n) = self . a_layout. size( ) ;
132
+ let ( m_, nrhs) = self . b_layout. size( ) ;
133
+ assert!( m_ >= m) ;
134
+
135
+ let lwork = self . work. len( ) . to_i32( ) . unwrap( ) ;
136
+
137
+ // Transpose if a is C-continuous
138
+ let mut a_t = None ;
139
+ let a_layout = match self . a_layout {
140
+ MatrixLayout :: C { .. } => {
141
+ let ( layout, t) = transpose( self . a_layout, a) ;
142
+ a_t = Some ( t) ;
143
+ layout
144
+ }
145
+ MatrixLayout :: F { .. } => self . a_layout,
146
+ } ;
147
+
148
+ // Transpose if b is C-continuous
149
+ let mut b_t = None ;
150
+ let b_layout = match self . b_layout {
151
+ MatrixLayout :: C { .. } => {
152
+ let ( layout, t) = transpose( self . b_layout, b) ;
153
+ b_t = Some ( t) ;
154
+ layout
155
+ }
156
+ MatrixLayout :: F { .. } => self . b_layout,
157
+ } ;
158
+
159
+ let rcond: <Self :: Elem as Scalar >:: Real = -1. ;
160
+ let mut rank: i32 = 0 ;
161
+
162
+ let mut info = 0 ;
163
+ unsafe {
164
+ $lsd(
165
+ & m,
166
+ & n,
167
+ & nrhs,
168
+ AsPtr :: as_mut_ptr( a_t. as_mut( ) . map( |v| v. as_mut_slice( ) ) . unwrap_or( a) ) ,
169
+ & a_layout. lda( ) ,
170
+ AsPtr :: as_mut_ptr( b_t. as_mut( ) . map( |v| v. as_mut_slice( ) ) . unwrap_or( b) ) ,
171
+ & b_layout. lda( ) ,
172
+ AsPtr :: as_mut_ptr( & mut self . singular_values) ,
173
+ & rcond,
174
+ & mut rank,
175
+ AsPtr :: as_mut_ptr( & mut self . work) ,
176
+ & lwork,
177
+ AsPtr :: as_mut_ptr( self . rwork. as_mut( ) . unwrap( ) ) ,
178
+ AsPtr :: as_mut_ptr( & mut self . iwork) ,
179
+ & mut info,
180
+ ) ;
181
+ }
182
+ info. as_lapack_result( ) ?;
183
+
184
+ let singular_values = unsafe { self . singular_values. slice_assume_init_ref( ) } ;
185
+
186
+ // Skip a_t -> a transpose because A has been destroyed
187
+ // Re-transpose b
188
+ if let Some ( b_t) = b_t {
189
+ transpose_over( b_layout, & b_t, b) ;
190
+ }
191
+
192
+ Ok ( LeastSquaresRef {
193
+ singular_values,
194
+ rank,
195
+ } )
196
+ }
197
+
198
+ fn eval(
199
+ mut self ,
200
+ a: & mut [ Self :: Elem ] ,
201
+ b: & mut [ Self :: Elem ] ,
202
+ ) -> Result <LeastSquaresOwned <Self :: Elem >> {
203
+ let LeastSquaresRef { rank, .. } = self . calc( a, b) ?;
204
+ let singular_values = unsafe { self . singular_values. assume_init( ) } ;
205
+ Ok ( LeastSquaresOwned {
206
+ singular_values,
207
+ rank,
208
+ } )
209
+ }
210
+ }
211
+ } ;
212
+ }
213
+ impl_least_squares_work_c ! ( c64, lapack_sys:: zgelsd_) ;
214
+ impl_least_squares_work_c ! ( c32, lapack_sys:: cgelsd_) ;
215
+
216
+ macro_rules! impl_least_squares_work_r {
217
+ ( $c: ty, $lsd: path) => {
218
+ impl LeastSquaresWorkImpl for LeastSquaresWork <$c> {
219
+ type Elem = $c;
220
+
221
+ fn new( a_layout: MatrixLayout , b_layout: MatrixLayout ) -> Result <Self > {
222
+ let ( m, n) = a_layout. size( ) ;
223
+ let ( m_, nrhs) = b_layout. size( ) ;
224
+ let k = m. min( n) ;
225
+ assert!( m_ >= m) ;
226
+
227
+ let rcond = -1. ;
228
+ let mut singular_values = vec_uninit( k as usize ) ;
229
+ let mut rank: i32 = 0 ;
230
+
231
+ // eval work size
232
+ let mut info = 0 ;
233
+ let mut work_size = [ Self :: Elem :: zero( ) ] ;
234
+ let mut iwork_size = [ 0 ] ;
235
+ unsafe {
236
+ $lsd(
237
+ & m,
238
+ & n,
239
+ & nrhs,
240
+ std:: ptr:: null_mut( ) ,
241
+ & a_layout. lda( ) ,
242
+ std:: ptr:: null_mut( ) ,
243
+ & b_layout. lda( ) ,
244
+ AsPtr :: as_mut_ptr( & mut singular_values) ,
245
+ & rcond,
246
+ & mut rank,
247
+ AsPtr :: as_mut_ptr( & mut work_size) ,
248
+ & ( -1 ) ,
249
+ iwork_size. as_mut_ptr( ) ,
250
+ & mut info,
251
+ )
252
+ } ;
253
+ info. as_lapack_result( ) ?;
254
+
255
+ let lwork = work_size[ 0 ] . to_usize( ) . unwrap( ) ;
256
+ let liwork = iwork_size[ 0 ] . to_usize( ) . unwrap( ) ;
257
+
258
+ let work = vec_uninit( lwork) ;
259
+ let iwork = vec_uninit( liwork) ;
260
+
261
+ Ok ( LeastSquaresWork {
262
+ a_layout,
263
+ b_layout,
264
+ work,
265
+ iwork,
266
+ rwork: None ,
267
+ singular_values,
268
+ } )
269
+ }
270
+
271
+ fn calc(
272
+ & mut self ,
273
+ a: & mut [ Self :: Elem ] ,
274
+ b: & mut [ Self :: Elem ] ,
275
+ ) -> Result <LeastSquaresRef <Self :: Elem >> {
276
+ let ( m, n) = self . a_layout. size( ) ;
277
+ let ( m_, nrhs) = self . b_layout. size( ) ;
278
+ assert!( m_ >= m) ;
279
+
280
+ let lwork = self . work. len( ) . to_i32( ) . unwrap( ) ;
281
+
282
+ // Transpose if a is C-continuous
283
+ let mut a_t = None ;
284
+ let a_layout = match self . a_layout {
285
+ MatrixLayout :: C { .. } => {
286
+ let ( layout, t) = transpose( self . a_layout, a) ;
287
+ a_t = Some ( t) ;
288
+ layout
289
+ }
290
+ MatrixLayout :: F { .. } => self . a_layout,
291
+ } ;
292
+
293
+ // Transpose if b is C-continuous
294
+ let mut b_t = None ;
295
+ let b_layout = match self . b_layout {
296
+ MatrixLayout :: C { .. } => {
297
+ let ( layout, t) = transpose( self . b_layout, b) ;
298
+ b_t = Some ( t) ;
299
+ layout
300
+ }
301
+ MatrixLayout :: F { .. } => self . b_layout,
302
+ } ;
303
+
304
+ let rcond: <Self :: Elem as Scalar >:: Real = -1. ;
305
+ let mut rank: i32 = 0 ;
306
+
307
+ let mut info = 0 ;
308
+ unsafe {
309
+ $lsd(
310
+ & m,
311
+ & n,
312
+ & nrhs,
313
+ AsPtr :: as_mut_ptr( a_t. as_mut( ) . map( |v| v. as_mut_slice( ) ) . unwrap_or( a) ) ,
314
+ & a_layout. lda( ) ,
315
+ AsPtr :: as_mut_ptr( b_t. as_mut( ) . map( |v| v. as_mut_slice( ) ) . unwrap_or( b) ) ,
316
+ & b_layout. lda( ) ,
317
+ AsPtr :: as_mut_ptr( & mut self . singular_values) ,
318
+ & rcond,
319
+ & mut rank,
320
+ AsPtr :: as_mut_ptr( & mut self . work) ,
321
+ & lwork,
322
+ AsPtr :: as_mut_ptr( & mut self . iwork) ,
323
+ & mut info,
324
+ ) ;
325
+ }
326
+ info. as_lapack_result( ) ?;
327
+
328
+ let singular_values = unsafe { self . singular_values. slice_assume_init_ref( ) } ;
329
+
330
+ // Skip a_t -> a transpose because A has been destroyed
331
+ // Re-transpose b
332
+ if let Some ( b_t) = b_t {
333
+ transpose_over( b_layout, & b_t, b) ;
334
+ }
335
+
336
+ Ok ( LeastSquaresRef {
337
+ singular_values,
338
+ rank,
339
+ } )
340
+ }
341
+
342
+ fn eval(
343
+ mut self ,
344
+ a: & mut [ Self :: Elem ] ,
345
+ b: & mut [ Self :: Elem ] ,
346
+ ) -> Result <LeastSquaresOwned <Self :: Elem >> {
347
+ let LeastSquaresRef { rank, .. } = self . calc( a, b) ?;
348
+ let singular_values = unsafe { self . singular_values. assume_init( ) } ;
349
+ Ok ( LeastSquaresOwned {
350
+ singular_values,
351
+ rank,
352
+ } )
353
+ }
354
+ }
355
+ } ;
33
356
}
357
+ impl_least_squares_work_r ! ( f64 , lapack_sys:: dgelsd_) ;
358
+ impl_least_squares_work_r ! ( f32 , lapack_sys:: sgelsd_) ;
34
359
35
360
macro_rules! impl_least_squares {
36
361
( @real, $scalar: ty, $gelsd: path) => {
0 commit comments