Skip to content

Commit 5d07538

Browse files
committed
impl EighWorkImpl for EighWork in c32, f32, f64
1 parent b864638 commit 5d07538

File tree

1 file changed

+157
-72
lines changed

1 file changed

+157
-72
lines changed

lax/src/eigh.rs

Lines changed: 157 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -39,81 +39,166 @@ pub trait EighWorkImpl: Sized {
3939
fn eval(self, uplo: UPLO, a: &mut [Self::Elem]) -> Result<Vec<<Self::Elem as Scalar>::Real>>;
4040
}
4141

42-
impl EighWorkImpl for EighWork<c64> {
43-
type Elem = c64;
44-
45-
fn new(calc_eigenvectors: bool, layout: MatrixLayout) -> Result<Self> {
46-
assert_eq!(layout.len(), layout.lda());
47-
let n = layout.len();
48-
let jobz = if calc_eigenvectors {
49-
JobEv::All
50-
} else {
51-
JobEv::None
52-
};
53-
let mut eigs = vec_uninit(n as usize);
54-
let mut rwork = vec_uninit(3 * n as usize - 2 as usize);
55-
let mut info = 0;
56-
let mut work_size = [c64::zero()];
57-
unsafe {
58-
lapack_sys::zheev_(
59-
jobz.as_ptr(),
60-
UPLO::Upper.as_ptr(), // dummy, working memory is not affected by UPLO
61-
&n,
62-
std::ptr::null_mut(),
63-
&n,
64-
AsPtr::as_mut_ptr(&mut eigs),
65-
AsPtr::as_mut_ptr(&mut work_size),
66-
&(-1),
67-
AsPtr::as_mut_ptr(&mut rwork),
68-
&mut info,
69-
);
70-
}
71-
info.as_lapack_result()?;
72-
let lwork = work_size[0].to_usize().unwrap();
73-
let work = vec_uninit(lwork);
74-
Ok(EighWork {
75-
n,
76-
eigs,
77-
jobz,
78-
work,
79-
rwork: Some(rwork),
80-
})
81-
}
82-
83-
fn calc(
84-
&mut self,
85-
uplo: UPLO,
86-
a: &mut [Self::Elem],
87-
) -> Result<&[<Self::Elem as Scalar>::Real]> {
88-
let lwork = self.work.len().to_i32().unwrap();
89-
let mut info = 0;
90-
unsafe {
91-
lapack_sys::zheev_(
92-
self.jobz.as_ptr(),
93-
uplo.as_ptr(),
94-
&self.n,
95-
AsPtr::as_mut_ptr(a),
96-
&self.n,
97-
AsPtr::as_mut_ptr(&mut self.eigs),
98-
AsPtr::as_mut_ptr(&mut self.work),
99-
&lwork,
100-
AsPtr::as_mut_ptr(self.rwork.as_mut().unwrap()),
101-
&mut info,
102-
);
42+
macro_rules! impl_eigh_work_c {
43+
($c:ty, $ev:path) => {
44+
impl EighWorkImpl for EighWork<$c> {
45+
type Elem = $c;
46+
47+
fn new(calc_eigenvectors: bool, layout: MatrixLayout) -> Result<Self> {
48+
assert_eq!(layout.len(), layout.lda());
49+
let n = layout.len();
50+
let jobz = if calc_eigenvectors {
51+
JobEv::All
52+
} else {
53+
JobEv::None
54+
};
55+
let mut eigs = vec_uninit(n as usize);
56+
let mut rwork = vec_uninit(3 * n as usize - 2 as usize);
57+
let mut info = 0;
58+
let mut work_size = [Self::Elem::zero()];
59+
unsafe {
60+
$ev(
61+
jobz.as_ptr(),
62+
UPLO::Upper.as_ptr(), // dummy, working memory is not affected by UPLO
63+
&n,
64+
std::ptr::null_mut(),
65+
&n,
66+
AsPtr::as_mut_ptr(&mut eigs),
67+
AsPtr::as_mut_ptr(&mut work_size),
68+
&(-1),
69+
AsPtr::as_mut_ptr(&mut rwork),
70+
&mut info,
71+
);
72+
}
73+
info.as_lapack_result()?;
74+
let lwork = work_size[0].to_usize().unwrap();
75+
let work = vec_uninit(lwork);
76+
Ok(EighWork {
77+
n,
78+
eigs,
79+
jobz,
80+
work,
81+
rwork: Some(rwork),
82+
})
83+
}
84+
85+
fn calc(
86+
&mut self,
87+
uplo: UPLO,
88+
a: &mut [Self::Elem],
89+
) -> Result<&[<Self::Elem as Scalar>::Real]> {
90+
let lwork = self.work.len().to_i32().unwrap();
91+
let mut info = 0;
92+
unsafe {
93+
$ev(
94+
self.jobz.as_ptr(),
95+
uplo.as_ptr(),
96+
&self.n,
97+
AsPtr::as_mut_ptr(a),
98+
&self.n,
99+
AsPtr::as_mut_ptr(&mut self.eigs),
100+
AsPtr::as_mut_ptr(&mut self.work),
101+
&lwork,
102+
AsPtr::as_mut_ptr(self.rwork.as_mut().unwrap()),
103+
&mut info,
104+
);
105+
}
106+
info.as_lapack_result()?;
107+
Ok(unsafe { self.eigs.slice_assume_init_ref() })
108+
}
109+
110+
fn eval(
111+
mut self,
112+
uplo: UPLO,
113+
a: &mut [Self::Elem],
114+
) -> Result<Vec<<Self::Elem as Scalar>::Real>> {
115+
let _eig = self.calc(uplo, a)?;
116+
Ok(unsafe { self.eigs.assume_init() })
117+
}
103118
}
104-
info.as_lapack_result()?;
105-
Ok(unsafe { self.eigs.slice_assume_init_ref() })
106-
}
119+
};
120+
}
121+
impl_eigh_work_c!(c64, lapack_sys::zheev_);
122+
impl_eigh_work_c!(c32, lapack_sys::cheev_);
107123

108-
fn eval(
109-
mut self,
110-
uplo: UPLO,
111-
a: &mut [Self::Elem],
112-
) -> Result<Vec<<Self::Elem as Scalar>::Real>> {
113-
let _eig = self.calc(uplo, a)?;
114-
Ok(unsafe { self.eigs.assume_init() })
115-
}
124+
macro_rules! impl_eigh_work_r {
125+
($f:ty, $ev:path) => {
126+
impl EighWorkImpl for EighWork<$f> {
127+
type Elem = $f;
128+
129+
fn new(calc_eigenvectors: bool, layout: MatrixLayout) -> Result<Self> {
130+
assert_eq!(layout.len(), layout.lda());
131+
let n = layout.len();
132+
let jobz = if calc_eigenvectors {
133+
JobEv::All
134+
} else {
135+
JobEv::None
136+
};
137+
let mut eigs = vec_uninit(n as usize);
138+
let mut info = 0;
139+
let mut work_size = [Self::Elem::zero()];
140+
unsafe {
141+
$ev(
142+
jobz.as_ptr(),
143+
UPLO::Upper.as_ptr(), // dummy, working memory is not affected by UPLO
144+
&n,
145+
std::ptr::null_mut(),
146+
&n,
147+
AsPtr::as_mut_ptr(&mut eigs),
148+
AsPtr::as_mut_ptr(&mut work_size),
149+
&(-1),
150+
&mut info,
151+
);
152+
}
153+
info.as_lapack_result()?;
154+
let lwork = work_size[0].to_usize().unwrap();
155+
let work = vec_uninit(lwork);
156+
Ok(EighWork {
157+
n,
158+
eigs,
159+
jobz,
160+
work,
161+
rwork: None,
162+
})
163+
}
164+
165+
fn calc(
166+
&mut self,
167+
uplo: UPLO,
168+
a: &mut [Self::Elem],
169+
) -> Result<&[<Self::Elem as Scalar>::Real]> {
170+
let lwork = self.work.len().to_i32().unwrap();
171+
let mut info = 0;
172+
unsafe {
173+
$ev(
174+
self.jobz.as_ptr(),
175+
uplo.as_ptr(),
176+
&self.n,
177+
AsPtr::as_mut_ptr(a),
178+
&self.n,
179+
AsPtr::as_mut_ptr(&mut self.eigs),
180+
AsPtr::as_mut_ptr(&mut self.work),
181+
&lwork,
182+
&mut info,
183+
);
184+
}
185+
info.as_lapack_result()?;
186+
Ok(unsafe { self.eigs.slice_assume_init_ref() })
187+
}
188+
189+
fn eval(
190+
mut self,
191+
uplo: UPLO,
192+
a: &mut [Self::Elem],
193+
) -> Result<Vec<<Self::Elem as Scalar>::Real>> {
194+
let _eig = self.calc(uplo, a)?;
195+
Ok(unsafe { self.eigs.assume_init() })
196+
}
197+
}
198+
};
116199
}
200+
impl_eigh_work_r!(f64, lapack_sys::dsyev_);
201+
impl_eigh_work_r!(f32, lapack_sys::ssyev_);
117202

118203
macro_rules! impl_eigh {
119204
(@real, $scalar:ty, $ev:path) => {

0 commit comments

Comments
 (0)