Skip to content

Commit 38b31bd

Browse files
committed
improve test coverage
1 parent 863238a commit 38b31bd

11 files changed

+67
-44
lines changed

R/helpers-testthat.R

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,15 @@ expect_gg <- function(x) {
44
}
55
expect_gtable <- function(x) testthat::expect_s3_class(x, "gtable")
66
expect_mcmc_array <- function(x) testthat::expect_true(is_mcmc_array(x))
7-
expect_bayesplot_grid <- function(x) testthat::expect_s3_class(x, "bayesplot_grid")
7+
expect_bayesplot_grid <- function(x) testthat::expect_true(is_bayesplot_grid(x))
8+
9+
10+
# Insert fake divergences for testing purposes
11+
#
12+
# @param np Data frame returned by nuts_params
13+
# @return 'np' with every other iter marked as a divergence
14+
ensure_divergences <- function(np) {
15+
divs <- rep_len(c(0,1), length.out = sum(np$Parameter=="divergent__"))
16+
np$Value[np$Parameter=="divergent__"] <- divs
17+
return(np)
18+
}

R/mcmc-traces.R

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -158,21 +158,18 @@ mcmc_trace <-
158158
divergences = NULL) {
159159

160160
# deprecate 'divergences' arg in favor of 'np' (for consistency across functions)
161-
if (!is.null(divergences)) {
161+
if (!is.null(np) && !is.null(divergences)) {
162+
stop(
163+
"'np' and 'divergences' can't both be specified. ",
164+
"Use only 'np' (the 'divergences' argument is deprecated)."
165+
)
166+
} else if (!is.null(divergences)) {
162167
warning(
163168
"The 'divergences' argument is deprecated ",
164169
"and will be removed in a future release. ",
165170
"Use the 'np' argument instead."
166171
)
167-
168-
if (is.null(np)) {
169-
np <- divergences
170-
} else {
171-
stop(
172-
"'np' and 'divergences' can't both be specified. ",
173-
"Use only 'np' (the 'divergences' argument is deprecated)."
174-
)
175-
}
172+
np <- divergences
176173
}
177174

178175
check_ignored_arguments(...)

tests/testthat/test-aesthetics.R

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@ test_that("color_scheme_set throws correct errors for custom schemes ", {
7575
"not found: not_a_color1, not_a_color2")
7676
expect_error(color_scheme_set(c("red", "blue")),
7777
"should be a character vector of length 1 or 6")
78+
expect_error(prepare_custom_colors(c("red", "blue")),
79+
"Custom color schemes must contain exactly 6 colors")
7880
})
7981

8082
test_that("mixed_scheme internal function doesn't error", {
@@ -180,3 +182,4 @@ test_that("ggplot2::theme_set overrides bayesplot theme", {
180182
})
181183

182184
bayesplot_theme_set(bayesplot::theme_default())
185+
color_scheme_set()

tests/testthat/test-extractors.R

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@ context("Extractors")
55
ITER <- 1000
66
CHAINS <- 3
77
fit <- stan_glm(mpg ~ wt + am, data = mtcars,
8-
iter = ITER, chains = CHAINS, refresh = 0)
8+
iter = ITER, chains = CHAINS,
9+
refresh = 0)
910

1011
x <- list(cbind(a = 1:3, b = rnorm(3)), cbind(a = 1:3, b = rnorm(3)))
1112

@@ -68,6 +69,8 @@ test_that("rhat.stanreg returns correct structure", {
6869
})
6970

7071
test_that("neff_ratio.stanreg returns correct structure", {
72+
expect_named(neff_ratio(fit, pars = c("wt", "am")), c("wt", "am"))
73+
7174
ratio <- neff_ratio(fit)
7275
expect_named(ratio)
7376
ans <- summary(fit)[1:length(ratio), "n_eff"] / (floor(ITER / 2) * CHAINS)

tests/testthat/test-helpers-mcmc.R

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,10 @@ test_that("validate_chain_list works", {
138138
colnames(chainlist2[[1]])[1] <- "AAA"
139139
expect_error(validate_chain_list(chainlist2), "parameters for each chain")
140140

141+
chainlist3 <- chainlist
142+
colnames(chainlist3[[1]]) <- c("", colnames(chainlist[[1]])[-1])
143+
expect_error(validate_chain_list(chainlist3), "Some parameters are missing names")
144+
141145
chainlist[[1]] <- chainlist[[1]][-1, ]
142146
expect_error(validate_chain_list(chainlist),
143147
"Each chain should have the same number of iterations")
@@ -205,6 +209,10 @@ test_that("transformations recycled properly if not a named list", {
205209

206210

207211
# prepare_mcmc_array ------------------------------------------------------
212+
test_that("prepare_mcmc_array errors if NAs", {
213+
arr[1,1,1] <- NA
214+
expect_error(prepare_mcmc_array(arr), "NAs not allowed")
215+
})
208216
test_that("prepare_mcmc_array processes non-array input types correctly", {
209217
# errors are mostly covered by tests of the many internal functions above
210218

tests/testthat/test-mcmc-nuts.R

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,9 @@ context("MCMC: nuts")
44

55
ITER <- 1000
66
CHAINS <- 3
7-
capture.output(
8-
fit <- stan_glm(mpg ~ wt + am, data = mtcars,
9-
iter = ITER, chains = CHAINS, refresh = 0)
10-
)
7+
fit <- stan_glm(mpg ~ wt + am, data = mtcars,
8+
iter = ITER, chains = CHAINS,
9+
refresh = 0)
1110
np <- nuts_params(fit)
1211
lp <- log_posterior(fit)
1312

@@ -21,6 +20,7 @@ test_that("all mcmc_nuts_* (except energy) return gtable objects", {
2120
expect_gtable(mcmc_nuts_stepsize(np, lp))
2221
expect_gtable(mcmc_nuts_stepsize(np, lp, chain = CHAINS))
2322

23+
np <- ensure_divergences(np)
2424
expect_gtable(mcmc_nuts_divergence(np, lp))
2525
expect_gtable(mcmc_nuts_divergence(np, lp, chain = CHAINS))
2626
})

tests/testthat/test-mcmc-recover.R

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,7 @@ context("MCMC: recover")
55
alpha <- 1; beta <- c(-.5, .5); sigma <- 2
66
X <- matrix(rnorm(200), 100, 2)
77
y <- rnorm(100, mean = c(alpha + X %*% beta), sd = sigma)
8-
capture.output(
9-
fit <- stan_glm(y ~ ., data = data.frame(y, X))
10-
)
8+
fit <- stan_glm(y ~ ., data = data.frame(y, X), refresh = 0, iter = 500, chains = 2)
119
draws <- as.matrix(fit)
1210
true <- c(alpha, beta, sigma)
1311

tests/testthat/test-mcmc-scatter-and-parcoord.R

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,10 @@ context("MCMC: scatter and parallel coordinates plots")
55
source(test_path("data-for-mcmc-tests.R"))
66

77
# also fit an rstanarm model to use with mcmc_pairs
8-
capture.output(
9-
fit <- stan_glm(mpg ~ wt + am, data = mtcars, iter = 1000, chains = 2, refresh = 0)
10-
)
8+
fit <- stan_glm(mpg ~ wt + am, data = mtcars, iter = 1000, chains = 2, refresh = 0)
119
post <- as.array(fit)
1210
lp <- log_posterior(fit)
13-
np <- nuts_params(fit)
14-
divs <- sample(c(0,1), size = 1000, prob = c(0.25, 0.75), replace = TRUE)
15-
np$Value[np$Parameter=="divergent__"] <- divs # fake divergences
16-
11+
np <- ensure_divergences(nuts_params(fit))
1712

1813
test_that("mcmc_scatter returns a ggplot object", {
1914
expect_gg(mcmc_scatter(arr, pars = c("beta[1]", "beta[2]")))
@@ -50,7 +45,9 @@ test_that("mcmc_scatter accepts NUTS info", {
5045

5146
# mcmc_pairs -------------------------------------------------------------
5247
test_that("mcmc_pairs returns a bayesplot_grid object", {
53-
expect_bayesplot_grid(mcmc_pairs(arr, pars = c("(Intercept)", "sigma")))
48+
g <- mcmc_pairs(arr, pars = c("(Intercept)", "sigma"))
49+
expect_bayesplot_grid(g)
50+
expect_equal(print(g), plot(g))
5451
expect_bayesplot_grid(mcmc_pairs(arr, pars = "sigma", regex_pars = "beta"))
5552
expect_bayesplot_grid(mcmc_pairs(arr, regex_pars = "x:[1-3]",
5653
transformations = "exp",

tests/testthat/test-mcmc-traces.R

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@ test_that("mcmc_trace_highlight throws error if 1 chain but multiple chains requ
2929
expect_error(mcmc_trace_highlight(arr1chain, highlight = 1), "requires multiple chains")
3030
})
3131

32+
test_that("mcmc_trace_highlight throws error if highlight > number of chains", {
33+
expect_error(mcmc_trace_highlight(arr, pars = "sigma", highlight = 7), "'highlight' is 7")
34+
})
35+
3236
# options -----------------------------------------------------------------
3337
test_that("mcmc_trace options work", {
3438
expect_gg(g1 <- mcmc_trace(arr, regex_pars = "beta", window = c(5, 10)))
@@ -46,41 +50,44 @@ test_that("mcmc_trace options work", {
4650

4751
# displaying divergences in traceplot -------------------------------------
4852
test_that("mcmc_trace 'np' argument works", {
53+
skip_if_not_installed("rstanarm")
4954
suppressPackageStartupMessages(library(rstanarm))
50-
suppressWarnings(capture.output(
51-
fit <- stan_glm(mpg ~ ., data = mtcars, iter = 200, refresh = 0,
52-
prior = hs(), adapt_delta = 0.7)
53-
))
55+
fit <- stan_glm(mpg ~ wt + am, data = mtcars, iter = 1000, chains = 2, refresh = 0)
5456
draws <- as.array(fit)
5557

5658
# divergences via nuts_params
57-
divs <- nuts_params(fit, pars = "divergent__")
58-
g <- mcmc_trace(draws, pars = "sigma", np = divs)
59+
divs1 <- ensure_divergences(nuts_params(fit, pars = "divergent__"))
60+
g <- mcmc_trace(draws, pars = "sigma", np = divs1)
5961
expect_gg(g)
6062
l2_data <- g$layers[[2]]$data
6163
expect_equal(names(l2_data), "Divergent")
6264

6365
# divergences as vector via 'divergences' arg should throw deprecation warning
64-
divs2 <- sample(c(0,1), nrow(draws), replace = TRUE)
66+
divs2 <- rep_len(c(0,1), length.out = nrow(draws))
6567
expect_warning(
6668
g2 <- mcmc_trace(draws, pars = "sigma", divergences = divs2),
6769
regexp = "deprecated"
6870
)
6971
expect_gg(g2)
7072

73+
expect_error(
74+
mcmc_trace(draws, pars = "sigma", np = divs1, divergences = divs2),
75+
"can't both be specified"
76+
)
77+
7178
# check errors & messages
7279
expect_error(mcmc_trace(draws, pars = "sigma", np = 1),
7380
"length(divergences) == n_iter is not TRUE",
7481
fixed = TRUE)
75-
expect_error(mcmc_trace(draws[,1:2,], pars = "sigma", np = divs),
82+
expect_error(mcmc_trace(draws[,1,], pars = "sigma", np = divs1),
7683
"num_chains(np) == n_chain is not TRUE",
7784
fixed = TRUE)
78-
expect_error(mcmc_trace(draws, pars = "sigma", np = divs[1:10, ]),
85+
expect_error(mcmc_trace(draws, pars = "sigma", np = divs1[1:10, ]),
7986
"num_iters(np) == n_iter is not TRUE",
8087
fixed = TRUE)
8188

82-
divs$Value[divs$Parameter == "divergent__"] <- 0
83-
expect_message(mcmc_trace(draws, pars = "sigma", np = divs),
89+
divs1$Value[divs1$Parameter == "divergent__"] <- 0
90+
expect_message(mcmc_trace(draws, pars = "sigma", np = divs1),
8491
"No divergences to plot.")
8592
})
8693

tests/testthat/test-ppc-intervals.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ test_that("ppc_intervals_grouped returns ggplot object", {
2727
})
2828

2929
test_that("ppc_ribbon_grouped returns ggplot object", {
30-
expect_gg(
31-
ppc_ribbon_grouped(y, yrep, x, group, facet_args = list(scales = "fixed")))
30+
expect_gg(ppc_ribbon_grouped(y, yrep, x, group))
31+
expect_gg(ppc_ribbon_grouped(y, yrep, x, group, facet_args = list(scales = "fixed")))
3232
})
3333

3434
test_that("ppc_intervals_data returns correct structure", {

0 commit comments

Comments
 (0)