Skip to content

Commit 1184068

Browse files
topeposimonpcouch
andauthored
add rlang type checkers (#950)
* add type checking files * remove newly unneeded checking functions * snapshot updates from tidymodels/recipes#1381 * updates files * basic replacements * type checker replacements * tidymodels/tailor#53 * Update R/checks.R Co-authored-by: Simon P. Couch <simonpatrickcouch@gmail.com> * add remote to get proper error messages * typo * update remotes? * only test snapshots with more recent version of R *with* rankdeficient --------- Co-authored-by: Simon P. Couch <simonpatrickcouch@gmail.com>
1 parent bc5422a commit 1184068

23 files changed

+1144
-243
lines changed

DESCRIPTION

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ Version: 1.2.1.9000
44
Authors@R: c(
55
person("Max", "Kuhn", , "max@posit.co", role = c("aut", "cre"),
66
comment = c(ORCID = "0000-0003-2402-136X")),
7-
person(given = "Posit Software, PBC", role = c("cph", "fnd"))
7+
person("Posit Software, PBC", role = c("cph", "fnd"))
88
)
99
Description: The ability to tune models is important. 'tune' contains
1010
functions and classes to be used in conjunction with other
@@ -27,12 +27,12 @@ Imports:
2727
ggplot2,
2828
glue (>= 1.6.2),
2929
GPfit,
30-
hardhat (>= 1.2.0),
30+
hardhat (>= 1.4.0.9002),
3131
lifecycle (>= 1.0.0),
32-
parsnip (>= 1.2.0),
32+
parsnip (>= 1.2.1.9003),
3333
purrr (>= 1.0.0),
34-
recipes (>= 1.0.4),
35-
rlang (>= 1.1.0),
34+
recipes (>= 1.1.0.9001),
35+
rlang (>= 1.1.4),
3636
rsample (>= 1.2.1.9000),
3737
tailor,
3838
tibble (>= 3.1.0),
@@ -57,8 +57,11 @@ Suggests:
5757
xgboost,
5858
xml2
5959
Remotes:
60+
tidymodels/hardhat,
61+
tidymodels/parsnip,
62+
tidymodels/recipes,
6063
tidymodels/rsample,
61-
tidymodels/tailor,
64+
tidymodels/tailor,
6265
tidymodels/workflows
6366
Config/Needs/website: pkgdown, tidymodels, kknn, doParallel, doFuture,
6467
tidyverse/tidytemplate

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,7 @@ export(tune_bayes)
244244
export(tune_grid)
245245
export(val_class_and_single)
246246
export(val_class_or_null)
247+
import(rlang)
247248
import(vctrs)
248249
import(workflows)
249250
importFrom(GPfit,GP_fit)

R/0_imports.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
#' @importFrom cli cli_inform cli_warn cli_abort qty
2323
#' @importFrom foreach foreach getDoParName %dopar%
2424
#' @importFrom tibble obj_sum size_sum
25-
25+
#' @import rlang
2626

2727
# ------------------------------------------------------------------------------
2828
# Only a small number of functions in workflows.

R/acquisition.R

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,8 @@ print.prob_improve <- function(x, ...) {
7474
#' @export
7575
predict.prob_improve <-
7676
function(object, new_data, maximize, iter, best, ...) {
77-
check_direction(maximize)
78-
check_best(best)
77+
check_bool(maximize)
78+
check_number_decimal(best, allow_infinite = FALSE)
7979

8080
if (is.function(object$trade_off)) {
8181
trade_off <- object$trade_off(iter)
@@ -126,8 +126,8 @@ exp_improve <- function(trade_off = 0, eps = .Machine$double.eps) {
126126

127127
#' @export
128128
predict.exp_improve <- function(object, new_data, maximize, iter, best, ...) {
129-
check_direction(maximize)
130-
check_best(best)
129+
check_bool(maximize)
130+
check_number_decimal(best, allow_infinite = FALSE)
131131

132132
if (is.function(object$trade_off)) {
133133
trade_off <- object$trade_off(iter)
@@ -177,7 +177,7 @@ conf_bound <- function(kappa = 0.1) {
177177

178178
#' @export
179179
predict.conf_bound <- function(object, new_data, maximize, iter, ...) {
180-
check_direction(maximize)
180+
check_bool(maximize)
181181

182182
if (is.function(object$kappa)) {
183183
kappa <- object$kappa(iter)

R/checks.R

Lines changed: 10 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -493,26 +493,6 @@ get_objective_name <- function(x, metrics) {
493493
x
494494
}
495495

496-
497-
# ------------------------------------------------------------------------------
498-
# acq functions
499-
500-
check_direction <- function(x) {
501-
if (!is.logical(x) || length(x) != 1) {
502-
rlang::abort("`maximize` should be a single logical.")
503-
}
504-
invisible(NULL)
505-
}
506-
507-
508-
check_best <- function(x) {
509-
if (!is.numeric(x) || length(x) != 1 || is.na(x)) {
510-
rlang::abort("`best` should be a single, non-missing numeric.")
511-
}
512-
invisible(NULL)
513-
}
514-
515-
516496
# ------------------------------------------------------------------------------
517497

518498
check_class_or_null <- function(x, cls = "numeric") {
@@ -537,6 +517,7 @@ val_class_or_null <- function(x, cls = "numeric", where = NULL) {
537517
}
538518
invisible(NULL)
539519
}
520+
# TODO remove this once finetune is updated
540521

541522
check_class_and_single <- function(x, cls = "numeric") {
542523
isTRUE(inherits(x, cls) & length(x) == 1)
@@ -558,7 +539,7 @@ val_class_and_single <- function(x, cls = "numeric", where = NULL) {
558539
}
559540
invisible(NULL)
560541
}
561-
542+
# TODO remove this once finetune is updated
562543

563544
# Check the data going into the GP. If there are all missing values, fail. If some
564545
# are missing, remove them and send a warning. If all metrics are the same, fail.
@@ -644,3 +625,11 @@ check_eval_time <- function(eval_time, metrics) {
644625
invisible(NULL)
645626

646627
}
628+
629+
check_time_limit_arg <- function(x, call = rlang::caller_env()) {
630+
if (!inherits(x, c("logical", "numeric")) || length(x) != 1L) {
631+
cli::cli_abort("{.arg time_limit} should be either a single numeric or
632+
logical value.", call = call)
633+
}
634+
invisible(NULL)
635+
}

R/compute_metrics.R

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ compute_metrics.tune_results <- function(x,
8787
summarize = TRUE,
8888
event_level = "first") {
8989
rlang::check_dots_empty()
90+
check_bool(summarize)
9091
if (!".predictions" %in% names(x)) {
9192
rlang::abort(paste0(
9293
"`x` must have been generated with the ",
@@ -114,10 +115,6 @@ compute_metrics.tune_results <- function(x,
114115
))
115116
}
116117

117-
if (!inherits(summarize, "logical") || length(summarize) != 1L) {
118-
rlang::abort("The `summarize` argument must be a single logical value.")
119-
}
120-
121118
param_names <- .get_tune_parameter_names(x)
122119
outcome_name <- .get_tune_outcome_names(x)
123120

R/control.R

Lines changed: 27 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -38,15 +38,15 @@ control_grid <- function(verbose = FALSE, allow_par = TRUE,
3838
# Any added arguments should also be added in superset control functions
3939
# in other packages
4040

41-
# add options for seeds per resample
41+
# add options for seeds per resample
42+
check_bool(verbose)
43+
check_bool(allow_par)
44+
check_bool(save_pred)
45+
check_bool(save_workflow)
46+
check_string(event_level)
47+
check_character(pkgs, allow_null = TRUE)
48+
check_function(extract, allow_null = TRUE)
4249

43-
val_class_and_single(verbose, "logical", "control_grid()")
44-
val_class_and_single(allow_par, "logical", "control_grid()")
45-
val_class_and_single(save_pred, "logical", "control_grid()")
46-
val_class_and_single(save_workflow, "logical", "control_grid()")
47-
val_class_and_single(event_level, "character", "control_grid()")
48-
val_class_or_null(pkgs, "character", "control_grid()")
49-
val_class_or_null(extract, "function", "control_grid()")
5050
val_parallel_over(parallel_over, "control_grid()")
5151

5252

@@ -241,26 +241,27 @@ control_bayes <-
241241
# in other packages
242242

243243
# add options for seeds per resample
244+
check_bool(verbose)
245+
check_bool(verbose_iter)
246+
check_bool(allow_par)
247+
check_bool(save_pred)
248+
check_bool(save_workflow)
249+
check_bool(save_gp_scoring)
250+
check_character(pkgs, allow_null = TRUE)
251+
check_function(extract, allow_null = TRUE)
252+
check_number_whole(no_improve, min = 0, allow_infinite = TRUE)
253+
check_number_whole(uncertain, min = 0, allow_infinite = TRUE)
254+
check_number_whole(seed)
255+
256+
check_time_limit_arg(time_limit)
244257

245-
val_class_and_single(verbose, "logical", "control_bayes()")
246-
val_class_and_single(verbose_iter, "logical", "control_bayes()")
247-
val_class_and_single(save_pred, "logical", "control_bayes()")
248-
val_class_and_single(save_gp_scoring, "logical", "control_bayes()")
249-
val_class_and_single(save_workflow, "logical", "control_bayes()")
250-
val_class_and_single(no_improve, c("numeric", "integer"), "control_bayes()")
251-
val_class_and_single(uncertain, c("numeric", "integer"), "control_bayes()")
252-
val_class_and_single(seed, c("numeric", "integer"), "control_bayes()")
253-
val_class_or_null(extract, "function", "control_bayes()")
254-
val_class_and_single(time_limit, c("logical", "numeric"), "control_bayes()")
255-
val_class_or_null(pkgs, "character", "control_bayes()")
256-
val_class_and_single(event_level, "character", "control_bayes()")
257258
val_parallel_over(parallel_over, "control_bayes()")
258-
val_class_and_single(allow_par, "logical", "control_bayes()")
259259

260260

261261
if (!is.infinite(uncertain) && uncertain > no_improve) {
262-
cli::cli_alert_warning(
263-
"Uncertainty sample scheduled after {uncertain} poor iterations but the search will stop after {no_improve}."
262+
cli::cli_warn(
263+
"Uncertainty sample scheduled after {uncertain} poor iterations but the
264+
search will stop after {no_improve}."
264265
)
265266
}
266267

@@ -296,13 +297,11 @@ print.control_bayes <- function(x, ...) {
296297
# ------------------------------------------------------------------------------
297298

298299
val_parallel_over <- function(parallel_over, where) {
299-
if (is.null(parallel_over)) {
300-
return(invisible(NULL))
300+
check_string(parallel_over, allow_null = TRUE)
301+
if (!is.null(parallel_over)) {
302+
rlang::arg_match0(parallel_over, c("resamples", "everything"), "parallel_over")
301303
}
302304

303-
val_class_and_single(parallel_over, "character", where)
304-
rlang::arg_match0(parallel_over, c("resamples", "everything"), "parallel_over")
305-
306305
invisible(NULL)
307306
}
308307

R/extract.R

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,7 @@ extract_spec_parsnip.tune_results <- function(x, ...) {
109109
#' @rdname extract-tune
110110
extract_recipe.tune_results <- function(x, ..., estimated = TRUE) {
111111
check_empty_dots(...)
112-
if (!rlang::is_bool(estimated)) {
113-
rlang::abort("`estimated` must be a single `TRUE` or `FALSE`.")
114-
}
112+
check_bool(estimated)
115113
extract_recipe(extract_workflow(x), estimated = estimated)
116114
}
117115
check_empty_dots <- function(...) {

0 commit comments

Comments
 (0)