@@ -53,6 +53,19 @@ pub struct SVD<T: RealNumber, M: SVDDecomposableMatrix<T>> {
53
53
tol : T ,
54
54
}
55
55
56
+ impl < T : RealNumber , M : SVDDecomposableMatrix < T > > SVD < T , M > {
57
+ /// Diagonal matrix with singular values
58
+ pub fn S ( & self ) -> M {
59
+ let mut s = M :: zeros ( self . U . shape ( ) . 1 , self . V . shape ( ) . 0 ) ;
60
+
61
+ for i in 0 ..self . s . len ( ) {
62
+ s. set ( i, i, self . s [ i] ) ;
63
+ }
64
+
65
+ s
66
+ }
67
+ }
68
+
56
69
/// Trait that implements SVD decomposition routine for any matrix.
57
70
pub trait SVDDecomposableMatrix < T : RealNumber > : BaseMatrix < T > {
58
71
/// Solves Ax = b. Overrides original matrix in the process.
@@ -711,4 +724,19 @@ mod tests {
711
724
let w = a. svd_solve_mut ( b) . unwrap ( ) ;
712
725
assert ! ( w. approximate_eq( & expected_w, 1e-2 ) ) ;
713
726
}
727
+
728
+ #[ test]
729
+ fn decompose_restore ( ) {
730
+ let a = DenseMatrix :: from_2d_array ( & [ & [ 1.0 , 2.0 , 3.0 , 4.0 ] , & [ 5.0 , 6.0 , 7.0 , 8.0 ] ] ) ;
731
+ let svd = a. svd ( ) . unwrap ( ) ;
732
+ let u: & DenseMatrix < f32 > = & svd. U ; //U
733
+ let v: & DenseMatrix < f32 > = & svd. V ; // V
734
+ let s: & DenseMatrix < f32 > = & svd. S ( ) ; // Sigma
735
+
736
+ let a_hat = u. matmul ( s) . matmul ( & v. transpose ( ) ) ;
737
+
738
+ for ( a, a_hat) in a. iter ( ) . zip ( a_hat. iter ( ) ) {
739
+ assert ! ( ( a - a_hat) . abs( ) < 1e-3 )
740
+ }
741
+ }
714
742
}
0 commit comments