10
10
# ' @param ... Currently unused.
11
11
# ' @param lw A matrix of (smoothed) log weights with the same dimensions as
12
12
# ' `yrep`. See [loo::psis()] and the associated `weights()` method as well as
13
- # ' the **Examples** section, below.
13
+ # ' the **Examples** section, below. If `lw` is not specified then
14
+ # ' `psis_object` can be provided and log weights will be extracted.
15
+ # ' @param psis_object If using **loo** version `2.0.0` or greater, an
16
+ # ' object returned by the `psis()` function (or by the `loo()` function
17
+ # ' with argument `save_psis` set to `TRUE`).
14
18
# ' @param alpha,size,fatten,linewidth Arguments passed to code geoms to control plot
15
19
# ' aesthetics. For `ppc_loo_pit_qq()` and `ppc_loo_pit_overlay()`, `size` and
16
20
# ' `alpha` are passed to [ggplot2::geom_point()] and
71
75
# ' log_radon ~ floor + log_uranium + floor:log_uranium
72
76
# ' + (1 + floor | county),
73
77
# ' data = radon,
74
- # ' iter = 1000 ,
78
+ # ' iter = 100 ,
75
79
# ' chains = 2,
76
80
# ' cores = 2
77
81
# ' )
89
93
# ' ppc_loo_pit_qq(y, yrep, lw = lw)
90
94
# ' ppc_loo_pit_qq(y, yrep, lw = lw, compare = "normal")
91
95
# '
96
+ # ' # can use the psis object instead of lw
97
+ # ' ppc_loo_pit_qq(y, yrep, psis_object = psis1)
92
98
# '
93
99
# ' # loo predictive intervals vs observations
94
100
# ' keep_obs <- 1:50
138
144
# '
139
145
ppc_loo_pit_overlay <- function (y ,
140
146
yrep ,
141
- lw ,
147
+ lw = NULL ,
142
148
... ,
149
+ psis_object = NULL ,
143
150
pit = NULL ,
144
151
samples = 100 ,
145
152
size = 0.25 ,
@@ -158,6 +165,7 @@ ppc_loo_pit_overlay <- function(y,
158
165
y = y ,
159
166
yrep = yrep ,
160
167
lw = lw ,
168
+ psis_object = psis_object ,
161
169
pit = pit ,
162
170
samples = samples ,
163
171
bw = bw ,
@@ -253,8 +261,9 @@ ppc_loo_pit_overlay <- function(y,
253
261
ppc_loo_pit_data <-
254
262
function (y ,
255
263
yrep ,
256
- lw ,
264
+ lw = NULL ,
257
265
... ,
266
+ psis_object = NULL ,
258
267
pit = NULL ,
259
268
samples = 100 ,
260
269
bw = " nrd0" ,
@@ -267,6 +276,7 @@ ppc_loo_pit_data <-
267
276
suggested_package(" rstantools" )
268
277
y <- validate_y(y )
269
278
yrep <- validate_predictions(yrep , length(y ))
279
+ lw <- .get_lw(lw , psis_object )
270
280
stopifnot(identical(dim(yrep ), dim(lw )))
271
281
pit <- rstantools :: loo_pit(object = yrep , y = y , lw = lw )
272
282
}
@@ -295,22 +305,24 @@ ppc_loo_pit_data <-
295
305
# ' @export
296
306
ppc_loo_pit_qq <- function (y ,
297
307
yrep ,
298
- lw ,
299
- pit ,
300
- compare = c(" uniform" , " normal" ),
308
+ lw = NULL ,
301
309
... ,
310
+ psis_object = NULL ,
311
+ pit = NULL ,
312
+ compare = c(" uniform" , " normal" ),
302
313
size = 2 ,
303
314
alpha = 1 ) {
304
315
check_ignored_arguments(... )
305
316
306
317
compare <- match.arg(compare )
307
- if (! missing (pit )) {
318
+ if (! is.null (pit )) {
308
319
stopifnot(is.numeric(pit ), is_vector_or_1Darray(pit ))
309
320
inform(" 'pit' specified so ignoring 'y','yrep','lw' if specified." )
310
321
} else {
311
322
suggested_package(" rstantools" )
312
323
y <- validate_y(y )
313
324
yrep <- validate_predictions(yrep , length(y ))
325
+ lw <- .get_lw(lw , psis_object )
314
326
stopifnot(identical(dim(yrep ), dim(lw )))
315
327
pit <- rstantools :: loo_pit(object = yrep , y = y , lw = lw )
316
328
}
@@ -352,7 +364,7 @@ ppc_loo_pit <-
352
364
function (y ,
353
365
yrep ,
354
366
lw ,
355
- pit ,
367
+ pit = NULL ,
356
368
compare = c(" uniform" , " normal" ),
357
369
... ,
358
370
size = 2 ,
@@ -374,18 +386,14 @@ ppc_loo_pit <-
374
386
# ' @rdname PPC-loo
375
387
# ' @export
376
388
# ' @template args-prob-prob_outer
377
- # ' @param psis_object If using **loo** version `2.0.0` or greater, an
378
- # ' object returned by the `psis()` function (or by the `loo()` function
379
- # ' with argument `save_psis` set to `TRUE`).
380
- # ' @param intervals For `ppc_loo_intervals()` and `ppc_loo_ribbon()`,
381
- # ' optionally a matrix of precomputed LOO predictive intervals
382
- # ' that can be specified instead of `yrep` and `lw` (these are both
383
- # ' ignored if `intervals` is specified). If not specified the intervals
384
- # ' are computed internally before plotting. If specified, `intervals`
385
- # ' must be a matrix with number of rows equal to the number of data points and
386
- # ' five columns in the following order: lower outer interval, lower inner
387
- # ' interval, median (50%), upper inner interval and upper outer interval
388
- # ' (column names are ignored).
389
+ # ' @param intervals For `ppc_loo_intervals()` and `ppc_loo_ribbon()`, optionally
390
+ # ' a matrix of pre-computed LOO predictive intervals that can be specified
391
+ # ' instead of `yrep` (ignored if `intervals` is specified). If not specified
392
+ # ' the intervals are computed internally before plotting. If specified,
393
+ # ' `intervals` must be a matrix with number of rows equal to the number of
394
+ # ' data points and five columns in the following order: lower outer interval,
395
+ # ' lower inner interval, median (50%), upper inner interval and upper outer
396
+ # ' interval (column names are ignored).
389
397
# ' @param order For `ppc_loo_intervals()`, a string indicating how to arrange
390
398
# ' the plotted intervals. The default (`"index"`) is to plot them in the
391
399
# ' order of the observations. The alternative (`"median"`) arranges them
@@ -403,9 +411,9 @@ ppc_loo_intervals <-
403
411
function (y ,
404
412
yrep ,
405
413
psis_object ,
414
+ ... ,
406
415
subset = NULL ,
407
416
intervals = NULL ,
408
- ... ,
409
417
prob = 0.5 ,
410
418
prob_outer = 0.9 ,
411
419
alpha = 0.33 ,
@@ -498,11 +506,10 @@ ppc_loo_intervals <-
498
506
ppc_loo_ribbon <-
499
507
function (y ,
500
508
yrep ,
501
- lw ,
502
509
psis_object ,
510
+ ... ,
503
511
subset = NULL ,
504
512
intervals = NULL ,
505
- ... ,
506
513
prob = 0.5 ,
507
514
prob_outer = 0.9 ,
508
515
alpha = 0.33 ,
@@ -720,3 +727,17 @@ ppc_loo_ribbon <-
720
727
721
728
list (xs = xs , unifs = bc_mat )
722
729
}
730
+
731
+ # Extract log weights from psis_object if provided
732
+ .get_lw <- function (lw = NULL , psis_object = NULL ) {
733
+ if (is.null(lw ) && is.null(psis_object )) {
734
+ abort(" One of 'lw' and 'psis_object' must be specified." )
735
+ } else if (is.null(lw )) {
736
+ suggested_package(" loo" , min_version = " 2.0.0" )
737
+ if (! loo :: is.psis(psis_object )) {
738
+ abort(" If specified, 'psis_object' must be a PSIS object from the loo package." )
739
+ }
740
+ lw <- loo :: weights.importance_sampling(psis_object )
741
+ }
742
+ lw
743
+ }
0 commit comments