Skip to content

Commit da3221a

Browse files
committed
LeastSquaresWork
1 parent ac2f7bc commit da3221a

File tree

1 file changed

+326
-1
lines changed

1 file changed

+326
-1
lines changed

lax/src/least_squares.rs

Lines changed: 326 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,14 @@ pub struct LeastSquaresOwned<A: Scalar> {
1212
pub rank: i32,
1313
}
1414

15+
/// Result of LeastSquares
16+
pub struct LeastSquaresRef<'work, A: Scalar> {
17+
/// singular values
18+
pub singular_values: &'work [A::Real],
19+
/// The rank of the input matrix A
20+
pub rank: i32,
21+
}
22+
1523
#[cfg_attr(doc, katexit::katexit)]
1624
/// Solve least square problem
1725
pub trait LeastSquaresSvdDivideConquer_: Scalar {
@@ -29,8 +37,325 @@ pub trait LeastSquaresSvdDivideConquer_: Scalar {
2937
a: &mut [Self],
3038
b_layout: MatrixLayout,
3139
b: &mut [Self],
32-
) -> Result<LeastSquaresOutput<Self>>;
40+
) -> Result<LeastSquaresOwned<Self>>;
41+
}
42+
43+
pub struct LeastSquaresWork<T: Scalar> {
44+
pub a_layout: MatrixLayout,
45+
pub b_layout: MatrixLayout,
46+
pub singular_values: Vec<MaybeUninit<T::Real>>,
47+
pub work: Vec<MaybeUninit<T>>,
48+
pub iwork: Vec<MaybeUninit<i32>>,
49+
pub rwork: Option<Vec<MaybeUninit<T::Real>>>,
50+
}
51+
52+
pub trait LeastSquaresWorkImpl: Sized {
53+
type Elem: Scalar;
54+
fn new(a_layout: MatrixLayout, b_layout: MatrixLayout) -> Result<Self>;
55+
fn calc(
56+
&mut self,
57+
a: &mut [Self::Elem],
58+
b: &mut [Self::Elem],
59+
) -> Result<LeastSquaresRef<Self::Elem>>;
60+
fn eval(
61+
self,
62+
a: &mut [Self::Elem],
63+
b: &mut [Self::Elem],
64+
) -> Result<LeastSquaresOwned<Self::Elem>>;
65+
}
66+
67+
macro_rules! impl_least_squares_work_c {
68+
($c:ty, $lsd:path) => {
69+
impl LeastSquaresWorkImpl for LeastSquaresWork<$c> {
70+
type Elem = $c;
71+
72+
fn new(a_layout: MatrixLayout, b_layout: MatrixLayout) -> Result<Self> {
73+
let (m, n) = a_layout.size();
74+
let (m_, nrhs) = b_layout.size();
75+
let k = m.min(n);
76+
assert!(m_ >= m);
77+
78+
let rcond = -1.;
79+
let mut singular_values = vec_uninit(k as usize);
80+
let mut rank: i32 = 0;
81+
82+
// eval work size
83+
let mut info = 0;
84+
let mut work_size = [Self::Elem::zero()];
85+
let mut iwork_size = [0];
86+
let mut rwork = [<Self::Elem as Scalar>::Real::zero()];
87+
unsafe {
88+
$lsd(
89+
&m,
90+
&n,
91+
&nrhs,
92+
std::ptr::null_mut(),
93+
&a_layout.lda(),
94+
std::ptr::null_mut(),
95+
&b_layout.lda(),
96+
AsPtr::as_mut_ptr(&mut singular_values),
97+
&rcond,
98+
&mut rank,
99+
AsPtr::as_mut_ptr(&mut work_size),
100+
&(-1),
101+
AsPtr::as_mut_ptr(&mut rwork),
102+
iwork_size.as_mut_ptr(),
103+
&mut info,
104+
)
105+
};
106+
info.as_lapack_result()?;
107+
108+
let lwork = work_size[0].to_usize().unwrap();
109+
let liwork = iwork_size[0].to_usize().unwrap();
110+
let lrwork = rwork[0].to_usize().unwrap();
111+
112+
let work = vec_uninit(lwork);
113+
let iwork = vec_uninit(liwork);
114+
let rwork = vec_uninit(lrwork);
115+
116+
Ok(LeastSquaresWork {
117+
a_layout,
118+
b_layout,
119+
work,
120+
iwork,
121+
rwork: Some(rwork),
122+
singular_values,
123+
})
124+
}
125+
126+
fn calc(
127+
&mut self,
128+
a: &mut [Self::Elem],
129+
b: &mut [Self::Elem],
130+
) -> Result<LeastSquaresRef<Self::Elem>> {
131+
let (m, n) = self.a_layout.size();
132+
let (m_, nrhs) = self.b_layout.size();
133+
assert!(m_ >= m);
134+
135+
let lwork = self.work.len().to_i32().unwrap();
136+
137+
// Transpose if a is C-continuous
138+
let mut a_t = None;
139+
let a_layout = match self.a_layout {
140+
MatrixLayout::C { .. } => {
141+
let (layout, t) = transpose(self.a_layout, a);
142+
a_t = Some(t);
143+
layout
144+
}
145+
MatrixLayout::F { .. } => self.a_layout,
146+
};
147+
148+
// Transpose if b is C-continuous
149+
let mut b_t = None;
150+
let b_layout = match self.b_layout {
151+
MatrixLayout::C { .. } => {
152+
let (layout, t) = transpose(self.b_layout, b);
153+
b_t = Some(t);
154+
layout
155+
}
156+
MatrixLayout::F { .. } => self.b_layout,
157+
};
158+
159+
let rcond: <Self::Elem as Scalar>::Real = -1.;
160+
let mut rank: i32 = 0;
161+
162+
let mut info = 0;
163+
unsafe {
164+
$lsd(
165+
&m,
166+
&n,
167+
&nrhs,
168+
AsPtr::as_mut_ptr(a_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(a)),
169+
&a_layout.lda(),
170+
AsPtr::as_mut_ptr(b_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(b)),
171+
&b_layout.lda(),
172+
AsPtr::as_mut_ptr(&mut self.singular_values),
173+
&rcond,
174+
&mut rank,
175+
AsPtr::as_mut_ptr(&mut self.work),
176+
&lwork,
177+
AsPtr::as_mut_ptr(self.rwork.as_mut().unwrap()),
178+
AsPtr::as_mut_ptr(&mut self.iwork),
179+
&mut info,
180+
);
181+
}
182+
info.as_lapack_result()?;
183+
184+
let singular_values = unsafe { self.singular_values.slice_assume_init_ref() };
185+
186+
// Skip a_t -> a transpose because A has been destroyed
187+
// Re-transpose b
188+
if let Some(b_t) = b_t {
189+
transpose_over(b_layout, &b_t, b);
190+
}
191+
192+
Ok(LeastSquaresRef {
193+
singular_values,
194+
rank,
195+
})
196+
}
197+
198+
fn eval(
199+
mut self,
200+
a: &mut [Self::Elem],
201+
b: &mut [Self::Elem],
202+
) -> Result<LeastSquaresOwned<Self::Elem>> {
203+
let LeastSquaresRef { rank, .. } = self.calc(a, b)?;
204+
let singular_values = unsafe { self.singular_values.assume_init() };
205+
Ok(LeastSquaresOwned {
206+
singular_values,
207+
rank,
208+
})
209+
}
210+
}
211+
};
212+
}
213+
impl_least_squares_work_c!(c64, lapack_sys::zgelsd_);
214+
impl_least_squares_work_c!(c32, lapack_sys::cgelsd_);
215+
216+
macro_rules! impl_least_squares_work_r {
217+
($c:ty, $lsd:path) => {
218+
impl LeastSquaresWorkImpl for LeastSquaresWork<$c> {
219+
type Elem = $c;
220+
221+
fn new(a_layout: MatrixLayout, b_layout: MatrixLayout) -> Result<Self> {
222+
let (m, n) = a_layout.size();
223+
let (m_, nrhs) = b_layout.size();
224+
let k = m.min(n);
225+
assert!(m_ >= m);
226+
227+
let rcond = -1.;
228+
let mut singular_values = vec_uninit(k as usize);
229+
let mut rank: i32 = 0;
230+
231+
// eval work size
232+
let mut info = 0;
233+
let mut work_size = [Self::Elem::zero()];
234+
let mut iwork_size = [0];
235+
unsafe {
236+
$lsd(
237+
&m,
238+
&n,
239+
&nrhs,
240+
std::ptr::null_mut(),
241+
&a_layout.lda(),
242+
std::ptr::null_mut(),
243+
&b_layout.lda(),
244+
AsPtr::as_mut_ptr(&mut singular_values),
245+
&rcond,
246+
&mut rank,
247+
AsPtr::as_mut_ptr(&mut work_size),
248+
&(-1),
249+
iwork_size.as_mut_ptr(),
250+
&mut info,
251+
)
252+
};
253+
info.as_lapack_result()?;
254+
255+
let lwork = work_size[0].to_usize().unwrap();
256+
let liwork = iwork_size[0].to_usize().unwrap();
257+
258+
let work = vec_uninit(lwork);
259+
let iwork = vec_uninit(liwork);
260+
261+
Ok(LeastSquaresWork {
262+
a_layout,
263+
b_layout,
264+
work,
265+
iwork,
266+
rwork: None,
267+
singular_values,
268+
})
269+
}
270+
271+
fn calc(
272+
&mut self,
273+
a: &mut [Self::Elem],
274+
b: &mut [Self::Elem],
275+
) -> Result<LeastSquaresRef<Self::Elem>> {
276+
let (m, n) = self.a_layout.size();
277+
let (m_, nrhs) = self.b_layout.size();
278+
assert!(m_ >= m);
279+
280+
let lwork = self.work.len().to_i32().unwrap();
281+
282+
// Transpose if a is C-continuous
283+
let mut a_t = None;
284+
let a_layout = match self.a_layout {
285+
MatrixLayout::C { .. } => {
286+
let (layout, t) = transpose(self.a_layout, a);
287+
a_t = Some(t);
288+
layout
289+
}
290+
MatrixLayout::F { .. } => self.a_layout,
291+
};
292+
293+
// Transpose if b is C-continuous
294+
let mut b_t = None;
295+
let b_layout = match self.b_layout {
296+
MatrixLayout::C { .. } => {
297+
let (layout, t) = transpose(self.b_layout, b);
298+
b_t = Some(t);
299+
layout
300+
}
301+
MatrixLayout::F { .. } => self.b_layout,
302+
};
303+
304+
let rcond: <Self::Elem as Scalar>::Real = -1.;
305+
let mut rank: i32 = 0;
306+
307+
let mut info = 0;
308+
unsafe {
309+
$lsd(
310+
&m,
311+
&n,
312+
&nrhs,
313+
AsPtr::as_mut_ptr(a_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(a)),
314+
&a_layout.lda(),
315+
AsPtr::as_mut_ptr(b_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(b)),
316+
&b_layout.lda(),
317+
AsPtr::as_mut_ptr(&mut self.singular_values),
318+
&rcond,
319+
&mut rank,
320+
AsPtr::as_mut_ptr(&mut self.work),
321+
&lwork,
322+
AsPtr::as_mut_ptr(&mut self.iwork),
323+
&mut info,
324+
);
325+
}
326+
info.as_lapack_result()?;
327+
328+
let singular_values = unsafe { self.singular_values.slice_assume_init_ref() };
329+
330+
// Skip a_t -> a transpose because A has been destroyed
331+
// Re-transpose b
332+
if let Some(b_t) = b_t {
333+
transpose_over(b_layout, &b_t, b);
334+
}
335+
336+
Ok(LeastSquaresRef {
337+
singular_values,
338+
rank,
339+
})
340+
}
341+
342+
fn eval(
343+
mut self,
344+
a: &mut [Self::Elem],
345+
b: &mut [Self::Elem],
346+
) -> Result<LeastSquaresOwned<Self::Elem>> {
347+
let LeastSquaresRef { rank, .. } = self.calc(a, b)?;
348+
let singular_values = unsafe { self.singular_values.assume_init() };
349+
Ok(LeastSquaresOwned {
350+
singular_values,
351+
rank,
352+
})
353+
}
354+
}
355+
};
33356
}
357+
impl_least_squares_work_r!(f64, lapack_sys::dgelsd_);
358+
impl_least_squares_work_r!(f32, lapack_sys::sgelsd_);
34359

35360
macro_rules! impl_least_squares {
36361
(@real, $scalar:ty, $gelsd:path) => {

0 commit comments

Comments
 (0)