Skip to content

Commit ae573cd

Browse files
author
Ozan Adiguzel
committed
handle posterior's draws objects
1 parent 5555083 commit ae573cd

File tree

8 files changed

+33
-4
lines changed

8 files changed

+33
-4
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
inst/doc
66

77
tests/testthat/Rplots.pdf
8+
tests/testthat/_snaps/
89

910
.DS_Store
1011

DESCRIPTION

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ Imports:
3131
ggplot2 (>= 3.0.0),
3232
ggridges,
3333
glue,
34+
posterior,
3435
reshape2,
3536
rlang (>= 0.3.0),
3637
stats,
@@ -53,7 +54,7 @@ Suggests:
5354
survival,
5455
testthat (>= 2.0.0),
5556
vdiffr
56-
RoxygenNote: 7.1.1
57+
RoxygenNote: 7.1.2
5758
VignetteBuilder: knitr
5859
Encoding: UTF-8
5960
Roxygen: list(markdown = TRUE)

R/helpers-mcmc.R

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@ prepare_mcmc_array <- function(x,
99
pars = character(),
1010
regex_pars = character(),
1111
transformations = list()) {
12-
if (is_df_with_chain(x)) {
12+
if (posterior::is_draws(x)) {
13+
x <- posterior::as_draws_array(x)
14+
} else if (is_df_with_chain(x)) {
1315
x <- df_with_chain2array(x)
1416
} else if (is_chain_list(x)) {
1517
# this will apply to mcmc.list and similar objects

man/MCMC-distributions.Rd

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/data-for-mcmc-tests.R

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ set.seed(8420)
33
# Prepare input objects
44
arr <- array(rnorm(4000), dim = c(100, 4, 10))
55
arr1chain <- arr[, 1, , drop = FALSE]
6+
drawsarr <- posterior::example_draws()
7+
drawsarr1chain <- drawsarr[, 1, , drop = FALSE]
68
mat <- matrix(rnorm(1000), nrow = 100, ncol = 10)
79
dframe <- as.data.frame(mat)
810
chainlist <- list(matrix(rnorm(1000), nrow = 100, ncol = 10),
@@ -16,6 +18,7 @@ chainlist1chain <- chainlist[1]
1618

1719
# one parameter
1820
arr1 <- arr[, , 1, drop = FALSE]
21+
drawsarr1 <- drawsarr[, , 1, drop = FALSE]
1922
mat1 <- mat[, 1, drop = FALSE]
2023
dframe1 <- dframe[, 1, drop = FALSE]
2124
chainlist1 <- list(chainlist[[1]][, 1, drop=FALSE],

tests/testthat/test-mcmc-combo.R

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ test_that("mcmc_combo returns a gtable object", {
77
expect_gtable(mcmc_combo(arr, regex_pars = "beta"))
88
expect_gtable(mcmc_combo(arr, regex_pars = "beta",
99
gg_theme = ggplot2::theme_dark()))
10+
expect_gtable(mcmc_combo(drawsarr, regex_pars = "theta"))
1011
expect_gtable(mcmc_combo(mat, regex_pars = "beta",
1112
binwidth = 1/20, combo = c("dens", "hist"),
1213
facet_args = list(nrow = 2)))
@@ -18,6 +19,7 @@ test_that("mcmc_combo returns a gtable object", {
1819
combo = c("trace", "hist")))
1920

2021
expect_gtable(mcmc_combo(arr1, pars = "(Intercept)"))
22+
expect_gtable(mcmc_combo(drawsarr1))
2123
expect_gtable(mcmc_combo(mat1))
2224
expect_gtable(mcmc_combo(dframe1))
2325
})
@@ -27,6 +29,9 @@ test_that("mcmc_combo throws error if 1 chain but multiple chains required", {
2729
expect_error(mcmc_combo(arr1chain, regex_pars = "beta",
2830
combo = c("trace_highlight", "dens")),
2931
"requires multiple chains")
32+
expect_error(mcmc_combo(drawsarr1chain, regex_pars = "theta",
33+
combo = c("trace_highlight", "dens")),
34+
"requires multiple chains")
3035
expect_error(mcmc_combo(mat, regex_pars = "beta",
3136
combo = c("trace_highlight", "hist")),
3237
"requires multiple chains")

tests/testthat/test-mcmc-distributions.R

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,23 @@ get_palette <- function(ggplot, n) {
1111
test_that("mcmc_hist returns a ggplot object", {
1212
expect_gg(mcmc_hist(arr, pars = "beta[1]", regex_pars = "x\\:"))
1313
expect_gg(mcmc_hist(arr1chain, regex_pars = "beta"))
14+
expect_gg(mcmc_hist(drawsarr, pars = "theta[1]"))
15+
expect_gg(mcmc_hist(drawsarr1chain, regex_pars = "theta"))
1416
expect_gg(mcmc_hist(mat))
1517
expect_gg(mcmc_hist(dframe))
1618
expect_gg(mcmc_hist(dframe_multiple_chains))
1719

1820
expect_gg(mcmc_hist(arr1))
21+
expect_gg(mcmc_hist(drawsarr1))
1922
expect_gg(mcmc_hist(mat1))
2023
expect_gg(mcmc_hist(dframe1))
2124
})
2225

2326
test_that("mcmc_dens returns a ggplot object", {
2427
expect_gg(mcmc_dens(arr, pars = "beta[2]", regex_pars = "x\\:"))
2528
expect_gg(mcmc_dens(arr1chain, regex_pars = "beta"))
29+
expect_gg(mcmc_hist(drawsarr, pars = "theta[1]"))
30+
expect_gg(mcmc_hist(drawsarr1chain, regex_pars = "theta"))
2631
expect_gg(mcmc_dens(mat))
2732

2833
expect_gg(mcmc_dens(dframe, transformations = list(sigma = function(x) x^2)))
@@ -33,6 +38,7 @@ test_that("mcmc_dens returns a ggplot object", {
3338
))
3439

3540
expect_gg(mcmc_dens(arr1))
41+
expect_gg(mcmc_hist(drawsarr1))
3642
expect_gg(mcmc_dens(mat1))
3743
expect_gg(mcmc_dens(dframe1))
3844
})
@@ -92,18 +98,22 @@ test_that("mcmc_* throws error if 1 chain but multiple chains required", {
9298
expect_error(mcmc_hist_by_chain(mat), "requires multiple chains")
9399
expect_error(mcmc_hist_by_chain(dframe), "requires multiple chains")
94100
expect_error(mcmc_hist_by_chain(arr1chain), "requires multiple chains")
101+
expect_error(mcmc_hist_by_chain(drawsarr1chain), "requires multiple chains")
95102

96103
expect_error(mcmc_dens_overlay(mat), "requires multiple chains")
97104
expect_error(mcmc_dens_overlay(dframe), "requires multiple chains")
98105
expect_error(mcmc_dens_overlay(arr1chain), "requires multiple chains")
106+
expect_error(mcmc_dens_overlay(drawsarr1chain), "requires multiple chains")
99107

100108
expect_error(mcmc_dens_chains(mat), "requires multiple chains")
101109
expect_error(mcmc_dens_chains(dframe), "requires multiple chains")
102110
expect_error(mcmc_dens_chains(arr1chain), "requires multiple chains")
111+
expect_error(mcmc_dens_chains(drawsarr1chain), "requires multiple chains")
103112

104113
expect_error(mcmc_violin(mat), "requires multiple chains")
105114
expect_error(mcmc_violin(dframe), "requires multiple chains")
106115
expect_error(mcmc_violin(arr1chain), "requires multiple chains")
116+
expect_error(mcmc_violin(drawsarr1chain), "requires multiple chains")
107117
})
108118

109119

tests/testthat/test-mcmc-scatter-and-parcoord.R

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ if (requireNamespace("rstanarm", quietly = TRUE)) {
1717
test_that("mcmc_scatter returns a ggplot object", {
1818
expect_gg(mcmc_scatter(arr, pars = c("beta[1]", "beta[2]")))
1919
expect_gg(mcmc_scatter(arr1chain, regex_pars = "beta", size = 3, alpha = 0.5))
20+
expect_gg(mcmc_scatter(drawsarr, pars = c("theta[1]", "theta[2]")))
2021
expect_gg(mcmc_scatter(mat, pars = c("sigma", "(Intercept)")))
2122
expect_gg(mcmc_scatter(dframe, regex_pars = "x:[2,4]"))
2223
expect_gg(mcmc_scatter(dframe_multiple_chains,
@@ -26,7 +27,9 @@ test_that("mcmc_scatter returns a ggplot object", {
2627
test_that("mcmc_scatter throws error if number of parameters is not 2", {
2728
expect_error(mcmc_scatter(arr, pars = c("sigma", "beta[1]", "beta[2]")), "exactly 2 parameters")
2829
expect_error(mcmc_scatter(arr, pars = "sigma"), "exactly 2 parameters")
30+
expect_error(mcmc_scatter(drawsarr, pars = "mu"), "exactly 2 parameters")
2931
expect_error(mcmc_scatter(arr1), "exactly 2 parameters")
32+
expect_error(mcmc_scatter(drawsarr1), "exactly 2 parameters")
3033
expect_error(mcmc_scatter(mat1), "exactly 2 parameters")
3134
})
3235

@@ -46,12 +49,14 @@ test_that("mcmc_hex returns a ggplot object", {
4649
skip_if_not_installed("hexbin")
4750
expect_gg(mcmc_hex(arr, pars = c("beta[1]", "beta[2]")))
4851
expect_gg(mcmc_hex(arr1chain, regex_pars = "beta", binwidth = c(.5,.5)))
52+
expect_gg(mcmc_hex(drawsarr, pars = c("theta[1]", "theta[2]")))
4953
})
5054

5155
test_that("mcmc_hex throws error if number of parameters is not 2", {
5256
skip_if_not_installed("hexbin")
5357
expect_error(mcmc_hex(arr, pars = c("sigma", "beta[1]", "beta[2]")), "exactly 2 parameters")
5458
expect_error(mcmc_hex(arr, pars = "sigma"), "exactly 2 parameters")
59+
expect_error(mcmc_hex(drawsarr, pars = "mu"), "exactly 2 parameters")
5560
expect_error(mcmc_hex(arr1), "exactly 2 parameters")
5661
expect_error(mcmc_hex(mat1), "exactly 2 parameters")
5762
})
@@ -69,8 +74,10 @@ test_that("mcmc_pairs returns a bayesplot_grid object", {
6974
diag_fun = "dens", off_diag_fun = "hex",
7075
diag_args = list(trim = FALSE),
7176
off_diag_args = list(binwidth = c(0.5, 0.5))))
77+
expect_bayesplot_grid(mcmc_pairs(drawsarr, pars = "mu", regex_pars = "theta"))
7278

7379
expect_bayesplot_grid(suppressWarnings(mcmc_pairs(arr1chain, regex_pars = "beta")))
80+
expect_bayesplot_grid(suppressWarnings(mcmc_pairs(drawsarr1chain, regex_pars = "theta")))
7481
expect_bayesplot_grid(suppressWarnings(mcmc_pairs(mat, pars = c("(Intercept)", "sigma"))))
7582
expect_bayesplot_grid(suppressWarnings(mcmc_pairs(dframe, pars = c("(Intercept)", "sigma"))))
7683
expect_bayesplot_grid(mcmc_pairs(dframe_multiple_chains, regex_pars = "beta"))

0 commit comments

Comments
 (0)