Skip to content

add fun_avg to ppc_avg functions #349

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jul 3, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

* Add possibility for left-truncation to `ppc_km_overlay()` and `ppc_km_overlay_grouped()` by @Sakuski
* Added `ppc_loo_pit_ecdf()` by @TeemuSailynoja
* PPC "avg" functions (`ppc_scatter_avg()`, `ppc_error_scatter_avg()`, etc.) gain a `fun_arg` argument to set the averaging function. (Suggestion of #348, @kruschke).

# bayesplot 1.12.0

Expand Down
14 changes: 11 additions & 3 deletions R/ppc-errors.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
#' @template args-group
#' @template args-facet_args
#' @param ... Currently unused.
#' @param fun_avg Function to apply to compute the posterior average.
#' Defaults to `"mean"`.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the ppc_stat functions, we have a similar argument, by the name stat, which does a very similar job:

#' @param stat A single function or a string naming a function, except for the
#' 2D plot which requires a vector of exactly two names or functions. In all
#' cases the function(s) should take a vector input and return a scalar
#' statistic. If specified as a string (or strings) then the legend will
#' display the function name(s). If specified as a function (or functions)
#' then generic naming is used in the legend.

We could align this doc to read, for example:

#' @param fun_avg A function or a string naming a function for computing the
#' posterior average. In both cases, the function should take a vector input and
#' return a scalar statistic. If specified as a string, then the legend will
#' display the function name. If specified as a function
#' then generic naming is used in the legend.
#' Defaults to "mean".

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Average y - y_rep axis label is not affected. I didn't want to make yrep_avg_label() and error_avg_label() depend on fun_avg or change the default "Average y - y_rep" labels.

It does affect the $rep_label in ppc_scatter_avg_data(y, yrep) when fun_avg is a string.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait, i'm just noticing Aki's comment. I'll switch to stat.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @tjmahr. The code looks good. And I agree with changing to stat

#' @param size,alpha For scatterplots, arguments passed to
#' [ggplot2::geom_point()] to control the appearance of the points. For the
#' binned error plot, arguments controlling the size of the outline and
Expand Down Expand Up @@ -209,6 +211,7 @@ ppc_error_scatter_avg <-
function(y,
yrep,
...,
fun_avg = "mean",
size = 2.5,
alpha = 0.8) {
check_ignored_arguments(...)
Expand All @@ -221,7 +224,8 @@ ppc_error_scatter_avg <-
yrep = errors,
size = size,
alpha = alpha,
ref_line = FALSE
ref_line = FALSE,
fun_avg = fun_avg
) +
labs(x = error_avg_label(), y = y_label())
}
Expand All @@ -234,6 +238,7 @@ ppc_error_scatter_avg_grouped <-
yrep,
group,
...,
fun_avg = "mean",
facet_args = list(),
size = 2.5,
alpha = 0.8) {
Expand All @@ -249,7 +254,8 @@ ppc_error_scatter_avg_grouped <-
size = size,
alpha = alpha,
facet_args = facet_args,
ref_line = FALSE
ref_line = FALSE,
fun_avg = fun_avg
) +
labs(x = error_avg_label(), y = y_label())
}
Expand All @@ -265,6 +271,7 @@ ppc_error_scatter_avg_vs_x <-
yrep,
x,
...,
fun_avg = "mean",
size = 2.5,
alpha = 0.8) {
check_ignored_arguments(...)
Expand All @@ -278,7 +285,8 @@ ppc_error_scatter_avg_vs_x <-
yrep = errors,
size = size,
alpha = alpha,
ref_line = FALSE
ref_line = FALSE,
fun_avg = fun_avg
) +
labs(x = error_avg_label(), y = expression(italic(x))) +
coord_flip()
Expand Down
27 changes: 19 additions & 8 deletions R/ppc-scatterplots.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
#' @template args-group
#' @template args-facet_args
#' @param ... Currently unused.
#' @param fun_avg Function to apply to compute the posterior average.
#' Defaults to `"mean"`.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above.

#' @param size,alpha Arguments passed to [ggplot2::geom_point()] to control the
#' appearance of the points.
#' @param ref_line If `TRUE` (the default) a dashed line with intercept 0 and
Expand All @@ -31,10 +33,10 @@
#' }
#' \item{`ppc_scatter_avg()`}{
#' A single scatterplot of `y` against the average values of `yrep`, i.e.,
#' the points `(x,y) = (mean(yrep[, n]), y[n])`, where each `yrep[, n]` is
#' a vector of length equal to the number of posterior draws. Unlike
#' for `ppc_scatter()`, for `ppc_scatter_avg()` `yrep` should contain many
#' draws (rows).
#' the points `(x,y) = (average(yrep[, n]), y[n])`, where each `yrep[, n]` is
#' a vector of length equal to the number of posterior draws and `average()`
#' is summary statistic. Unlike for `ppc_scatter()`, for `ppc_scatter_avg()`
#' `yrep` should contain many draws (rows).
#' }
#' \item{`ppc_scatter_avg_grouped()`}{
#' The same as `ppc_scatter_avg()`, but a separate plot is generated for
Expand All @@ -59,6 +61,9 @@
#' p1 + lims
#' p2 + lims
#'
#' # "average" function is customizable
#' ppc_scatter_avg(y, yrep, fun_avg = "median", ref_line = FALSE)
#'
#' # for ppc_scatter_avg_grouped the default is to allow the facets
#' # to have different x and y axes
#' group <- example_group_data()
Expand Down Expand Up @@ -116,6 +121,7 @@ ppc_scatter_avg <-
function(y,
yrep,
...,
fun_avg = "mean",
size = 2.5,
alpha = 0.8,
ref_line = TRUE) {
Expand All @@ -125,7 +131,7 @@ ppc_scatter_avg <-
dots$group <- NULL
}

data <- ppc_scatter_avg_data(y, yrep, group = dots$group)
data <- ppc_scatter_avg_data(y, yrep, group = dots$group, fun_avg = fun_avg)
if (is.null(dots$group) && nrow(yrep) == 1) {
inform(
"With only 1 row in 'yrep' ppc_scatter_avg is the same as ppc_scatter."
Expand Down Expand Up @@ -155,6 +161,7 @@ ppc_scatter_avg_grouped <-
yrep,
group,
...,
fun_avg = "mean",
facet_args = list(),
size = 2.5,
alpha = 0.8,
Expand Down Expand Up @@ -184,16 +191,20 @@ ppc_scatter_data <- function(y, yrep) {

#' @rdname PPC-scatterplots
#' @export
ppc_scatter_avg_data <- function(y, yrep, group = NULL) {
ppc_scatter_avg_data <- function(y, yrep, group = NULL, fun_avg = "mean") {
y <- validate_y(y)
yrep <- validate_predictions(yrep, length(y))
if (!is.null(group)) {
group <- validate_group(group, length(y))
}

data <- ppc_scatter_data(y = y, yrep = t(colMeans(yrep)))
data <- ppc_scatter_data(y = y, yrep = t(apply(yrep, 2, FUN = fun_avg)))
data$rep_id <- NA_integer_
levels(data$rep_label) <- "mean(italic(y)[rep]))"
if (is.character(fun_avg)) {
levels(data$rep_label) <- sprintf("%s(italic(y)[rep]))", fun_avg)
} else {
levels(data$rep_label) <- "Average(italic(y)[rep]))"
}

if (!is.null(group)) {
data <- tibble::add_column(data,
Expand Down
16 changes: 14 additions & 2 deletions man/PPC-errors.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

27 changes: 21 additions & 6 deletions man/PPC-scatterplots.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.