Skip to content

Commit 1fe6b67

Browse files
authored
Merge pull request #312 from stan-dev/ppc-loo-psis_object
Allow `psis_object` argument for all ppc-loo plots
2 parents 09b813a + 4ec0743 commit 1fe6b67

File tree

3 files changed

+85
-53
lines changed

3 files changed

+85
-53
lines changed

R/ppc-loo.R

Lines changed: 45 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,11 @@
1010
#' @param ... Currently unused.
1111
#' @param lw A matrix of (smoothed) log weights with the same dimensions as
1212
#' `yrep`. See [loo::psis()] and the associated `weights()` method as well as
13-
#' the **Examples** section, below.
13+
#' the **Examples** section, below. If `lw` is not specified then
14+
#' `psis_object` can be provided and log weights will be extracted.
15+
#' @param psis_object If using **loo** version `2.0.0` or greater, an
16+
#' object returned by the `psis()` function (or by the `loo()` function
17+
#' with argument `save_psis` set to `TRUE`).
1418
#' @param alpha,size,fatten,linewidth Arguments passed to code geoms to control plot
1519
#' aesthetics. For `ppc_loo_pit_qq()` and `ppc_loo_pit_overlay()`, `size` and
1620
#' `alpha` are passed to [ggplot2::geom_point()] and
@@ -71,7 +75,7 @@
7175
#' log_radon ~ floor + log_uranium + floor:log_uranium
7276
#' + (1 + floor | county),
7377
#' data = radon,
74-
#' iter = 1000,
78+
#' iter = 100,
7579
#' chains = 2,
7680
#' cores = 2
7781
#' )
@@ -89,6 +93,8 @@
8993
#' ppc_loo_pit_qq(y, yrep, lw = lw)
9094
#' ppc_loo_pit_qq(y, yrep, lw = lw, compare = "normal")
9195
#'
96+
#' # can use the psis object instead of lw
97+
#' ppc_loo_pit_qq(y, yrep, psis_object = psis1)
9298
#'
9399
#' # loo predictive intervals vs observations
94100
#' keep_obs <- 1:50
@@ -138,8 +144,9 @@ NULL
138144
#'
139145
ppc_loo_pit_overlay <- function(y,
140146
yrep,
141-
lw,
147+
lw = NULL,
142148
...,
149+
psis_object = NULL,
143150
pit = NULL,
144151
samples = 100,
145152
size = 0.25,
@@ -158,6 +165,7 @@ ppc_loo_pit_overlay <- function(y,
158165
y = y,
159166
yrep = yrep,
160167
lw = lw,
168+
psis_object = psis_object,
161169
pit = pit,
162170
samples = samples,
163171
bw = bw,
@@ -253,8 +261,9 @@ ppc_loo_pit_overlay <- function(y,
253261
ppc_loo_pit_data <-
254262
function(y,
255263
yrep,
256-
lw,
264+
lw = NULL,
257265
...,
266+
psis_object = NULL,
258267
pit = NULL,
259268
samples = 100,
260269
bw = "nrd0",
@@ -267,6 +276,7 @@ ppc_loo_pit_data <-
267276
suggested_package("rstantools")
268277
y <- validate_y(y)
269278
yrep <- validate_predictions(yrep, length(y))
279+
lw <- .get_lw(lw, psis_object)
270280
stopifnot(identical(dim(yrep), dim(lw)))
271281
pit <- rstantools::loo_pit(object = yrep, y = y, lw = lw)
272282
}
@@ -295,22 +305,24 @@ ppc_loo_pit_data <-
295305
#' @export
296306
ppc_loo_pit_qq <- function(y,
297307
yrep,
298-
lw,
299-
pit,
300-
compare = c("uniform", "normal"),
308+
lw = NULL,
301309
...,
310+
psis_object = NULL,
311+
pit = NULL,
312+
compare = c("uniform", "normal"),
302313
size = 2,
303314
alpha = 1) {
304315
check_ignored_arguments(...)
305316

306317
compare <- match.arg(compare)
307-
if (!missing(pit)) {
318+
if (!is.null(pit)) {
308319
stopifnot(is.numeric(pit), is_vector_or_1Darray(pit))
309320
inform("'pit' specified so ignoring 'y','yrep','lw' if specified.")
310321
} else {
311322
suggested_package("rstantools")
312323
y <- validate_y(y)
313324
yrep <- validate_predictions(yrep, length(y))
325+
lw <- .get_lw(lw, psis_object)
314326
stopifnot(identical(dim(yrep), dim(lw)))
315327
pit <- rstantools::loo_pit(object = yrep, y = y, lw = lw)
316328
}
@@ -352,7 +364,7 @@ ppc_loo_pit <-
352364
function(y,
353365
yrep,
354366
lw,
355-
pit,
367+
pit = NULL,
356368
compare = c("uniform", "normal"),
357369
...,
358370
size = 2,
@@ -374,18 +386,14 @@ ppc_loo_pit <-
374386
#' @rdname PPC-loo
375387
#' @export
376388
#' @template args-prob-prob_outer
377-
#' @param psis_object If using **loo** version `2.0.0` or greater, an
378-
#' object returned by the `psis()` function (or by the `loo()` function
379-
#' with argument `save_psis` set to `TRUE`).
380-
#' @param intervals For `ppc_loo_intervals()` and `ppc_loo_ribbon()`,
381-
#' optionally a matrix of precomputed LOO predictive intervals
382-
#' that can be specified instead of `yrep` and `lw` (these are both
383-
#' ignored if `intervals` is specified). If not specified the intervals
384-
#' are computed internally before plotting. If specified, `intervals`
385-
#' must be a matrix with number of rows equal to the number of data points and
386-
#' five columns in the following order: lower outer interval, lower inner
387-
#' interval, median (50%), upper inner interval and upper outer interval
388-
#' (column names are ignored).
389+
#' @param intervals For `ppc_loo_intervals()` and `ppc_loo_ribbon()`, optionally
390+
#' a matrix of pre-computed LOO predictive intervals that can be specified
391+
#' instead of `yrep` (ignored if `intervals` is specified). If not specified
392+
#' the intervals are computed internally before plotting. If specified,
393+
#' `intervals` must be a matrix with number of rows equal to the number of
394+
#' data points and five columns in the following order: lower outer interval,
395+
#' lower inner interval, median (50%), upper inner interval and upper outer
396+
#' interval (column names are ignored).
389397
#' @param order For `ppc_loo_intervals()`, a string indicating how to arrange
390398
#' the plotted intervals. The default (`"index"`) is to plot them in the
391399
#' order of the observations. The alternative (`"median"`) arranges them
@@ -403,9 +411,9 @@ ppc_loo_intervals <-
403411
function(y,
404412
yrep,
405413
psis_object,
414+
...,
406415
subset = NULL,
407416
intervals = NULL,
408-
...,
409417
prob = 0.5,
410418
prob_outer = 0.9,
411419
alpha = 0.33,
@@ -498,11 +506,10 @@ ppc_loo_intervals <-
498506
ppc_loo_ribbon <-
499507
function(y,
500508
yrep,
501-
lw,
502509
psis_object,
510+
...,
503511
subset = NULL,
504512
intervals = NULL,
505-
...,
506513
prob = 0.5,
507514
prob_outer = 0.9,
508515
alpha = 0.33,
@@ -720,3 +727,17 @@ ppc_loo_ribbon <-
720727

721728
list(xs = xs, unifs = bc_mat)
722729
}
730+
731+
# Extract log weights from psis_object if provided
732+
.get_lw <- function(lw = NULL, psis_object = NULL) {
733+
if (is.null(lw) && is.null(psis_object)) {
734+
abort("One of 'lw' and 'psis_object' must be specified.")
735+
} else if (is.null(lw)) {
736+
suggested_package("loo", min_version = "2.0.0")
737+
if (!loo::is.psis(psis_object)) {
738+
abort("If specified, 'psis_object' must be a PSIS object from the loo package.")
739+
}
740+
lw <- loo::weights.importance_sampling(psis_object)
741+
}
742+
lw
743+
}

man/PPC-loo.Rd

Lines changed: 28 additions & 24 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test-ppc-loo.R

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ test_that("ppc_loo_pit_overlay returns ggplot object", {
3535
skip_if_not_installed("rstanarm")
3636
skip_if_not_installed("loo")
3737
expect_gg(ppc_loo_pit_overlay(y, yrep, lw, samples = 25))
38+
expect_gg(ppc_loo_pit_overlay(y, yrep, psis_object = psis1, samples = 25))
3839
})
3940

4041
test_that("ppc_loo_pit_overlay warns about binary data", {
@@ -65,29 +66,35 @@ test_that("ppc_loo_pit_qq returns ggplot object", {
6566
skip_if_not_installed("rstanarm")
6667
skip_if_not_installed("loo")
6768
expect_gg(p1 <- ppc_loo_pit_qq(y, yrep, lw))
69+
expect_gg(p2 <- ppc_loo_pit_qq(y, yrep, psis_object = psis1))
6870
expect_equal(p1$labels$x, "Uniform")
69-
expect_gg(p2 <- ppc_loo_pit_qq(y, yrep, lw, compare = "normal"))
70-
expect_equal(p2$labels$x, "Normal")
71+
expect_equal(p1$data, p2$data)
72+
expect_gg(p3 <- ppc_loo_pit_qq(y, yrep, lw, compare = "normal"))
73+
expect_equal(p3$labels$x, "Normal")
7174
})
7275

7376
test_that("ppc_loo_pit functions work when pit specified instead of y,yrep,lw", {
7477
skip_if_not_installed("rstanarm")
7578
skip_if_not_installed("loo")
7679
expect_gg(ppc_loo_pit_qq(pit = pits))
7780
expect_message(
78-
ppc_loo_pit_qq(y = y, yrep = yrep, lw = lw, pit = pits),
81+
p1 <- ppc_loo_pit_qq(y = y, yrep = yrep, lw = lw, pit = pits),
7982
"'pit' specified so ignoring 'y','yrep','lw' if specified"
8083
)
84+
expect_message(
85+
p2 <- ppc_loo_pit_qq(pit = pits)
86+
)
87+
expect_equal(p1$data, p2$data)
8188

82-
expect_gg(ppc_loo_pit_overlay(pit = pits))
89+
90+
expect_gg(p1 <- ppc_loo_pit_overlay(pit = pits))
8391
expect_message(
8492
ppc_loo_pit_overlay(y = y, yrep = yrep, lw = lw, pit = pits),
8593
"'pit' specified so ignoring 'y','yrep','lw' if specified"
8694
)
8795
})
8896

8997

90-
9198
test_that("ppc_loo_intervals returns ggplot object", {
9299
skip_if_not_installed("rstanarm")
93100
skip_if_not_installed("loo")

0 commit comments

Comments
 (0)