@@ -3,6 +3,12 @@ use ndarray_linalg::*;
3
3
use std:: cmp:: min;
4
4
5
5
fn test ( a : & Array2 < f64 > , n : usize , m : usize ) {
6
+ test_both ( a, n, m) ;
7
+ test_u ( a, n, m) ;
8
+ test_vt ( a, n, m) ;
9
+ }
10
+
11
+ fn test_both ( a : & Array2 < f64 > , n : usize , m : usize ) {
6
12
let answer = a. clone ( ) ;
7
13
println ! ( "a = \n {:?}" , a) ;
8
14
let ( u, s, vt) : ( _ , Array1 < _ > , _ ) = a. svd ( true , true ) . unwrap ( ) ;
@@ -18,6 +24,26 @@ fn test(a: &Array2<f64>, n: usize, m: usize) {
18
24
assert_close_l2 ! ( & u. dot( & sm) . dot( & vt) , & answer, 1e-7 ) ;
19
25
}
20
26
27
+ fn test_u ( a : & Array2 < f64 > , n : usize , _m : usize ) {
28
+ println ! ( "a = \n {:?}" , a) ;
29
+ let ( u, _s, vt) : ( _ , Array1 < _ > , _ ) = a. svd ( true , false ) . unwrap ( ) ;
30
+ assert ! ( u. is_some( ) ) ;
31
+ assert ! ( vt. is_none( ) ) ;
32
+ let u = u. unwrap ( ) ;
33
+ assert_eq ! ( u. dim( ) . 0 , n) ;
34
+ assert_eq ! ( u. dim( ) . 1 , n) ;
35
+ }
36
+
37
+ fn test_vt ( a : & Array2 < f64 > , _n : usize , m : usize ) {
38
+ println ! ( "a = \n {:?}" , a) ;
39
+ let ( u, _s, vt) : ( _ , Array1 < _ > , _ ) = a. svd ( false , true ) . unwrap ( ) ;
40
+ assert ! ( u. is_none( ) ) ;
41
+ assert ! ( vt. is_some( ) ) ;
42
+ let vt = vt. unwrap ( ) ;
43
+ assert_eq ! ( vt. dim( ) . 0 , m) ;
44
+ assert_eq ! ( vt. dim( ) . 1 , m) ;
45
+ }
46
+
21
47
#[ test]
22
48
fn svd_square ( ) {
23
49
let a = random ( ( 3 , 3 ) ) ;
0 commit comments