diff --git a/NEWS.md b/NEWS.md index 7bcc32375c..3957042f3b 100644 --- a/NEWS.md +++ b/NEWS.md @@ -314,6 +314,8 @@ (@teunbrand, #4335). * `ggsave()` can write a multi-page pdf file when provided with a list of plots (@teunbrand, #5093). +* (internal) When `validate_subclass()` fails to find a class directly, it tries + to retrieve the class via constructor functions (@teunbrand). # ggplot2 3.5.1 diff --git a/R/layer.R b/R/layer.R index 2cd10c447f..8ed938def6 100644 --- a/R/layer.R +++ b/R/layer.R @@ -458,18 +458,50 @@ validate_subclass <- function(x, subclass, if (inherits(x, subclass)) { return(x) - } else if (is_scalar_character(x)) { - name <- paste0(subclass, camelize(x, first = TRUE)) - obj <- find_global(name, env = env) + } + if (!is_scalar_character(x)) { + stop_input_type(x, as_cli("either a string or a {.cls {subclass}} object"), arg = x_arg) + } - if (is.null(obj) || !inherits(obj, subclass)) { - cli::cli_abort("Can't find {argname} called {.val {x}}.", call = call) - } + # Try getting class object directly + name <- paste0(subclass, camelize(x, first = TRUE)) + obj <- find_global(name, env = env) + if (inherits(obj, subclass)) { + return(obj) + } + + # Try retrieving class via constructors + name <- snakeize(name) + obj <- find_global(name, env = env, mode = "function") + if (is.function(obj)) { + obj <- try_fetch( + obj(), + error = function(cnd) { + # replace `obj()` call with name of actual constructor + cnd$call <- call(name) + cli::cli_abort( + "Failed to retrieve a {.cls {subclass}} object from {.fn {name}}.", + parent = cnd, call = call + ) + }) + } + # Position constructors return classes directly + if (inherits(obj, subclass)) { + return(obj) + } + # Try prying the class from a layer + if (inherits(obj, "Layer")) { + obj <- switch( + subclass, + Geom = obj$geom, + Stat = obj$stat, + NULL + ) + } + if (inherits(obj, subclass)) { return(obj) - } else if (is.null(x)) { - cli::cli_abort("The {.arg {x_arg}} argument cannot be empty.", call = call) } - stop_input_type(x, as_cli("either a string or a {.cls {subclass}} object")) + cli::cli_abort("Can't find {argname} called {.val {x}}.", call = call) } # helper function to adjust the draw_key slot of a geom diff --git a/tests/testthat/_snaps/layer.md b/tests/testthat/_snaps/layer.md index 79b561b17d..c796c6a530 100644 --- a/tests/testthat/_snaps/layer.md +++ b/tests/testthat/_snaps/layer.md @@ -1,14 +1,14 @@ # layer() checks its input - The `geom` argument cannot be empty. + `geom` must be either a string or a object, not `NULL`. --- - The `stat` argument cannot be empty. + `stat` must be either a string or a object, not `NULL`. --- - The `position` argument cannot be empty. + `position` must be either a string or a object, not `NULL`. --- @@ -25,7 +25,13 @@ --- - `x` must be either a string or a object, not an environment. + `environment()` must be either a string or a object, not an environment. + +--- + + Failed to retrieve a object from `geom_foo()`. + Caused by error in `geom_foo()`: + ! This function is unconstructable. # unknown params create warning diff --git a/tests/testthat/test-layer.R b/tests/testthat/test-layer.R index f901d3b62f..5e2dbf1d2b 100644 --- a/tests/testthat/test-layer.R +++ b/tests/testthat/test-layer.R @@ -10,6 +10,9 @@ test_that("layer() checks its input", { expect_snapshot_error(validate_subclass("test", "geom")) expect_snapshot_error(validate_subclass(environment(), "geom")) + + geom_foo <- function(...) stop("This function is unconstructable.") + expect_snapshot_error(layer("foo", "identity", position = "identity")) }) test_that("aesthetics go in aes_params", { @@ -154,6 +157,22 @@ test_that("layer names can be resolved", { expect_snapshot(p + l + l, error = TRUE) }) +test_that("check_subclass can resolve classes via constructors", { + + env <- new_environment(list( + geom_foobar = geom_point, + stat_foobar = stat_boxplot, + position_foobar = position_nudge, + guide_foobar = guide_axis_theta + )) + + expect_s3_class(validate_subclass("foobar", "Geom", env = env), "GeomPoint") + expect_s3_class(validate_subclass("foobar", "Stat", env = env), "StatBoxplot") + expect_s3_class(validate_subclass("foobar", "Position", env = env), "PositionNudge") + expect_s3_class(validate_subclass("foobar", "Guide", env = env), "GuideAxisTheta") + +}) + test_that("attributes on layer data are preserved", { # This is a good layer for testing because: # * It needs to compute a statistic at the group level