Skip to content

Commit 36cadfe

Browse files
committed
fix: proper task compability checking
1 parent 8de32c2 commit 36cadfe

File tree

3 files changed

+40
-2
lines changed

3 files changed

+40
-2
lines changed

R/Filter.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,8 +164,8 @@ Filter = R6Class("Filter",
164164

165165
fn = task$feature_names
166166

167-
if (!is_scalar_na(self$task_types) && task$task_type %nin% self$task_types) {
168-
stopf("Filter '%s' does not support the type '%s' of task '%s'",
167+
if (!is_scalar_na(self$task_types) && !some(self$task_types, test_matching_task_type, object = task, class = "learner")) {
168+
stopf("Filter '%s' not compatible with type '%s' of task '%s'",
169169
self$id, task$task_type, task$id)
170170
}
171171

R/helper.R

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,23 @@ call_praznik = function(self, task, fun, nfeat) {
1212
catn = function(..., file = "") {
1313
cat(paste0(..., collapse = "\n"), "\n", sep = "", file = file)
1414
}
15+
16+
17+
test_matching_task_type = function(task_type, object, class) {
18+
fget = function(tab, i, j, key) {
19+
x = tab[[key]]
20+
tab[[j]][x %chin% i]
21+
}
22+
23+
if (is.null(task_type) || object$task_type == task_type) {
24+
return(TRUE)
25+
}
26+
27+
cl_task_type = fget(mlr_reflections$task_types, task_type, class, "type")
28+
if (inherits(object, cl_task_type)) {
29+
return(TRUE)
30+
}
31+
32+
cl_object = fget(mlr_reflections$task_types, object$task_type, class, "type")
33+
return(cl_task_type == cl_object)
34+
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
skip_if_not_installed("mlr3spatiotempcv")
2+
skip_on_cran()
3+
4+
test_that("task detection works with mlr3spatiotempcv tasks", {
5+
pkg = "mlr3spatiotempcv"
6+
library(pkg, character.only = TRUE) # FIXME: replace with requireNamespace()
7+
task = tsk("ecuador")
8+
learner = lrn("classif.rpart")
9+
10+
filter = flt("importance", learner = learner)
11+
expect_filter(filter$calculate(task))
12+
13+
filter = flt("variance")
14+
expect_filter(filter$calculate(task))
15+
16+
filter = flt("mim")
17+
expect_filter(filter$calculate(task))
18+
})

0 commit comments

Comments
 (0)