Skip to content

Commit 48c2106

Browse files
committed
Generate tests using paste crate
1 parent 2c787f5 commit 48c2106

File tree

2 files changed

+33
-41
lines changed

2 files changed

+33
-41
lines changed

Cargo.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,6 @@ version = "0.6"
4646
default-features = false
4747
features = ["static"]
4848
optional = true
49+
50+
[dev-dependencies]
51+
paste = "*"

tests/svd.rs

Lines changed: 30 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,8 @@ use ndarray::*;
22
use ndarray_linalg::*;
33
use std::cmp::min;
44

5-
fn test(a: &Array2<f64>, n: usize, m: usize) {
6-
test_both(a, n, m);
7-
test_u(a, n, m);
8-
test_vt(a, n, m);
9-
}
10-
11-
fn test_both(a: &Array2<f64>, n: usize, m: usize) {
5+
fn test(a: &Array2<f64>) {
6+
let (n, m) = a.dim();
127
let answer = a.clone();
138
println!("a = \n{:?}", a);
149
let (u, s, vt): (_, Array1<_>, _) = a.svd(true, true).unwrap();
@@ -24,7 +19,8 @@ fn test_both(a: &Array2<f64>, n: usize, m: usize) {
2419
assert_close_l2!(&u.dot(&sm).dot(&vt), &answer, 1e-7);
2520
}
2621

27-
fn test_u(a: &Array2<f64>, n: usize, _m: usize) {
22+
fn test_u(a: &Array2<f64>) {
23+
let (n, _m) = a.dim();
2824
println!("a = \n{:?}", a);
2925
let (u, _s, vt): (_, Array1<_>, _) = a.svd(true, false).unwrap();
3026
assert!(u.is_some());
@@ -34,7 +30,8 @@ fn test_u(a: &Array2<f64>, n: usize, _m: usize) {
3430
assert_eq!(u.dim().1, n);
3531
}
3632

37-
fn test_vt(a: &Array2<f64>, _n: usize, m: usize) {
33+
fn test_vt(a: &Array2<f64>) {
34+
let (_n, m) = a.dim();
3835
println!("a = \n{:?}", a);
3936
let (u, _s, vt): (_, Array1<_>, _) = a.svd(false, true).unwrap();
4037
assert!(u.is_none());
@@ -44,38 +41,30 @@ fn test_vt(a: &Array2<f64>, _n: usize, m: usize) {
4441
assert_eq!(vt.dim().1, m);
4542
}
4643

47-
#[test]
48-
fn svd_square() {
49-
let a = random((3, 3));
50-
test(&a, 3, 3);
51-
}
52-
53-
#[test]
54-
fn svd_square_t() {
55-
let a = random((3, 3).f());
56-
test(&a, 3, 3);
57-
}
58-
59-
#[test]
60-
fn svd_3x4() {
61-
let a = random((3, 4));
62-
test(&a, 3, 4);
63-
}
44+
macro_rules! test_svd_impl {
45+
($test:ident, $n:expr, $m:expr) => {
46+
paste::item! {
47+
#[test]
48+
fn [<svd_ $test _ $n x $m>]() {
49+
let a = random(($n, $m));
50+
$test(&a);
51+
}
6452

65-
#[test]
66-
fn svd_3x4_t() {
67-
let a = random((3, 4).f());
68-
test(&a, 3, 4);
53+
#[test]
54+
fn [<svd_ $test _ $n x $m _t>]() {
55+
let a = random(($n, $m).f());
56+
$test(&a);
57+
}
58+
}
59+
};
6960
}
7061

71-
#[test]
72-
fn svd_4x3() {
73-
let a = random((4, 3));
74-
test(&a, 4, 3);
75-
}
76-
77-
#[test]
78-
fn svd_4x3_t() {
79-
let a = random((4, 3).f());
80-
test(&a, 4, 3);
81-
}
62+
test_svd_impl!(test, 3, 3);
63+
test_svd_impl!(test_u, 3, 3);
64+
test_svd_impl!(test_vt, 3, 3);
65+
test_svd_impl!(test, 4, 3);
66+
test_svd_impl!(test_u, 4, 3);
67+
test_svd_impl!(test_vt, 4, 3);
68+
test_svd_impl!(test, 3, 4);
69+
test_svd_impl!(test_u, 3, 4);
70+
test_svd_impl!(test_vt, 3, 4);

0 commit comments

Comments
 (0)