@@ -2,13 +2,8 @@ use ndarray::*;
2
2
use ndarray_linalg:: * ;
3
3
use std:: cmp:: min;
4
4
5
- fn test ( a : & Array2 < f64 > , n : usize , m : usize ) {
6
- test_both ( a, n, m) ;
7
- test_u ( a, n, m) ;
8
- test_vt ( a, n, m) ;
9
- }
10
-
11
- fn test_both ( a : & Array2 < f64 > , n : usize , m : usize ) {
5
+ fn test ( a : & Array2 < f64 > ) {
6
+ let ( n, m) = a. dim ( ) ;
12
7
let answer = a. clone ( ) ;
13
8
println ! ( "a = \n {:?}" , a) ;
14
9
let ( u, s, vt) : ( _ , Array1 < _ > , _ ) = a. svd ( true , true ) . unwrap ( ) ;
@@ -24,7 +19,8 @@ fn test_both(a: &Array2<f64>, n: usize, m: usize) {
24
19
assert_close_l2 ! ( & u. dot( & sm) . dot( & vt) , & answer, 1e-7 ) ;
25
20
}
26
21
27
- fn test_u ( a : & Array2 < f64 > , n : usize , _m : usize ) {
22
+ fn test_u ( a : & Array2 < f64 > ) {
23
+ let ( n, _m) = a. dim ( ) ;
28
24
println ! ( "a = \n {:?}" , a) ;
29
25
let ( u, _s, vt) : ( _ , Array1 < _ > , _ ) = a. svd ( true , false ) . unwrap ( ) ;
30
26
assert ! ( u. is_some( ) ) ;
@@ -34,7 +30,8 @@ fn test_u(a: &Array2<f64>, n: usize, _m: usize) {
34
30
assert_eq ! ( u. dim( ) . 1 , n) ;
35
31
}
36
32
37
- fn test_vt ( a : & Array2 < f64 > , _n : usize , m : usize ) {
33
+ fn test_vt ( a : & Array2 < f64 > ) {
34
+ let ( _n, m) = a. dim ( ) ;
38
35
println ! ( "a = \n {:?}" , a) ;
39
36
let ( u, _s, vt) : ( _ , Array1 < _ > , _ ) = a. svd ( false , true ) . unwrap ( ) ;
40
37
assert ! ( u. is_none( ) ) ;
@@ -44,38 +41,30 @@ fn test_vt(a: &Array2<f64>, _n: usize, m: usize) {
44
41
assert_eq ! ( vt. dim( ) . 1 , m) ;
45
42
}
46
43
47
- #[ test]
48
- fn svd_square ( ) {
49
- let a = random ( ( 3 , 3 ) ) ;
50
- test ( & a, 3 , 3 ) ;
51
- }
52
-
53
- #[ test]
54
- fn svd_square_t ( ) {
55
- let a = random ( ( 3 , 3 ) . f ( ) ) ;
56
- test ( & a, 3 , 3 ) ;
57
- }
58
-
59
- #[ test]
60
- fn svd_3x4 ( ) {
61
- let a = random ( ( 3 , 4 ) ) ;
62
- test ( & a, 3 , 4 ) ;
63
- }
44
+ macro_rules! test_svd_impl {
45
+ ( $test: ident, $n: expr, $m: expr) => {
46
+ paste:: item! {
47
+ #[ test]
48
+ fn [ <svd_ $test _ $n x $m>] ( ) {
49
+ let a = random( ( $n, $m) ) ;
50
+ $test( & a) ;
51
+ }
64
52
65
- #[ test]
66
- fn svd_3x4_t ( ) {
67
- let a = random ( ( 3 , 4 ) . f ( ) ) ;
68
- test ( & a, 3 , 4 ) ;
53
+ #[ test]
54
+ fn [ <svd_ $test _ $n x $m _t>] ( ) {
55
+ let a = random( ( $n, $m) . f( ) ) ;
56
+ $test( & a) ;
57
+ }
58
+ }
59
+ } ;
69
60
}
70
61
71
- #[ test]
72
- fn svd_4x3 ( ) {
73
- let a = random ( ( 4 , 3 ) ) ;
74
- test ( & a, 4 , 3 ) ;
75
- }
76
-
77
- #[ test]
78
- fn svd_4x3_t ( ) {
79
- let a = random ( ( 4 , 3 ) . f ( ) ) ;
80
- test ( & a, 4 , 3 ) ;
81
- }
62
+ test_svd_impl ! ( test, 3 , 3 ) ;
63
+ test_svd_impl ! ( test_u, 3 , 3 ) ;
64
+ test_svd_impl ! ( test_vt, 3 , 3 ) ;
65
+ test_svd_impl ! ( test, 4 , 3 ) ;
66
+ test_svd_impl ! ( test_u, 4 , 3 ) ;
67
+ test_svd_impl ! ( test_vt, 4 , 3 ) ;
68
+ test_svd_impl ! ( test, 3 , 4 ) ;
69
+ test_svd_impl ! ( test_u, 3 , 4 ) ;
70
+ test_svd_impl ! ( test_vt, 3 , 4 ) ;
0 commit comments