Skip to content

Commit 51db99e

Browse files
authored
test compute_grid_info() (#948)
1 parent f8d734a commit 51db99e

File tree

1 file changed

+171
-0
lines changed

1 file changed

+171
-0
lines changed

tests/testthat/test-grid_helpers.R

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
test_that("compute_grid_info - recipe only", {
2+
library(workflows)
3+
library(recipes)
4+
library(parsnip)
5+
library(dials)
6+
7+
rec <- recipe(mpg ~ ., mtcars) %>% step_spline_natural(deg_free = tune())
8+
9+
wflow <- workflow()
10+
wflow <- add_model(wflow, boost_tree(mode = "regression"))
11+
wflow <- add_recipe(wflow, rec)
12+
13+
grid <- grid_space_filling(extract_parameter_set_dials(wflow))
14+
res <- compute_grid_info(wflow, grid)
15+
16+
expect_equal(res$.iter_preprocessor, 1:5)
17+
expect_equal(res$.msg_preprocessor, paste0("preprocessor ", 1:5, "/5"))
18+
expect_equal(res$deg_free, grid$deg_free)
19+
expect_equal(res$.iter_model, rep(1, 5))
20+
expect_equal(res$.iter_config, as.list(paste0("Preprocessor", 1:5, "_Model1")))
21+
expect_equal(res$.msg_model, paste0("preprocessor ", 1:5, "/5, model 1/1"))
22+
expect_equal(res$.submodels, list(list(), list(), list(), list(), list()))
23+
expect_named(
24+
res,
25+
c(".iter_preprocessor", ".msg_preprocessor", "deg_free", ".iter_model",
26+
".iter_config", ".msg_model", ".submodels"),
27+
ignore.order = TRUE
28+
)
29+
expect_equal(nrow(res), 5)
30+
})
31+
32+
test_that("compute_grid_info - model only (no submodels)", {
33+
library(workflows)
34+
library(parsnip)
35+
library(dials)
36+
37+
spec <- boost_tree(mode = "regression", learn_rate = tune())
38+
39+
wflow <- workflow()
40+
wflow <- add_model(wflow, spec)
41+
wflow <- add_formula(wflow, mpg ~ .)
42+
43+
grid <- grid_space_filling(extract_parameter_set_dials(wflow))
44+
res <- compute_grid_info(wflow, grid)
45+
46+
expect_equal(res$.iter_preprocessor, rep(1, 5))
47+
expect_equal(res$.msg_preprocessor, rep("preprocessor 1/1", 5))
48+
expect_equal(res$learn_rate, grid$learn_rate)
49+
expect_equal(res$.iter_model, 1:5)
50+
expect_equal(res$.iter_config, as.list(paste0("Preprocessor1_Model", 1:5)))
51+
expect_equal(res$.msg_model, paste0("preprocessor 1/1, model ", 1:5, "/5"))
52+
expect_equal(res$.submodels, list(list(), list(), list(), list(), list()))
53+
expect_named(
54+
res,
55+
c(".iter_preprocessor", ".msg_preprocessor", "learn_rate", ".iter_model",
56+
".iter_config", ".msg_model", ".submodels"),
57+
ignore.order = TRUE
58+
)
59+
expect_equal(nrow(res), 5)
60+
})
61+
62+
test_that("compute_grid_info - model only (with submodels)", {
63+
library(workflows)
64+
library(parsnip)
65+
library(dials)
66+
67+
spec <- boost_tree(mode = "regression", trees = tune())
68+
69+
wflow <- workflow()
70+
wflow <- add_model(wflow, spec)
71+
wflow <- add_formula(wflow, mpg ~ .)
72+
73+
grid <- grid_space_filling(extract_parameter_set_dials(wflow))
74+
res <- compute_grid_info(wflow, grid)
75+
76+
expect_equal(res$.iter_preprocessor, 1)
77+
expect_equal(res$.msg_preprocessor, "preprocessor 1/1")
78+
expect_equal(res$trees, max(grid$trees))
79+
expect_equal(res$.iter_model, 1)
80+
expect_equal(res$.iter_config, list(paste0("Preprocessor1_Model", 1:5)))
81+
expect_equal(res$.msg_model, "preprocessor 1/1, model 1/1")
82+
expect_equal(res$.submodels, list(list(trees = grid$trees[-which.max(grid$trees)])))
83+
expect_named(
84+
res,
85+
c(".iter_preprocessor", ".msg_preprocessor", "trees", ".iter_model",
86+
".iter_config", ".msg_model", ".submodels"),
87+
ignore.order = TRUE
88+
)
89+
expect_equal(nrow(res), 1)
90+
})
91+
92+
test_that("compute_grid_info - recipe and model (no submodels)", {
93+
library(workflows)
94+
library(parsnip)
95+
library(recipes)
96+
library(dials)
97+
98+
rec <- recipe(mpg ~ ., mtcars) %>% step_spline_natural(deg_free = tune())
99+
spec <- boost_tree(mode = "regression", learn_rate = tune())
100+
101+
wflow <- workflow()
102+
wflow <- add_model(wflow, spec)
103+
wflow <- add_recipe(wflow, rec)
104+
105+
grid <- grid_space_filling(extract_parameter_set_dials(wflow))
106+
res <- compute_grid_info(wflow, grid)
107+
108+
expect_equal(res$.iter_preprocessor, 1:5)
109+
expect_equal(res$.msg_preprocessor, paste0("preprocessor ", 1:5, "/5"))
110+
expect_equal(res$learn_rate, grid$learn_rate)
111+
expect_equal(res$deg_free, grid$deg_free)
112+
expect_equal(res$.iter_model, rep(1, 5))
113+
expect_equal(res$.iter_config, as.list(paste0("Preprocessor", 1:5, "_Model1")))
114+
expect_equal(res$.msg_model, paste0("preprocessor ", 1:5, "/5, model 1/1"))
115+
expect_equal(res$.submodels, list(list(), list(), list(), list(), list()))
116+
expect_named(
117+
res,
118+
c(".iter_preprocessor", ".msg_preprocessor", "deg_free", "learn_rate",
119+
".iter_model", ".iter_config", ".msg_model", ".submodels"),
120+
ignore.order = TRUE
121+
)
122+
expect_equal(nrow(res), 5)
123+
})
124+
125+
test_that("compute_grid_info - recipe and model (with submodels)", {
126+
library(workflows)
127+
library(parsnip)
128+
library(recipes)
129+
library(dials)
130+
131+
rec <- recipe(mpg ~ ., mtcars) %>% step_spline_natural(deg_free = tune())
132+
spec <- boost_tree(mode = "regression", trees = tune())
133+
134+
wflow <- workflow()
135+
wflow <- add_model(wflow, spec)
136+
wflow <- add_recipe(wflow, rec)
137+
138+
# use grid_regular to trigger submodel trick
139+
set.seed(1)
140+
grid <- grid_regular(extract_parameter_set_dials(wflow))
141+
res <- compute_grid_info(wflow, grid)
142+
143+
expect_equal(res$.iter_preprocessor, 1:3)
144+
expect_equal(res$.msg_preprocessor, paste0("preprocessor ", 1:3, "/3"))
145+
expect_equal(res$trees, rep(max(grid$trees), 3))
146+
expect_equal(res$.iter_model, rep(1, 3))
147+
expect_equal(
148+
res$.iter_config,
149+
list(
150+
c("Preprocessor1_Model1", "Preprocessor1_Model2", "Preprocessor1_Model3"),
151+
c("Preprocessor2_Model1", "Preprocessor2_Model2", "Preprocessor2_Model3"),
152+
c("Preprocessor3_Model1", "Preprocessor3_Model2", "Preprocessor3_Model3")
153+
)
154+
)
155+
expect_equal(res$.msg_model, paste0("preprocessor ", 1:3, "/3, model 1/1"))
156+
expect_equal(
157+
res$.submodels,
158+
list(
159+
list(trees = c(1L, 1000L)),
160+
list(trees = c(1L, 1000L)),
161+
list(trees = c(1L, 1000L))
162+
)
163+
)
164+
expect_named(
165+
res,
166+
c(".iter_preprocessor", ".msg_preprocessor", "deg_free", "trees",
167+
".iter_model", ".iter_config", ".msg_model", ".submodels"),
168+
ignore.order = TRUE
169+
)
170+
expect_equal(nrow(res), 3)
171+
})

0 commit comments

Comments
 (0)