Skip to content

Commit f30931d

Browse files
committed
SvdDcWork and SvdDcWorkImpl
1 parent 3fd8c0b commit f30931d

File tree

2 files changed

+296
-2
lines changed

2 files changed

+296
-2
lines changed

lax/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ pub use self::opnorm::*;
111111
pub use self::rcond::*;
112112
pub use self::solve::*;
113113
pub use self::solveh::*;
114-
pub use self::svd::SvdOwned;
114+
pub use self::svd::{SvdOwned, SvdRef};
115115
pub use self::svddc::*;
116116
pub use self::triangular::*;
117117
pub use self::tridiagonal::*;

lax/src/svddc.rs

Lines changed: 295 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,303 @@ pub trait SVDDC_: Scalar {
1414
/// |:-------|:-------|:-------|:-------|
1515
/// | sgesdd | dgesdd | cgesdd | zgesdd |
1616
///
17-
fn svddc(l: MatrixLayout, jobz: JobSvd, a: &mut [Self]) -> Result<SvdOwned<Self>>;
17+
fn svddc(layout: MatrixLayout, jobz: JobSvd, a: &mut [Self]) -> Result<SvdOwned<Self>>;
1818
}
1919

20+
pub struct SvdDcWork<T: Scalar> {
21+
pub jobz: JobSvd,
22+
pub layout: MatrixLayout,
23+
pub s: Vec<MaybeUninit<T::Real>>,
24+
pub u: Option<Vec<MaybeUninit<T>>>,
25+
pub vt: Option<Vec<MaybeUninit<T>>>,
26+
pub work: Vec<MaybeUninit<T>>,
27+
pub iwork: Vec<MaybeUninit<i32>>,
28+
pub rwork: Option<Vec<MaybeUninit<T::Real>>>,
29+
}
30+
31+
pub trait SvdDcWorkImpl: Sized {
32+
type Elem: Scalar;
33+
fn new(layout: MatrixLayout, jobz: JobSvd) -> Result<Self>;
34+
fn calc(&mut self, a: &mut [Self::Elem]) -> Result<SvdRef<Self::Elem>>;
35+
fn eval(self, a: &mut [Self::Elem]) -> Result<SvdOwned<Self::Elem>>;
36+
}
37+
38+
macro_rules! impl_svd_dc_work_c {
39+
($s:ty, $sdd:path) => {
40+
impl SvdDcWorkImpl for SvdDcWork<$s> {
41+
type Elem = $s;
42+
43+
fn new(layout: MatrixLayout, jobz: JobSvd) -> Result<Self> {
44+
let m = layout.lda();
45+
let n = layout.len();
46+
let k = m.min(n);
47+
let (u_col, vt_row) = match jobz {
48+
JobSvd::All | JobSvd::None => (m, n),
49+
JobSvd::Some => (k, k),
50+
};
51+
52+
let mut s = vec_uninit(k as usize);
53+
let (mut u, mut vt) = match jobz {
54+
JobSvd::All => (
55+
Some(vec_uninit((m * m) as usize)),
56+
Some(vec_uninit((n * n) as usize)),
57+
),
58+
JobSvd::Some => (
59+
Some(vec_uninit((m * u_col) as usize)),
60+
Some(vec_uninit((n * vt_row) as usize)),
61+
),
62+
JobSvd::None => (None, None),
63+
};
64+
let mut iwork = vec_uninit(8 * k as usize);
65+
66+
let mx = n.max(m) as usize;
67+
let mn = n.min(m) as usize;
68+
let lrwork = match jobz {
69+
JobSvd::None => 7 * mn,
70+
_ => std::cmp::max(5 * mn * mn + 5 * mn, 2 * mx * mn + 2 * mn * mn + mn),
71+
};
72+
let mut rwork = vec_uninit(lrwork);
73+
74+
let mut info = 0;
75+
let mut work_size = [Self::Elem::zero()];
76+
unsafe {
77+
$sdd(
78+
jobz.as_ptr(),
79+
&m,
80+
&n,
81+
std::ptr::null_mut(),
82+
&m,
83+
AsPtr::as_mut_ptr(&mut s),
84+
AsPtr::as_mut_ptr(u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut [])),
85+
&m,
86+
AsPtr::as_mut_ptr(vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut [])),
87+
&vt_row,
88+
AsPtr::as_mut_ptr(&mut work_size),
89+
&(-1),
90+
AsPtr::as_mut_ptr(&mut rwork),
91+
AsPtr::as_mut_ptr(&mut iwork),
92+
&mut info,
93+
);
94+
}
95+
info.as_lapack_result()?;
96+
let lwork = work_size[0].to_usize().unwrap();
97+
let work = vec_uninit(lwork);
98+
Ok(SvdDcWork {
99+
layout,
100+
jobz,
101+
iwork,
102+
work,
103+
rwork: Some(rwork),
104+
u,
105+
vt,
106+
s,
107+
})
108+
}
109+
110+
fn calc(&mut self, a: &mut [Self::Elem]) -> Result<SvdRef<Self::Elem>> {
111+
let m = self.layout.lda();
112+
let n = self.layout.len();
113+
let k = m.min(n);
114+
let (_, vt_row) = match self.jobz {
115+
JobSvd::All | JobSvd::None => (m, n),
116+
JobSvd::Some => (k, k),
117+
};
118+
let lwork = self.work.len().to_i32().unwrap();
119+
120+
let mut info = 0;
121+
unsafe {
122+
$sdd(
123+
self.jobz.as_ptr(),
124+
&m,
125+
&n,
126+
AsPtr::as_mut_ptr(a),
127+
&m,
128+
AsPtr::as_mut_ptr(&mut self.s),
129+
AsPtr::as_mut_ptr(
130+
self.u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []),
131+
),
132+
&m,
133+
AsPtr::as_mut_ptr(
134+
self.vt
135+
.as_mut()
136+
.map(|x| x.as_mut_slice())
137+
.unwrap_or(&mut []),
138+
),
139+
&vt_row,
140+
AsPtr::as_mut_ptr(&mut self.work),
141+
&lwork,
142+
AsPtr::as_mut_ptr(self.rwork.as_mut().unwrap()),
143+
AsPtr::as_mut_ptr(&mut self.iwork),
144+
&mut info,
145+
);
146+
}
147+
info.as_lapack_result()?;
148+
149+
let s = unsafe { self.s.slice_assume_init_ref() };
150+
let u = self
151+
.u
152+
.as_ref()
153+
.map(|v| unsafe { v.slice_assume_init_ref() });
154+
let vt = self
155+
.vt
156+
.as_ref()
157+
.map(|v| unsafe { v.slice_assume_init_ref() });
158+
159+
Ok(match self.layout {
160+
MatrixLayout::F { .. } => SvdRef { s, u, vt },
161+
MatrixLayout::C { .. } => SvdRef { s, u: vt, vt: u },
162+
})
163+
}
164+
165+
fn eval(mut self, a: &mut [Self::Elem]) -> Result<SvdOwned<Self::Elem>> {
166+
let _ref = self.calc(a)?;
167+
let s = unsafe { self.s.assume_init() };
168+
let u = self.u.map(|v| unsafe { v.assume_init() });
169+
let vt = self.vt.map(|v| unsafe { v.assume_init() });
170+
Ok(match self.layout {
171+
MatrixLayout::F { .. } => SvdOwned { s, u, vt },
172+
MatrixLayout::C { .. } => SvdOwned { s, u: vt, vt: u },
173+
})
174+
}
175+
}
176+
};
177+
}
178+
impl_svd_dc_work_c!(c64, lapack_sys::zgesdd_);
179+
impl_svd_dc_work_c!(c32, lapack_sys::cgesdd_);
180+
181+
macro_rules! impl_svd_dc_work_r {
182+
($s:ty, $sdd:path) => {
183+
impl SvdDcWorkImpl for SvdDcWork<$s> {
184+
type Elem = $s;
185+
186+
fn new(layout: MatrixLayout, jobz: JobSvd) -> Result<Self> {
187+
let m = layout.lda();
188+
let n = layout.len();
189+
let k = m.min(n);
190+
let (u_col, vt_row) = match jobz {
191+
JobSvd::All | JobSvd::None => (m, n),
192+
JobSvd::Some => (k, k),
193+
};
194+
195+
let mut s = vec_uninit(k as usize);
196+
let (mut u, mut vt) = match jobz {
197+
JobSvd::All => (
198+
Some(vec_uninit((m * m) as usize)),
199+
Some(vec_uninit((n * n) as usize)),
200+
),
201+
JobSvd::Some => (
202+
Some(vec_uninit((m * u_col) as usize)),
203+
Some(vec_uninit((n * vt_row) as usize)),
204+
),
205+
JobSvd::None => (None, None),
206+
};
207+
let mut iwork = vec_uninit(8 * k as usize);
208+
209+
let mut info = 0;
210+
let mut work_size = [Self::Elem::zero()];
211+
unsafe {
212+
$sdd(
213+
jobz.as_ptr(),
214+
&m,
215+
&n,
216+
std::ptr::null_mut(),
217+
&m,
218+
AsPtr::as_mut_ptr(&mut s),
219+
AsPtr::as_mut_ptr(u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut [])),
220+
&m,
221+
AsPtr::as_mut_ptr(vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut [])),
222+
&vt_row,
223+
AsPtr::as_mut_ptr(&mut work_size),
224+
&(-1),
225+
AsPtr::as_mut_ptr(&mut iwork),
226+
&mut info,
227+
);
228+
}
229+
info.as_lapack_result()?;
230+
let lwork = work_size[0].to_usize().unwrap();
231+
let work = vec_uninit(lwork);
232+
Ok(SvdDcWork {
233+
layout,
234+
jobz,
235+
iwork,
236+
work,
237+
rwork: None,
238+
u,
239+
vt,
240+
s,
241+
})
242+
}
243+
244+
fn calc(&mut self, a: &mut [Self::Elem]) -> Result<SvdRef<Self::Elem>> {
245+
let m = self.layout.lda();
246+
let n = self.layout.len();
247+
let k = m.min(n);
248+
let (_, vt_row) = match self.jobz {
249+
JobSvd::All | JobSvd::None => (m, n),
250+
JobSvd::Some => (k, k),
251+
};
252+
let lwork = self.work.len().to_i32().unwrap();
253+
254+
let mut info = 0;
255+
unsafe {
256+
$sdd(
257+
self.jobz.as_ptr(),
258+
&m,
259+
&n,
260+
AsPtr::as_mut_ptr(a),
261+
&m,
262+
AsPtr::as_mut_ptr(&mut self.s),
263+
AsPtr::as_mut_ptr(
264+
self.u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []),
265+
),
266+
&m,
267+
AsPtr::as_mut_ptr(
268+
self.vt
269+
.as_mut()
270+
.map(|x| x.as_mut_slice())
271+
.unwrap_or(&mut []),
272+
),
273+
&vt_row,
274+
AsPtr::as_mut_ptr(&mut self.work),
275+
&lwork,
276+
AsPtr::as_mut_ptr(&mut self.iwork),
277+
&mut info,
278+
);
279+
}
280+
info.as_lapack_result()?;
281+
282+
let s = unsafe { self.s.slice_assume_init_ref() };
283+
let u = self
284+
.u
285+
.as_ref()
286+
.map(|v| unsafe { v.slice_assume_init_ref() });
287+
let vt = self
288+
.vt
289+
.as_ref()
290+
.map(|v| unsafe { v.slice_assume_init_ref() });
291+
292+
Ok(match self.layout {
293+
MatrixLayout::F { .. } => SvdRef { s, u, vt },
294+
MatrixLayout::C { .. } => SvdRef { s, u: vt, vt: u },
295+
})
296+
}
297+
298+
fn eval(mut self, a: &mut [Self::Elem]) -> Result<SvdOwned<Self::Elem>> {
299+
let _ref = self.calc(a)?;
300+
let s = unsafe { self.s.assume_init() };
301+
let u = self.u.map(|v| unsafe { v.assume_init() });
302+
let vt = self.vt.map(|v| unsafe { v.assume_init() });
303+
Ok(match self.layout {
304+
MatrixLayout::F { .. } => SvdOwned { s, u, vt },
305+
MatrixLayout::C { .. } => SvdOwned { s, u: vt, vt: u },
306+
})
307+
}
308+
}
309+
};
310+
}
311+
impl_svd_dc_work_r!(f64, lapack_sys::dgesdd_);
312+
impl_svd_dc_work_r!(f32, lapack_sys::sgesdd_);
313+
20314
macro_rules! impl_svddc {
21315
(@real, $scalar:ty, $gesdd:path) => {
22316
impl_svddc!(@body, $scalar, $gesdd, );

0 commit comments

Comments
 (0)