Skip to content

Commit b461e81

Browse files
Simplify and expand test ranges in test_normal_horseshoe_sampler
1 parent fe5190e commit b461e81

File tree

1 file changed

+5
-6
lines changed

1 file changed

+5
-6
lines changed

tests/test_gibbs.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -86,13 +86,13 @@ def test_horseshoe_match(srng):
8686

8787

8888
@pytest.mark.parametrize(
89-
"N, p, nonzero_atol",
89+
"N, p, rtol",
9090
[
91-
(50, 10, np.array([1.0, 0.5, 0.5, 3e-1, 3e-1])),
92-
(50, 55, np.array([1.5, 0.5, 0.5, 0.75, 3e-1])),
91+
(50, 10, 0.5),
92+
(50, 75, 0.5),
9393
],
9494
)
95-
def test_normal_horseshoe_sampler(srng, N, p, nonzero_atol):
95+
def test_normal_horseshoe_sampler(srng, N, p, rtol):
9696
"""Check the results of a normal regression model with a Horseshoe prior.
9797
9898
This test example is modified from section 3.2 of Makalic & Schmidt (2016)
@@ -131,8 +131,7 @@ def test_normal_horseshoe_sampler(srng, N, p, nonzero_atol):
131131
assert np.all(lambda_post_val >= 0)
132132

133133
beta_post_median = np.median(beta_post_vals[100::2], axis=0)
134-
assert np.allclose(beta_post_median[:5], true_beta[:5], atol=nonzero_atol)
135-
assert np.all(np.abs(beta_post_median[5:]) < 1)
134+
assert np.allclose(beta_post_median[:5], true_beta[:5], atol=1e-1, rtol=rtol)
136135

137136

138137
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)