Skip to content

Commit 3e541e0

Browse files
Volodymyr OrlovVolodymyr Orlov
authored andcommitted
fix: improves SVD
1 parent a19398f commit 3e541e0

File tree

1 file changed

+28
-0
lines changed

1 file changed

+28
-0
lines changed

src/linalg/svd.rs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,19 @@ pub struct SVD<T: RealNumber, M: SVDDecomposableMatrix<T>> {
5353
tol: T,
5454
}
5555

56+
impl<T: RealNumber, M: SVDDecomposableMatrix<T>> SVD<T, M> {
57+
/// Diagonal matrix with singular values
58+
pub fn S(&self) -> M {
59+
let mut s = M::zeros(self.U.shape().1, self.V.shape().0);
60+
61+
for i in 0..self.s.len() {
62+
s.set(i, i, self.s[i]);
63+
}
64+
65+
s
66+
}
67+
}
68+
5669
/// Trait that implements SVD decomposition routine for any matrix.
5770
pub trait SVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
5871
/// Solves Ax = b. Overrides original matrix in the process.
@@ -711,4 +724,19 @@ mod tests {
711724
let w = a.svd_solve_mut(b).unwrap();
712725
assert!(w.approximate_eq(&expected_w, 1e-2));
713726
}
727+
728+
#[test]
729+
fn decompose_restore() {
730+
let a = DenseMatrix::from_2d_array(&[&[1.0, 2.0, 3.0, 4.0], &[5.0, 6.0, 7.0, 8.0]]);
731+
let svd = a.svd().unwrap();
732+
let u: &DenseMatrix<f32> = &svd.U; //U
733+
let v: &DenseMatrix<f32> = &svd.V; // V
734+
let s: &DenseMatrix<f32> = &svd.S(); // Sigma
735+
736+
let a_hat = u.matmul(s).matmul(&v.transpose());
737+
738+
for (a, a_hat) in a.iter().zip(a_hat.iter()) {
739+
assert!((a - a_hat).abs() < 1e-3)
740+
}
741+
}
714742
}

0 commit comments

Comments
 (0)