Skip to content

Commit fb67133

Browse files
committed
Split primal functions
1 parent b6f64dd commit fb67133

File tree

1 file changed

+52
-10
lines changed

1 file changed

+52
-10
lines changed

src/krylov/householder.rs

Lines changed: 52 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,34 @@
11
use super::*;
22
use crate::{inner::*, norm::*};
3-
use num_traits::Zero;
3+
use num_traits::{One, Zero};
4+
5+
/// Calc a reflactor `w` from a vector `x`
6+
pub fn calc_reflector<A, S>(x: &mut ArrayBase<S, Ix1>)
7+
where
8+
A: Scalar + Lapack,
9+
S: DataMut<Elem = A>,
10+
{
11+
let norm = x.norm_l2();
12+
let alpha = x[0].mul_real(norm / x[0].abs());
13+
x[0] -= alpha;
14+
let inv_rev_norm = A::Real::one() / x.norm_l2();
15+
azip!(mut a(x) in { *a = a.mul_real(inv_rev_norm)});
16+
}
17+
18+
/// Take a reflection using `w`
19+
pub fn reflect<A, S1, S2>(w: &ArrayBase<S1, Ix1>, a: &mut ArrayBase<S2, Ix1>)
20+
where
21+
A: Scalar + Lapack,
22+
S1: Data<Elem = A>,
23+
S2: DataMut<Elem = A>,
24+
{
25+
assert_eq!(w.len(), a.len());
26+
let n = a.len();
27+
let c = A::from(2.0).unwrap() * w.inner(&a);
28+
for l in 0..n {
29+
a[l] -= c * w[l];
30+
}
31+
}
432

533
/// Iterative orthogonalizer using Householder reflection
634
#[derive(Debug, Clone)]
@@ -27,13 +55,7 @@ impl<A: Scalar + Lapack> Householder<A> {
2755
{
2856
assert!(k < self.v.len());
2957
assert_eq!(a.len(), self.dim, "Input array size mismaches to the dimension");
30-
31-
let w = self.v[k].slice(s![k..]);
32-
let mut a_slice = a.slice_mut(s![k..]);
33-
let c = A::from(2.0).unwrap() * w.inner(&a_slice);
34-
for l in 0..self.dim - k {
35-
a_slice[l] -= c * w[l];
36-
}
58+
reflect(&self.v[k].slice(s![k..]), &mut a.slice_mut(s![k..]));
3759
}
3860

3961
/// Take forward reflection `P = P_l ... P_1`
@@ -110,14 +132,15 @@ impl<A: Scalar + Lapack> Orthogonalizer for Householder<A> {
110132
for i in 0..k {
111133
coef[i] = a[i];
112134
}
135+
coef[k] = A::from_real(alpha);
113136
if alpha < rtol {
114137
// linearly dependent
115-
coef[k] = A::from_real(alpha);
116138
return Err(coef);
117139
}
118140

119-
// Add reflector
120141
assert!(k < a.len()); // this must hold because `alpha == 0` if k >= a.len()
142+
143+
// Add reflector
121144
let alpha = if a[k].abs() > Zero::zero() {
122145
a[k].mul_real(alpha / a[k].abs())
123146
} else {
@@ -158,3 +181,22 @@ where
158181
let h = Householder::new(dim);
159182
qr(iter, h, rtol, strategy)
160183
}
184+
185+
#[cfg(test)]
186+
mod tests {
187+
use super::*;
188+
use crate::assert::*;
189+
190+
#[test]
191+
fn check_reflector() {
192+
let mut a = array![c64::new(1.0, 1.0), c64::new(1.0, 0.0), c64::new(0.0, 1.0)];
193+
let mut w = a.clone();
194+
calc_reflector(&mut w);
195+
reflect(&w, &mut a);
196+
close_l2(
197+
&a,
198+
&array![c64::new(2.0.sqrt(), 2.0.sqrt()), c64::zero(), c64::zero()],
199+
1e-9,
200+
);
201+
}
202+
}

0 commit comments

Comments
 (0)