Skip to content

Commit b379b92

Browse files
committed
Merge LeastSquaresSvdDivideConquer_ into Lapack trait
1 parent da3221a commit b379b92

File tree

3 files changed

+39
-171
lines changed

3 files changed

+39
-171
lines changed

lax/src/least_squares.rs

Lines changed: 0 additions & 162 deletions
Original file line numberDiff line numberDiff line change
@@ -20,26 +20,6 @@ pub struct LeastSquaresRef<'work, A: Scalar> {
2020
pub rank: i32,
2121
}
2222

23-
#[cfg_attr(doc, katexit::katexit)]
24-
/// Solve least square problem
25-
pub trait LeastSquaresSvdDivideConquer_: Scalar {
26-
/// Compute a vector $x$ which minimizes Euclidian norm $\| Ax - b\|$
27-
/// for a given matrix $A$ and a vector $b$.
28-
fn least_squares(
29-
a_layout: MatrixLayout,
30-
a: &mut [Self],
31-
b: &mut [Self],
32-
) -> Result<LeastSquaresOwned<Self>>;
33-
34-
/// Solve least square problems $\argmin_X \| AX - B\|$
35-
fn least_squares_nrhs(
36-
a_layout: MatrixLayout,
37-
a: &mut [Self],
38-
b_layout: MatrixLayout,
39-
b: &mut [Self],
40-
) -> Result<LeastSquaresOwned<Self>>;
41-
}
42-
4323
pub struct LeastSquaresWork<T: Scalar> {
4424
pub a_layout: MatrixLayout,
4525
pub b_layout: MatrixLayout,
@@ -356,145 +336,3 @@ macro_rules! impl_least_squares_work_r {
356336
}
357337
impl_least_squares_work_r!(f64, lapack_sys::dgelsd_);
358338
impl_least_squares_work_r!(f32, lapack_sys::sgelsd_);
359-
360-
macro_rules! impl_least_squares {
361-
(@real, $scalar:ty, $gelsd:path) => {
362-
impl_least_squares!(@body, $scalar, $gelsd, );
363-
};
364-
(@complex, $scalar:ty, $gelsd:path) => {
365-
impl_least_squares!(@body, $scalar, $gelsd, rwork);
366-
};
367-
368-
(@body, $scalar:ty, $gelsd:path, $($rwork:ident),*) => {
369-
impl LeastSquaresSvdDivideConquer_ for $scalar {
370-
fn least_squares(
371-
l: MatrixLayout,
372-
a: &mut [Self],
373-
b: &mut [Self],
374-
) -> Result<LeastSquaresOwned<Self>> {
375-
let b_layout = l.resized(b.len() as i32, 1);
376-
Self::least_squares_nrhs(l, a, b_layout, b)
377-
}
378-
379-
fn least_squares_nrhs(
380-
a_layout: MatrixLayout,
381-
a: &mut [Self],
382-
b_layout: MatrixLayout,
383-
b: &mut [Self],
384-
) -> Result<LeastSquaresOwned<Self>> {
385-
// Minimize |b - Ax|_2
386-
//
387-
// where
388-
// A : (m, n)
389-
// b : (max(m, n), nrhs) // `b` has to store `x` on exit
390-
// x : (n, nrhs)
391-
let (m, n) = a_layout.size();
392-
let (m_, nrhs) = b_layout.size();
393-
let k = m.min(n);
394-
assert!(m_ >= m);
395-
396-
// Transpose if a is C-continuous
397-
let mut a_t = None;
398-
let a_layout = match a_layout {
399-
MatrixLayout::C { .. } => {
400-
let (layout, t) = transpose(a_layout, a);
401-
a_t = Some(t);
402-
layout
403-
}
404-
MatrixLayout::F { .. } => a_layout,
405-
};
406-
407-
// Transpose if b is C-continuous
408-
let mut b_t = None;
409-
let b_layout = match b_layout {
410-
MatrixLayout::C { .. } => {
411-
let (layout, t) = transpose(b_layout, b);
412-
b_t = Some(t);
413-
layout
414-
}
415-
MatrixLayout::F { .. } => b_layout,
416-
};
417-
418-
let rcond: Self::Real = -1.;
419-
let mut singular_values: Vec<MaybeUninit<Self::Real>> = vec_uninit( k as usize);
420-
let mut rank: i32 = 0;
421-
422-
// eval work size
423-
let mut info = 0;
424-
let mut work_size = [Self::zero()];
425-
let mut iwork_size = [0];
426-
$(
427-
let mut $rwork = [Self::Real::zero()];
428-
)*
429-
unsafe {
430-
$gelsd(
431-
&m,
432-
&n,
433-
&nrhs,
434-
AsPtr::as_mut_ptr(a_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(a)),
435-
&a_layout.lda(),
436-
AsPtr::as_mut_ptr(b_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(b)),
437-
&b_layout.lda(),
438-
AsPtr::as_mut_ptr(&mut singular_values),
439-
&rcond,
440-
&mut rank,
441-
AsPtr::as_mut_ptr(&mut work_size),
442-
&(-1),
443-
$(AsPtr::as_mut_ptr(&mut $rwork),)*
444-
iwork_size.as_mut_ptr(),
445-
&mut info,
446-
)
447-
};
448-
info.as_lapack_result()?;
449-
450-
// calc
451-
let lwork = work_size[0].to_usize().unwrap();
452-
let mut work: Vec<MaybeUninit<Self>> = vec_uninit(lwork);
453-
let liwork = iwork_size[0].to_usize().unwrap();
454-
let mut iwork: Vec<MaybeUninit<i32>> = vec_uninit(liwork);
455-
$(
456-
let lrwork = $rwork[0].to_usize().unwrap();
457-
let mut $rwork: Vec<MaybeUninit<Self::Real>> = vec_uninit(lrwork);
458-
)*
459-
unsafe {
460-
$gelsd(
461-
&m,
462-
&n,
463-
&nrhs,
464-
AsPtr::as_mut_ptr(a_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(a)),
465-
&a_layout.lda(),
466-
AsPtr::as_mut_ptr(b_t.as_mut().map(|v| v.as_mut_slice()).unwrap_or(b)),
467-
&b_layout.lda(),
468-
AsPtr::as_mut_ptr(&mut singular_values),
469-
&rcond,
470-
&mut rank,
471-
AsPtr::as_mut_ptr(&mut work),
472-
&(lwork as i32),
473-
$(AsPtr::as_mut_ptr(&mut $rwork),)*
474-
AsPtr::as_mut_ptr(&mut iwork),
475-
&mut info,
476-
);
477-
}
478-
info.as_lapack_result()?;
479-
480-
let singular_values = unsafe { singular_values.assume_init() };
481-
482-
// Skip a_t -> a transpose because A has been destroyed
483-
// Re-transpose b
484-
if let Some(b_t) = b_t {
485-
transpose_over(b_layout, &b_t, b);
486-
}
487-
488-
Ok(LeastSquaresOwned {
489-
singular_values,
490-
rank,
491-
})
492-
}
493-
}
494-
};
495-
}
496-
497-
impl_least_squares!(@real, f64, lapack_sys::dgelsd_);
498-
impl_least_squares!(@real, f32, lapack_sys::sgelsd_);
499-
impl_least_squares!(@complex, c64, lapack_sys::zgelsd_);
500-
impl_least_squares!(@complex, c32, lapack_sys::cgelsd_);

lax/src/lib.rs

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -120,16 +120,10 @@ use std::mem::MaybeUninit;
120120

121121
pub type Pivot = Vec<i32>;
122122

123+
#[cfg_attr(doc, katexit::katexit)]
123124
/// Trait for primitive types which implements LAPACK subroutines
124125
pub trait Lapack:
125-
OperatorNorm_
126-
+ Solve_
127-
+ Solveh_
128-
+ Cholesky_
129-
+ Triangular_
130-
+ Tridiagonal_
131-
+ Rcond_
132-
+ LeastSquaresSvdDivideConquer_
126+
OperatorNorm_ + Solve_ + Solveh_ + Cholesky_ + Triangular_ + Tridiagonal_ + Rcond_
133127
{
134128
/// Compute right eigenvalue and eigenvectors for a general matrix
135129
fn eig(
@@ -172,6 +166,22 @@ pub trait Lapack:
172166

173167
/// Compute singular value decomposition (SVD) with divide-and-conquer algorithm
174168
fn svddc(layout: MatrixLayout, jobz: JobSvd, a: &mut [Self]) -> Result<SvdOwned<Self>>;
169+
170+
/// Compute a vector $x$ which minimizes Euclidian norm $\| Ax - b\|$
171+
/// for a given matrix $A$ and a vector $b$.
172+
fn least_squares(
173+
a_layout: MatrixLayout,
174+
a: &mut [Self],
175+
b: &mut [Self],
176+
) -> Result<LeastSquaresOwned<Self>>;
177+
178+
/// Solve least square problems $\argmin_X \| AX - B\|$
179+
fn least_squares_nrhs(
180+
a_layout: MatrixLayout,
181+
a: &mut [Self],
182+
b_layout: MatrixLayout,
183+
b: &mut [Self],
184+
) -> Result<LeastSquaresOwned<Self>>;
175185
}
176186

177187
macro_rules! impl_lapack {
@@ -247,6 +257,26 @@ macro_rules! impl_lapack {
247257
let work = SvdDcWork::<$s>::new(layout, jobz)?;
248258
work.eval(a)
249259
}
260+
261+
fn least_squares(
262+
l: MatrixLayout,
263+
a: &mut [Self],
264+
b: &mut [Self],
265+
) -> Result<LeastSquaresOwned<Self>> {
266+
let b_layout = l.resized(b.len() as i32, 1);
267+
Self::least_squares_nrhs(l, a, b_layout, b)
268+
}
269+
270+
fn least_squares_nrhs(
271+
a_layout: MatrixLayout,
272+
a: &mut [Self],
273+
b_layout: MatrixLayout,
274+
b: &mut [Self],
275+
) -> Result<LeastSquaresOwned<Self>> {
276+
use least_squares::*;
277+
let work = LeastSquaresWork::<$s>::new(a_layout, b_layout)?;
278+
work.eval(a, b)
279+
}
250280
}
251281
};
252282
}

ndarray-linalg/src/least_squares.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ fn compute_residual_scalar<E: Scalar, D: Data<Elem = E>>(
340340
/// valid representation for `ArrayBase` (over `E`).
341341
impl<E, D1, D2> LeastSquaresSvdInPlace<D2, E, Ix2> for ArrayBase<D1, Ix2>
342342
where
343-
E: Scalar + Lapack + LeastSquaresSvdDivideConquer_,
343+
E: Scalar + Lapack,
344344
D1: DataMut<Elem = E>,
345345
D2: DataMut<Elem = E>,
346346
{

0 commit comments

Comments
 (0)