diff --git a/R/assertions.R b/R/assertions.R index 10b195ed9..4b10a4c9e 100644 --- a/R/assertions.R +++ b/R/assertions.R @@ -207,7 +207,12 @@ assert_predictable = function(task, learner) { predict_type = fget_keys(task$col_info, i = ids, j = "type", key = "id") predict_levels = fget_keys(task$col_info, i = ids, j = "levels", key = "id") - ok = all(train_type == predict_type) && all(pmap_lgl(list(x = train_levels, y = predict_levels), identical)) + + ok = all(train_type == predict_type) + + if ("new_levels" %nin% learner$properties) { + ok = ok && all(pmap_lgl(list(x = train_levels, y = predict_levels), identical)) + } if (!ok) { stopf("Learner '%s' received task with different column info (feature type or factor level ordering) during train and predict.", learner$id) diff --git a/tests/testthat/test_Learner.R b/tests/testthat/test_Learner.R index 2cd9553aa..bb6c3fcd5 100644 --- a/tests/testthat/test_Learner.R +++ b/tests/testthat/test_Learner.R @@ -1015,3 +1015,16 @@ test_that("config error does not trigger callback", { l$encapsulate("evaluate", lrn("classif.featureless")) expect_error(l$train(tsk("iris")), regexp = "You misconfigured the learner") }) + +test_that("new_levels property is working", { + learner = lrn("classif.featureless") + task = tsk("penguins") + learner$train(task) + data = task$data() + set(data, i = 1L, j = "island", value = "NewIsland") + + expect_error(learner$predict_newdata(data), "received task with different column info") + + learner$properties = c(learner$properties, "new_levels") + expect_prediction(learner$predict_newdata(data)) +})