Skip to content

Commit b864638

Browse files
committed
EighWork<c64>
1 parent c953001 commit b864638

File tree

1 file changed

+92
-0
lines changed

1 file changed

+92
-0
lines changed

lax/src/eigh.rs

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,98 @@ pub trait Eigh_: Scalar {
2323
) -> Result<Vec<Self::Real>>;
2424
}
2525

26+
pub struct EighWork<T: Scalar> {
27+
pub n: i32,
28+
pub jobz: JobEv,
29+
pub eigs: Vec<MaybeUninit<T::Real>>,
30+
pub work: Vec<MaybeUninit<T>>,
31+
pub rwork: Option<Vec<MaybeUninit<T::Real>>>,
32+
}
33+
34+
pub trait EighWorkImpl: Sized {
35+
type Elem: Scalar;
36+
fn new(calc_eigenvectors: bool, layout: MatrixLayout) -> Result<Self>;
37+
fn calc(&mut self, uplo: UPLO, a: &mut [Self::Elem])
38+
-> Result<&[<Self::Elem as Scalar>::Real]>;
39+
fn eval(self, uplo: UPLO, a: &mut [Self::Elem]) -> Result<Vec<<Self::Elem as Scalar>::Real>>;
40+
}
41+
42+
impl EighWorkImpl for EighWork<c64> {
43+
type Elem = c64;
44+
45+
fn new(calc_eigenvectors: bool, layout: MatrixLayout) -> Result<Self> {
46+
assert_eq!(layout.len(), layout.lda());
47+
let n = layout.len();
48+
let jobz = if calc_eigenvectors {
49+
JobEv::All
50+
} else {
51+
JobEv::None
52+
};
53+
let mut eigs = vec_uninit(n as usize);
54+
let mut rwork = vec_uninit(3 * n as usize - 2 as usize);
55+
let mut info = 0;
56+
let mut work_size = [c64::zero()];
57+
unsafe {
58+
lapack_sys::zheev_(
59+
jobz.as_ptr(),
60+
UPLO::Upper.as_ptr(), // dummy, working memory is not affected by UPLO
61+
&n,
62+
std::ptr::null_mut(),
63+
&n,
64+
AsPtr::as_mut_ptr(&mut eigs),
65+
AsPtr::as_mut_ptr(&mut work_size),
66+
&(-1),
67+
AsPtr::as_mut_ptr(&mut rwork),
68+
&mut info,
69+
);
70+
}
71+
info.as_lapack_result()?;
72+
let lwork = work_size[0].to_usize().unwrap();
73+
let work = vec_uninit(lwork);
74+
Ok(EighWork {
75+
n,
76+
eigs,
77+
jobz,
78+
work,
79+
rwork: Some(rwork),
80+
})
81+
}
82+
83+
fn calc(
84+
&mut self,
85+
uplo: UPLO,
86+
a: &mut [Self::Elem],
87+
) -> Result<&[<Self::Elem as Scalar>::Real]> {
88+
let lwork = self.work.len().to_i32().unwrap();
89+
let mut info = 0;
90+
unsafe {
91+
lapack_sys::zheev_(
92+
self.jobz.as_ptr(),
93+
uplo.as_ptr(),
94+
&self.n,
95+
AsPtr::as_mut_ptr(a),
96+
&self.n,
97+
AsPtr::as_mut_ptr(&mut self.eigs),
98+
AsPtr::as_mut_ptr(&mut self.work),
99+
&lwork,
100+
AsPtr::as_mut_ptr(self.rwork.as_mut().unwrap()),
101+
&mut info,
102+
);
103+
}
104+
info.as_lapack_result()?;
105+
Ok(unsafe { self.eigs.slice_assume_init_ref() })
106+
}
107+
108+
fn eval(
109+
mut self,
110+
uplo: UPLO,
111+
a: &mut [Self::Elem],
112+
) -> Result<Vec<<Self::Elem as Scalar>::Real>> {
113+
let _eig = self.calc(uplo, a)?;
114+
Ok(unsafe { self.eigs.assume_init() })
115+
}
116+
}
117+
26118
macro_rules! impl_eigh {
27119
(@real, $scalar:ty, $ev:path) => {
28120
impl_eigh!(@body, $scalar, $ev, );

0 commit comments

Comments
 (0)