Skip to content

Commit f8d734a

Browse files
authored
transition from add_tailor(prop) and method (#945)
1 parent 18442b2 commit f8d734a

File tree

3 files changed

+28
-19
lines changed

3 files changed

+28
-19
lines changed

R/grid_code_paths.R

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,7 @@ tune_grid_loop_iter <- function(split,
384384
assessment_rows <- as.integer(split, data = "assessment")
385385
assessment <- vctrs::vec_slice(split$data, assessment_rows)
386386

387-
if (workflows::.should_inner_split(workflow)) {
387+
if (workflows::.workflow_includes_calibration(workflow)) {
388388
# if the workflow has a postprocessor that needs training (i.e. calibration),
389389
# further split the analysis data into an "inner" analysis and
390390
# assessment set.
@@ -397,11 +397,6 @@ tune_grid_loop_iter <- function(split,
397397
# calibration set
398398
# * the model (including the post-processor) generates predictions on the
399399
# assessment set and those predictions are assessed with performance metrics
400-
# todo: check if workflow's `method` is incompatible with `class(split)`?
401-
# todo: workflow's `method` is currently ignored in favor of the one
402-
# automatically dispatched to from `split`. consider this is combination
403-
# with above todo.
404-
split_args <- c(split_args, list(prop = workflow$post$actions$tailor$prop))
405400
split <- rsample::inner_split(split, split_args = split_args)
406401
analysis <- rsample::analysis(split)
407402

tests/testthat/test-last-fit.R

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -246,9 +246,7 @@ test_that("can use `last_fit()` with a workflow - postprocessor (requires traini
246246
parsnip::linear_reg()
247247
) %>%
248248
workflows::add_tailor(
249-
tailor::tailor("regression") %>% tailor::adjust_numeric_calibration("linear"),
250-
prop = 2/3,
251-
method = class(split)
249+
tailor::tailor() %>% tailor::adjust_numeric_calibration("linear")
252250
)
253251

254252
set.seed(1)
@@ -261,13 +259,21 @@ test_that("can use `last_fit()` with a workflow - postprocessor (requires traini
261259
last_fit_preds <- collect_predictions(last_fit_res)
262260

263261
set.seed(1)
264-
wflow_res <- generics::fit(wflow, rsample::analysis(split))
262+
inner_split <- rsample::inner_split(split, split_args = list())
263+
264+
set.seed(1)
265+
wflow_res <-
266+
generics::fit(
267+
wflow,
268+
rsample::analysis(inner_split),
269+
calibration = rsample::assessment(inner_split)
270+
)
265271
wflow_preds <- predict(wflow_res, rsample::assessment(split))
266272

267273
expect_equal(last_fit_preds[".pred"], wflow_preds)
268274
})
269275

270-
test_that("can use `last_fit()` with a workflow - postprocessor (requires training)", {
276+
test_that("can use `last_fit()` with a workflow - postprocessor (does not require training)", {
271277
skip_if_not_installed("tailor")
272278

273279
y <- seq(0, 7, .001)
@@ -284,9 +290,7 @@ test_that("can use `last_fit()` with a workflow - postprocessor (requires traini
284290
parsnip::linear_reg()
285291
) %>%
286292
workflows::add_tailor(
287-
tailor::tailor("regression") %>% tailor::adjust_numeric_range(lower_limit = 1),
288-
prop = 2/3,
289-
method = class(split)
293+
tailor::tailor() %>% tailor::adjust_numeric_range(lower_limit = 1)
290294
)
291295

292296
set.seed(1)

tests/testthat/test-resample.R

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -151,9 +151,7 @@ test_that("can use `fit_resamples()` with a workflow - postprocessor (requires t
151151
parsnip::linear_reg()
152152
) %>%
153153
workflows::add_tailor(
154-
tailor::tailor("regression") %>% tailor::adjust_numeric_calibration("linear"),
155-
prop = 2/3,
156-
method = class(folds$splits[[1]])
154+
tailor::tailor() %>% tailor::adjust_numeric_calibration("linear")
157155
)
158156

159157
set.seed(1)
@@ -178,8 +176,20 @@ test_that("can use `fit_resamples()` with a workflow - postprocessor (requires t
178176
seed <- generate_seeds(TRUE, 1)[[1]]
179177
old_kind <- RNGkind()[[1]]
180178
assign(".Random.seed", seed, envir = globalenv())
179+
withr::defer(RNGkind(kind = old_kind))
181180

182-
wflow_res <- generics::fit(wflow, rsample::analysis(folds$splits[[1]]))
181+
inner_split_1 <-
182+
rsample::inner_split(
183+
folds$splits[[1]],
184+
split_args = list(v = 2, repeats = 1, breaks = 4, pool = 0.1)
185+
)
186+
187+
wflow_res <-
188+
generics::fit(
189+
wflow,
190+
rsample::analysis(inner_split_1),
191+
calibration = rsample::assessment(inner_split_1)
192+
)
183193
wflow_preds <- predict(wflow_res, rsample::assessment(folds$splits[[1]]))
184194

185195
tune_wflow$fit$fit$elapsed$elapsed <- wflow_res$fit$fit$elapsed$elapsed
@@ -201,7 +211,7 @@ test_that("can use `fit_resamples()` with a workflow - postprocessor (no trainin
201211
parsnip::linear_reg()
202212
) %>%
203213
workflows::add_tailor(
204-
tailor::tailor("regression") %>% tailor::adjust_numeric_range(lower_limit = 1)
214+
tailor::tailor() %>% tailor::adjust_numeric_range(lower_limit = 1)
205215
)
206216

207217
set.seed(1)

0 commit comments

Comments
 (0)