Skip to content

Commit 866eae8

Browse files
committed
Replace operator::Operator* by operator::LinearOperator
1 parent 316e284 commit 866eae8

File tree

6 files changed

+76
-241
lines changed

6 files changed

+76
-241
lines changed

src/diagonal.rs

Lines changed: 10 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
33
use ndarray::*;
44

5-
use super::convert::*;
65
use super::operator::*;
6+
use super::types::*;
77

88
/// Vector as a Diagonal matrix
99
pub struct Diagonal<S: Data> {
@@ -30,81 +30,19 @@ impl<A, S: Data<Elem = A>> AsDiagonal<A> for ArrayBase<S, Ix1> {
3030
}
3131
}
3232

33-
impl<A, S, Sr> OperatorInplace<Sr, Ix1> for Diagonal<S>
33+
impl<A, Sa> LinearOperator for Diagonal<Sa>
3434
where
35-
A: LinalgScalar,
36-
S: Data<Elem = A>,
37-
Sr: DataMut<Elem = A>,
35+
A: Scalar,
36+
Sa: Data<Elem = A>,
3837
{
39-
fn op_inplace<'a>(&self, a: &'a mut ArrayBase<Sr, Ix1>) -> &'a mut ArrayBase<Sr, Ix1> {
38+
type Elem = A;
39+
40+
fn apply_mut<S>(&self, a: &mut ArrayBase<S, Ix1>)
41+
where
42+
S: DataMut<Elem = A>,
43+
{
4044
for (val, d) in a.iter_mut().zip(self.diag.iter()) {
4145
*val = *val * *d;
4246
}
43-
a
44-
}
45-
}
46-
47-
impl<A, S, Sr> Operator<A, Sr, Ix1> for Diagonal<S>
48-
where
49-
A: LinalgScalar,
50-
S: Data<Elem = A>,
51-
Sr: Data<Elem = A>,
52-
{
53-
fn op(&self, a: &ArrayBase<Sr, Ix1>) -> Array1<A> {
54-
let mut a = replicate(a);
55-
self.op_inplace(&mut a);
56-
a
57-
}
58-
}
59-
60-
impl<A, S, Sr> OperatorInto<Sr, Ix1> for Diagonal<S>
61-
where
62-
A: LinalgScalar,
63-
S: Data<Elem = A>,
64-
Sr: DataOwned<Elem = A> + DataMut,
65-
{
66-
fn op_into(&self, mut a: ArrayBase<Sr, Ix1>) -> ArrayBase<Sr, Ix1> {
67-
self.op_inplace(&mut a);
68-
a
69-
}
70-
}
71-
72-
impl<A, S, Sr> OperatorInplace<Sr, Ix2> for Diagonal<S>
73-
where
74-
A: LinalgScalar,
75-
S: Data<Elem = A>,
76-
Sr: DataMut<Elem = A>,
77-
{
78-
fn op_inplace<'a>(&self, a: &'a mut ArrayBase<Sr, Ix2>) -> &'a mut ArrayBase<Sr, Ix2> {
79-
let d = &self.diag;
80-
for ((i, _), val) in a.indexed_iter_mut() {
81-
*val = *val * d[i];
82-
}
83-
a
84-
}
85-
}
86-
87-
impl<A, S, Sr> Operator<A, Sr, Ix2> for Diagonal<S>
88-
where
89-
A: LinalgScalar,
90-
S: Data<Elem = A>,
91-
Sr: Data<Elem = A>,
92-
{
93-
fn op(&self, a: &ArrayBase<Sr, Ix2>) -> Array2<A> {
94-
let mut a = replicate(a);
95-
self.op_inplace(&mut a);
96-
a
97-
}
98-
}
99-
100-
impl<A, S, Sr> OperatorInto<Sr, Ix2> for Diagonal<S>
101-
where
102-
A: LinalgScalar,
103-
S: Data<Elem = A>,
104-
Sr: DataOwned<Elem = A> + DataMut,
105-
{
106-
fn op_into(&self, mut a: ArrayBase<Sr, Ix2>) -> ArrayBase<Sr, Ix2> {
107-
self.op_inplace(&mut a);
108-
a
10947
}
11048
}

src/eigh.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use ndarray::*;
55
use crate::diagonal::*;
66
use crate::error::*;
77
use crate::layout::*;
8-
use crate::operator::Operator;
8+
use crate::operator::LinearOperator;
99
use crate::types::*;
1010
use crate::UPLO;
1111

@@ -165,7 +165,7 @@ where
165165
fn ssqrt_into(self, uplo: UPLO) -> Result<Self::Output> {
166166
let (e, v) = self.eigh_into(uplo)?;
167167
let e_sqrt = Array1::from_iter(e.iter().map(|r| Scalar::from_real(r.sqrt())));
168-
let ev = e_sqrt.into_diagonal().op(&v.t());
169-
Ok(v.op(&ev))
168+
let ev = e_sqrt.into_diagonal().apply2(&v.t());
169+
Ok(v.apply2(&ev))
170170
}
171171
}

src/krylov/mod.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ use ndarray::*;
66
pub mod arnoldi;
77
pub mod householder;
88
pub mod mgs;
9-
pub mod operator;
109

1110
pub use arnoldi::{arnoldi_householder, arnoldi_mgs, Arnoldi};
1211
pub use householder::{householder, Householder};

src/krylov/operator.rs

Lines changed: 0 additions & 77 deletions
This file was deleted.

src/operator.rs

Lines changed: 60 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -1,105 +1,80 @@
1-
//! Linear Operator
1+
//! Linear operator algebra
22
3+
use crate::generate::hstack;
4+
use crate::types::*;
35
use ndarray::*;
46

5-
use super::types::*;
7+
pub trait LinearOperator {
8+
type Elem: Scalar;
69

7-
pub trait Operator<A, S, D>
8-
where
9-
S: Data<Elem = A>,
10-
D: Dimension,
11-
{
12-
fn op(&self, a: &ArrayBase<S, D>) -> Array<A, D>;
13-
}
14-
15-
pub trait OperatorInto<S, D>
16-
where
17-
S: DataMut,
18-
D: Dimension,
19-
{
20-
fn op_into(&self, a: ArrayBase<S, D>) -> ArrayBase<S, D>;
21-
}
22-
23-
pub trait OperatorInplace<S, D>
24-
where
25-
S: DataMut,
26-
D: Dimension,
27-
{
28-
fn op_inplace<'a>(&self, a: &'a mut ArrayBase<S, D>) -> &'a mut ArrayBase<S, D>;
29-
}
10+
/// Apply operator out-place
11+
fn apply<S>(&self, a: &ArrayBase<S, Ix1>) -> Array1<S::Elem>
12+
where
13+
S: Data<Elem = Self::Elem>,
14+
{
15+
let mut a = a.to_owned();
16+
self.apply_mut(&mut a);
17+
a
18+
}
3019

31-
impl<T, A, S, D> Operator<A, S, D> for T
32-
where
33-
A: Scalar + Lapack,
34-
S: Data<Elem = A>,
35-
D: Dimension,
36-
T: linalg::Dot<ArrayBase<S, D>, Output = Array<A, D>>,
37-
{
38-
fn op(&self, rhs: &ArrayBase<S, D>) -> Array<A, D> {
39-
self.dot(rhs)
20+
/// Apply operator in-place
21+
fn apply_mut<S>(&self, a: &mut ArrayBase<S, Ix1>)
22+
where
23+
S: DataMut<Elem = Self::Elem>,
24+
{
25+
let b = self.apply(a);
26+
azip!(mut a(a), b in { *a = b });
4027
}
41-
}
4228

43-
pub trait OperatorMulti<A, S, D>
44-
where
45-
S: Data<Elem = A>,
46-
D: Dimension,
47-
{
48-
fn op_multi(&self, a: &ArrayBase<S, D>) -> Array<A, D>;
49-
}
29+
/// Apply operator with move
30+
fn apply_into<S>(&self, mut a: ArrayBase<S, Ix1>) -> ArrayBase<S, Ix1>
31+
where
32+
S: DataOwned<Elem = Self::Elem> + DataMut,
33+
{
34+
self.apply_mut(&mut a);
35+
a
36+
}
5037

51-
impl<T, A, S, D> OperatorMulti<A, S, D> for T
52-
where
53-
A: Scalar + Lapack,
54-
S: DataMut<Elem = A>,
55-
D: Dimension + RemoveAxis,
56-
for<'a> T: OperatorInplace<ViewRepr<&'a mut A>, D::Smaller>,
57-
{
58-
fn op_multi(&self, a: &ArrayBase<S, D>) -> Array<A, D> {
59-
let a = a.to_owned();
60-
self.op_multi_into(a)
38+
/// Apply operator to matrix out-place
39+
fn apply2<S>(&self, a: &ArrayBase<S, Ix2>) -> Array2<S::Elem>
40+
where
41+
S: Data<Elem = Self::Elem>,
42+
{
43+
let cols: Vec<_> = a.axis_iter(Axis(1)).map(|col| self.apply(&col)).collect();
44+
hstack(&cols).unwrap()
6145
}
62-
}
6346

64-
pub trait OperatorMultiInto<S, D>
65-
where
66-
S: DataMut,
67-
D: Dimension,
68-
{
69-
fn op_multi_into(&self, a: ArrayBase<S, D>) -> ArrayBase<S, D>;
70-
}
47+
/// Apply operator to matrix in-place
48+
fn apply2_mut<S>(&self, a: &mut ArrayBase<S, Ix2>)
49+
where
50+
S: DataMut<Elem = Self::Elem>,
51+
{
52+
for mut col in a.axis_iter_mut(Axis(1)) {
53+
self.apply_mut(&mut col)
54+
}
55+
}
7156

72-
impl<T, A, S, D> OperatorMultiInto<S, D> for T
73-
where
74-
S: DataMut<Elem = A>,
75-
D: Dimension + RemoveAxis,
76-
for<'a> T: OperatorInplace<ViewRepr<&'a mut A>, D::Smaller>,
77-
{
78-
fn op_multi_into(&self, mut a: ArrayBase<S, D>) -> ArrayBase<S, D> {
79-
self.op_multi_inplace(&mut a);
57+
/// Apply operator to matrix with move
58+
fn apply2_into<S>(&self, mut a: ArrayBase<S, Ix2>) -> ArrayBase<S, Ix2>
59+
where
60+
S: DataOwned<Elem = Self::Elem> + DataMut,
61+
{
62+
self.apply2_mut(&mut a);
8063
a
8164
}
8265
}
8366

84-
pub trait OperatorMultiInplace<S, D>
67+
impl<A, Sa> LinearOperator for ArrayBase<Sa, Ix2>
8568
where
86-
S: DataMut,
87-
D: Dimension,
69+
A: Scalar,
70+
Sa: Data<Elem = A>,
8871
{
89-
fn op_multi_inplace<'a>(&self, a: &'a mut ArrayBase<S, D>) -> &'a mut ArrayBase<S, D>;
90-
}
72+
type Elem = A;
9173

92-
impl<T, A, S, D> OperatorMultiInplace<S, D> for T
93-
where
94-
S: DataMut<Elem = A>,
95-
D: Dimension + RemoveAxis,
96-
for<'a> T: OperatorInplace<ViewRepr<&'a mut A>, D::Smaller>,
97-
{
98-
fn op_multi_inplace<'a>(&self, a: &'a mut ArrayBase<S, D>) -> &'a mut ArrayBase<S, D> {
99-
let n = a.ndim();
100-
for mut col in a.axis_iter_mut(Axis(n - 1)) {
101-
self.op_inplace(&mut col);
102-
}
103-
a
74+
fn apply<S>(&self, a: &ArrayBase<S, Ix1>) -> Array1<A>
75+
where
76+
S: Data<Elem = A>,
77+
{
78+
self.dot(a)
10479
}
10580
}

tests/diag.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,15 @@ use ndarray_linalg::*;
55
fn diag_1d() {
66
let d = arr1(&[1.0, 2.0]);
77
let v = arr1(&[1.0, 1.0]);
8-
let dv = d.into_diagonal().op(&v);
8+
let dv = d.into_diagonal().apply(&v);
99
assert_close_l2!(&dv, &arr1(&[1.0, 2.0]), 1e-7);
1010
}
1111

1212
#[test]
1313
fn diag_2d() {
1414
let d = arr1(&[1.0, 2.0]);
1515
let m = arr2(&[[1.0, 1.0], [1.0, 1.0]]);
16-
let dm = d.into_diagonal().op(&m);
16+
let dm = d.into_diagonal().apply2(&m);
1717
println!("dm = {:?}", dm);
1818
assert_close_l2!(&dm, &arr2(&[[1.0, 1.0], [2.0, 2.0]]), 1e-7);
1919
}
@@ -22,7 +22,7 @@ fn diag_2d() {
2222
fn diag_2d_multi() {
2323
let d = arr1(&[1.0, 2.0]);
2424
let m = arr2(&[[1.0, 1.0], [1.0, 1.0]]);
25-
let dm = d.into_diagonal().op_multi_into(m);
25+
let dm = d.into_diagonal().apply2_into(m);
2626
println!("dm = {:?}", dm);
2727
assert_close_l2!(&dm, &arr2(&[[1.0, 1.0], [2.0, 2.0]]), 1e-7);
2828
}

0 commit comments

Comments
 (0)