Skip to content

Commit f9f16e2

Browse files
committed
Add SvdWork<T>
1 parent b74552e commit f9f16e2

File tree

1 file changed

+301
-0
lines changed

1 file changed

+301
-0
lines changed

lax/src/svd.rs

Lines changed: 301 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,307 @@ pub trait SVD_: Scalar {
3030
-> Result<SVDOutput<Self>>;
3131
}
3232

33+
pub struct SvdWork<T: Scalar> {
34+
pub ju: JobSvd,
35+
pub jvt: JobSvd,
36+
pub layout: MatrixLayout,
37+
pub s: Vec<MaybeUninit<T::Real>>,
38+
pub u: Option<Vec<MaybeUninit<T>>>,
39+
pub vt: Option<Vec<MaybeUninit<T>>>,
40+
pub work: Vec<MaybeUninit<T>>,
41+
pub rwork: Option<Vec<MaybeUninit<T::Real>>>,
42+
}
43+
44+
#[derive(Debug, Clone)]
45+
pub struct SvdRef<'work, T: Scalar> {
46+
pub s: &'work [T::Real],
47+
pub u: Option<&'work [T]>,
48+
pub vt: Option<&'work [T]>,
49+
}
50+
51+
#[derive(Debug, Clone)]
52+
pub struct SvdOwned<T: Scalar> {
53+
pub s: Vec<T::Real>,
54+
pub u: Option<Vec<T>>,
55+
pub vt: Option<Vec<T>>,
56+
}
57+
58+
pub trait SvdWorkImpl: Sized {
59+
type Elem: Scalar;
60+
fn new(layout: MatrixLayout, calc_u: bool, calc_vt: bool) -> Result<Self>;
61+
fn calc(&mut self, a: &mut [Self::Elem]) -> Result<SvdRef<Self::Elem>>;
62+
fn eval(self, a: &mut [Self::Elem]) -> Result<SvdOwned<Self::Elem>>;
63+
}
64+
65+
macro_rules! impl_svd_work_c {
66+
($s:ty, $svd:path) => {
67+
impl SvdWorkImpl for SvdWork<$s> {
68+
type Elem = $s;
69+
70+
fn new(layout: MatrixLayout, calc_u: bool, calc_vt: bool) -> Result<Self> {
71+
let ju = match layout {
72+
MatrixLayout::F { .. } => JobSvd::from_bool(calc_u),
73+
MatrixLayout::C { .. } => JobSvd::from_bool(calc_vt),
74+
};
75+
let jvt = match layout {
76+
MatrixLayout::F { .. } => JobSvd::from_bool(calc_vt),
77+
MatrixLayout::C { .. } => JobSvd::from_bool(calc_u),
78+
};
79+
80+
let m = layout.lda();
81+
let mut u = match ju {
82+
JobSvd::All => Some(vec_uninit((m * m) as usize)),
83+
JobSvd::None => None,
84+
_ => unimplemented!("SVD with partial vector output is not supported yet"),
85+
};
86+
87+
let n = layout.len();
88+
let mut vt = match jvt {
89+
JobSvd::All => Some(vec_uninit((n * n) as usize)),
90+
JobSvd::None => None,
91+
_ => unimplemented!("SVD with partial vector output is not supported yet"),
92+
};
93+
94+
let k = std::cmp::min(m, n);
95+
let mut s = vec_uninit(k as usize);
96+
let mut rwork = vec_uninit(5 * k as usize);
97+
98+
// eval work size
99+
let mut info = 0;
100+
let mut work_size = [Self::Elem::zero()];
101+
unsafe {
102+
$svd(
103+
ju.as_ptr(),
104+
jvt.as_ptr(),
105+
&m,
106+
&n,
107+
std::ptr::null_mut(),
108+
&m,
109+
AsPtr::as_mut_ptr(&mut s),
110+
AsPtr::as_mut_ptr(u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut [])),
111+
&m,
112+
AsPtr::as_mut_ptr(vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut [])),
113+
&n,
114+
AsPtr::as_mut_ptr(&mut work_size),
115+
&(-1),
116+
AsPtr::as_mut_ptr(&mut rwork),
117+
&mut info,
118+
);
119+
}
120+
info.as_lapack_result()?;
121+
let lwork = work_size[0].to_usize().unwrap();
122+
let work = vec_uninit(lwork);
123+
Ok(SvdWork {
124+
layout,
125+
ju,
126+
jvt,
127+
s,
128+
u,
129+
vt,
130+
work,
131+
rwork: Some(rwork),
132+
})
133+
}
134+
135+
fn calc(&mut self, a: &mut [Self::Elem]) -> Result<SvdRef<Self::Elem>> {
136+
let m = self.layout.lda();
137+
let n = self.layout.len();
138+
let lwork = self.work.len().to_i32().unwrap();
139+
140+
let mut info = 0;
141+
unsafe {
142+
$svd(
143+
self.ju.as_ptr(),
144+
self.jvt.as_ptr(),
145+
&m,
146+
&n,
147+
AsPtr::as_mut_ptr(a),
148+
&m,
149+
AsPtr::as_mut_ptr(&mut self.s),
150+
AsPtr::as_mut_ptr(
151+
self.u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []),
152+
),
153+
&m,
154+
AsPtr::as_mut_ptr(
155+
self.vt
156+
.as_mut()
157+
.map(|x| x.as_mut_slice())
158+
.unwrap_or(&mut []),
159+
),
160+
&n,
161+
AsPtr::as_mut_ptr(&mut self.work),
162+
&(lwork as i32),
163+
AsPtr::as_mut_ptr(self.rwork.as_mut().unwrap()),
164+
&mut info,
165+
);
166+
}
167+
info.as_lapack_result()?;
168+
169+
let s = unsafe { self.s.slice_assume_init_ref() };
170+
let u = self
171+
.u
172+
.as_ref()
173+
.map(|v| unsafe { v.slice_assume_init_ref() });
174+
let vt = self
175+
.vt
176+
.as_ref()
177+
.map(|v| unsafe { v.slice_assume_init_ref() });
178+
179+
match self.layout {
180+
MatrixLayout::F { .. } => Ok(SvdRef { s, u, vt }),
181+
MatrixLayout::C { .. } => Ok(SvdRef { s, u: vt, vt: u }),
182+
}
183+
}
184+
185+
fn eval(mut self, a: &mut [Self::Elem]) -> Result<SvdOwned<Self::Elem>> {
186+
let _ref = self.calc(a)?;
187+
let s = unsafe { self.s.assume_init() };
188+
let u = self.u.map(|v| unsafe { v.assume_init() });
189+
let vt = self.vt.map(|v| unsafe { v.assume_init() });
190+
match self.layout {
191+
MatrixLayout::F { .. } => Ok(SvdOwned { s, u, vt }),
192+
MatrixLayout::C { .. } => Ok(SvdOwned { s, u: vt, vt: u }),
193+
}
194+
}
195+
}
196+
};
197+
}
198+
impl_svd_work_c!(c64, lapack_sys::zgesvd_);
199+
impl_svd_work_c!(c32, lapack_sys::cgesvd_);
200+
201+
macro_rules! impl_svd_work_r {
202+
($s:ty, $svd:path) => {
203+
impl SvdWorkImpl for SvdWork<$s> {
204+
type Elem = $s;
205+
206+
fn new(layout: MatrixLayout, calc_u: bool, calc_vt: bool) -> Result<Self> {
207+
let ju = match layout {
208+
MatrixLayout::F { .. } => JobSvd::from_bool(calc_u),
209+
MatrixLayout::C { .. } => JobSvd::from_bool(calc_vt),
210+
};
211+
let jvt = match layout {
212+
MatrixLayout::F { .. } => JobSvd::from_bool(calc_vt),
213+
MatrixLayout::C { .. } => JobSvd::from_bool(calc_u),
214+
};
215+
216+
let m = layout.lda();
217+
let mut u = match ju {
218+
JobSvd::All => Some(vec_uninit((m * m) as usize)),
219+
JobSvd::None => None,
220+
_ => unimplemented!("SVD with partial vector output is not supported yet"),
221+
};
222+
223+
let n = layout.len();
224+
let mut vt = match jvt {
225+
JobSvd::All => Some(vec_uninit((n * n) as usize)),
226+
JobSvd::None => None,
227+
_ => unimplemented!("SVD with partial vector output is not supported yet"),
228+
};
229+
230+
let k = std::cmp::min(m, n);
231+
let mut s = vec_uninit(k as usize);
232+
233+
// eval work size
234+
let mut info = 0;
235+
let mut work_size = [Self::Elem::zero()];
236+
unsafe {
237+
$svd(
238+
ju.as_ptr(),
239+
jvt.as_ptr(),
240+
&m,
241+
&n,
242+
std::ptr::null_mut(),
243+
&m,
244+
AsPtr::as_mut_ptr(&mut s),
245+
AsPtr::as_mut_ptr(u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut [])),
246+
&m,
247+
AsPtr::as_mut_ptr(vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut [])),
248+
&n,
249+
AsPtr::as_mut_ptr(&mut work_size),
250+
&(-1),
251+
&mut info,
252+
);
253+
}
254+
info.as_lapack_result()?;
255+
let lwork = work_size[0].to_usize().unwrap();
256+
let work = vec_uninit(lwork);
257+
Ok(SvdWork {
258+
layout,
259+
ju,
260+
jvt,
261+
s,
262+
u,
263+
vt,
264+
work,
265+
rwork: None,
266+
})
267+
}
268+
269+
fn calc(&mut self, a: &mut [Self::Elem]) -> Result<SvdRef<Self::Elem>> {
270+
let m = self.layout.lda();
271+
let n = self.layout.len();
272+
let lwork = self.work.len().to_i32().unwrap();
273+
274+
let mut info = 0;
275+
unsafe {
276+
$svd(
277+
self.ju.as_ptr(),
278+
self.jvt.as_ptr(),
279+
&m,
280+
&n,
281+
AsPtr::as_mut_ptr(a),
282+
&m,
283+
AsPtr::as_mut_ptr(&mut self.s),
284+
AsPtr::as_mut_ptr(
285+
self.u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut []),
286+
),
287+
&m,
288+
AsPtr::as_mut_ptr(
289+
self.vt
290+
.as_mut()
291+
.map(|x| x.as_mut_slice())
292+
.unwrap_or(&mut []),
293+
),
294+
&n,
295+
AsPtr::as_mut_ptr(&mut self.work),
296+
&(lwork as i32),
297+
&mut info,
298+
);
299+
}
300+
info.as_lapack_result()?;
301+
302+
let s = unsafe { self.s.slice_assume_init_ref() };
303+
let u = self
304+
.u
305+
.as_ref()
306+
.map(|v| unsafe { v.slice_assume_init_ref() });
307+
let vt = self
308+
.vt
309+
.as_ref()
310+
.map(|v| unsafe { v.slice_assume_init_ref() });
311+
312+
match self.layout {
313+
MatrixLayout::F { .. } => Ok(SvdRef { s, u, vt }),
314+
MatrixLayout::C { .. } => Ok(SvdRef { s, u: vt, vt: u }),
315+
}
316+
}
317+
318+
fn eval(mut self, a: &mut [Self::Elem]) -> Result<SvdOwned<Self::Elem>> {
319+
let _ref = self.calc(a)?;
320+
let s = unsafe { self.s.assume_init() };
321+
let u = self.u.map(|v| unsafe { v.assume_init() });
322+
let vt = self.vt.map(|v| unsafe { v.assume_init() });
323+
match self.layout {
324+
MatrixLayout::F { .. } => Ok(SvdOwned { s, u, vt }),
325+
MatrixLayout::C { .. } => Ok(SvdOwned { s, u: vt, vt: u }),
326+
}
327+
}
328+
}
329+
};
330+
}
331+
impl_svd_work_r!(f64, lapack_sys::dgesvd_);
332+
impl_svd_work_r!(f32, lapack_sys::sgesvd_);
333+
33334
macro_rules! impl_svd {
34335
(@real, $scalar:ty, $gesvd:path) => {
35336
impl_svd!(@body, $scalar, $gesvd, );

0 commit comments

Comments
 (0)