diff --git a/DESCRIPTION b/DESCRIPTION
index 40ca6dcb..d8eaab9f 100644
--- a/DESCRIPTION
+++ b/DESCRIPTION
@@ -27,7 +27,7 @@ Depends:
mlr3 (>= 0.23.0),
R (>= 3.1.0)
Imports:
- bbotk (>= 1.5.0.9000),
+ bbotk (>= 1.6.0),
checkmate (>= 2.0.0),
cli,
data.table,
@@ -47,6 +47,8 @@ Suggests:
rush,
rush (>= 0.2.0),
testthat (>= 3.0.0)
+Remotes:
+ mlr-org/mlr3@aggregate_fast
Config/testthat/edition: 3
Config/testthat/parallel: false
Encoding: UTF-8
diff --git a/R/ObjectiveFSelectAsync.R b/R/ObjectiveFSelectAsync.R
index 2aca31fc..aea69cd8 100644
--- a/R/ObjectiveFSelectAsync.R
+++ b/R/ObjectiveFSelectAsync.R
@@ -16,6 +16,37 @@
#' @export
ObjectiveFSelectAsync = R6Class("ObjectiveFSelectAsync",
inherit = ObjectiveFSelect,
+ public = list(
+ #' @description
+ #' Creates a new instance of this [R6][R6::R6Class] class.
+ initialize = function(
+ task,
+ learner,
+ resampling,
+ measures,
+ check_values = TRUE,
+ store_benchmark_result = TRUE,
+ store_models = FALSE,
+ callbacks = NULL
+ ) {
+ super$initialize(
+ task = task,
+ learner = learner,
+ resampling = resampling,
+ measures = measures,
+ store_benchmark_result = store_benchmark_result,
+ store_models = store_models,
+ check_values = check_values,
+ callbacks = callbacks
+ )
+ measure_properties = unlist(map(self$measures, "properties"))
+ if (self$codomain$length == 1 && all(c("requires_task", "requires_learner", "requires_model", "requires_train_set") %nin% measure_properties)) {
+ private$.aggregator = async_aggregator_fast
+ } else {
+ private$.aggregator = async_aggregator_default
+ }
+ }
+ ),
private = list(
.eval = function(xs, resampling) {
lg$debug("Evaluating feature subset %s", as_short_string(xs))
@@ -37,8 +68,8 @@ ObjectiveFSelectAsync = R6Class("ObjectiveFSelectAsync",
lg$debug("Aggregating performance")
- # aggregate performance
- private$.aggregated_performance = as.list(private$.resample_result$aggregate(self$measures))
+ # aggregate performance using the appropriate aggregator
+ private$.aggregated_performance = as.list(private$.aggregator(private$.resample_result, self$measures))
lg$debug("Aggregated performance %s", as_short_string(private$.aggregated_performance))
@@ -61,6 +92,15 @@ ObjectiveFSelectAsync = R6Class("ObjectiveFSelectAsync",
.xs = NULL,
.resample_result = NULL,
- .aggregated_performance = NULL
+ .aggregated_performance = NULL,
+ .aggregator = NULL
)
)
+
+async_aggregator_default = function(resample_result, measures) {
+ resample_result$aggregate(measures)
+}
+
+async_aggregator_fast = function(resample_result, measures) {
+ mlr3::faggregate(resample_result, measures[[1]])
+}
diff --git a/R/ObjectiveFSelectBatch.R b/R/ObjectiveFSelectBatch.R
index 31ae0027..b2a5acd6 100644
--- a/R/ObjectiveFSelectBatch.R
+++ b/R/ObjectiveFSelectBatch.R
@@ -51,6 +51,12 @@ ObjectiveFSelectBatch = R6Class("ObjectiveFSelectBatch",
check_values = check_values,
callbacks = callbacks
)
+ measure_properties = unlist(map(self$measures, "properties"))
+ if (self$codomain$length == 1 && all(c("requires_task", "requires_learner", "requires_model", "requires_train_set") %nin% measure_properties)) {
+ private$.aggregator = aggregator_fast
+ } else {
+ private$.aggregator = aggregator_default
+ }
}
),
@@ -80,16 +86,10 @@ ObjectiveFSelectBatch = R6Class("ObjectiveFSelectBatch",
lg$debug("Aggregating performance")
# aggregate performance scores
- private$.aggregated_performance = private$.benchmark_result$aggregate(self$measures, conditions = TRUE)[, c(self$codomain$target_ids, "warnings", "errors"), with = FALSE]
+ private$.aggregated_performance = private$.aggregator(private$.benchmark_result, self$measures, self$codomain)
lg$debug("Aggregated performance %s", as_short_string(private$.aggregated_performance))
- # add runtime to evaluations
- time = map_dbl(private$.benchmark_result$resample_results$resample_result, function(rr) {
- sum(map_dbl(get_private(rr)$.data$learner_states(get_private(rr)$.view), function(state) state$train_time + state$predict_time))
- })
- set(private$.aggregated_performance, j = "runtime_learners", value = time)
-
# store benchmark result in archive
if (self$store_benchmark_result) {
lg$debug("Storing resample result")
@@ -106,6 +106,34 @@ ObjectiveFSelectBatch = R6Class("ObjectiveFSelectBatch",
.design = NULL,
.benchmark_result = NULL,
.aggregated_performance = NULL,
- .model_required = FALSE
+ .model_required = FALSE,
+ .aggregator = NULL
)
)
+
+aggregator_default = function(benchmark_result, measures, codomain) {
+ aggr = benchmark_result$aggregate(measures, conditions = TRUE)[, c(codomain$target_ids, "warnings", "errors"), with = FALSE]
+
+ # add runtime
+ data = get_private(benchmark_result)$.data$data
+ tab = data$fact[data$uhashes, c("uhash", "learner_state"), with = FALSE]
+ learner_state = NULL
+ runtime = tab[, sum(map_dbl(learner_state, function(s) sum(s$train_time + s$predict_time))), by = uhash]$V1
+ set(aggr, j = "runtime_learners", value = runtime)
+ aggr
+}
+
+aggregator_fast = function(benchmark_result, measures, codomain) {
+ aggr = faggregate(benchmark_result, measures[[1]])
+
+ # add runtime and conditions
+ data = get_private(benchmark_result)$.data$data
+ tab = data$fact[data$uhashes, c("uhash", "learner_state"), with = FALSE]
+
+ learner_state = NULL
+ aggr[tab[, list(
+ errors = sum(map_int(learner_state, function(s) sum(s$log$class == "error"))),
+ warnings = sum(map_int(learner_state, function(s) sum(s$log$class == "warning"))),
+ runtime_learners = sum(map_dbl(learner_state, function(s) sum(s$train_time + s$predict_time)))
+ ), by = uhash], on = "uhash"]
+}
diff --git a/R/helper.R b/R/helper.R
index cd9b1117..e41ffaf1 100644
--- a/R/helper.R
+++ b/R/helper.R
@@ -19,7 +19,7 @@ measures_to_codomain = function(measures) {
}
extract_runtime = function(resample_result) {
- runtimes = map_dbl(get_private(resample_result)$.data$learner_states(get_private(resample_result)$.view), function(state) {
+ runtimes = map_dbl(get_private(resample_result)$.data$learner_states(), function(state) {
state$train_time + state$predict_time
})
sum(runtimes)
diff --git a/inst/testthat/helper_misc.R b/inst/testthat/helper_misc.R
index b07b2437..68d3ca5a 100644
--- a/inst/testthat/helper_misc.R
+++ b/inst/testthat/helper_misc.R
@@ -37,7 +37,7 @@ MeasureDummy = R6Class("MeasureDummy", inherit = MeasureRegr,
)
}
private$.score_design = score_design
- super$initialize(id = "dummy", range = c(0, 4), minimize = minimize)
+ super$initialize(id = "dummy", range = c(0, 4), minimize = minimize, properties = c("requires_task", "requires_learner"))
}
),
private = list(
diff --git a/man-roxygen/param_aggregate_fast.R b/man-roxygen/param_aggregate_fast.R
new file mode 100644
index 00000000..91ba62fb
--- /dev/null
+++ b/man-roxygen/param_aggregate_fast.R
@@ -0,0 +1,5 @@
+#' @param aggregate_fast (`logical(1)`)\cr
+#' If `TRUE`, the performance values are aggregated in a fast way.
+#' This is only supported for measures that do not require task, learner, model or train set.
+#' The archive does not contain warnings and errors.
+#' Default is `FALSE`.
diff --git a/man/ObjectiveFSelectAsync.Rd b/man/ObjectiveFSelectAsync.Rd
index b9aad69c..3bbaf11b 100644
--- a/man/ObjectiveFSelectAsync.Rd
+++ b/man/ObjectiveFSelectAsync.Rd
@@ -13,6 +13,7 @@ This class is usually constructed internally by the \link{FSelectInstanceAsyncSi
\section{Methods}{
\subsection{Public methods}{
\itemize{
+\item \href{#method-ObjectiveFSelectAsync-new}{\code{ObjectiveFSelectAsync$new()}}
\item \href{#method-ObjectiveFSelectAsync-clone}{\code{ObjectiveFSelectAsync$clone()}}
}
}
@@ -25,11 +26,62 @@ This class is usually constructed internally by the \link{FSelectInstanceAsyncSi
bbotk::Objective$format()
bbotk::Objective$help()
bbotk::Objective$print()
-mlr3fselect::ObjectiveFSelect$initialize()
}}
\if{html}{\out{
}}
+\if{html}{\out{}}
+\if{latex}{\out{\hypertarget{method-ObjectiveFSelectAsync-new}{}}}
+\subsection{Method \code{new()}}{
+Creates a new instance of this \link[R6:R6Class]{R6} class.
+\subsection{Usage}{
+\if{html}{\out{}}\preformatted{ObjectiveFSelectAsync$new(
+ task,
+ learner,
+ resampling,
+ measures,
+ check_values = TRUE,
+ store_benchmark_result = TRUE,
+ store_models = FALSE,
+ callbacks = NULL
+)}\if{html}{\out{
}}
+}
+
+\subsection{Arguments}{
+\if{html}{\out{}}
+\describe{
+\item{\code{task}}{(\link[mlr3:Task]{mlr3::Task})\cr
+Task to operate on.}
+
+\item{\code{learner}}{(\link[mlr3:Learner]{mlr3::Learner})\cr
+Learner to optimize the feature subset for.}
+
+\item{\code{resampling}}{(\link[mlr3:Resampling]{mlr3::Resampling})\cr
+Resampling that is used to evaluated the performance of the feature subsets.
+Uninstantiated resamplings are instantiated during construction so that all feature subsets are evaluated on the same data splits.
+Already instantiated resamplings are kept unchanged.}
+
+\item{\code{measures}}{(list of \link[mlr3:Measure]{mlr3::Measure})\cr
+Measures to optimize.
+If \code{NULL}, \CRANpkg{mlr3}'s default measure is used.}
+
+\item{\code{check_values}}{(\code{logical(1)})\cr
+Check the parameters before the evaluation and the results for
+validity?}
+
+\item{\code{store_benchmark_result}}{(\code{logical(1)})\cr
+Store benchmark result in archive?}
+
+\item{\code{store_models}}{(\code{logical(1)}).
+Store models in benchmark result?}
+
+\item{\code{callbacks}}{(list of \link{CallbackBatchFSelect})\cr
+List of callbacks.}
+}
+\if{html}{\out{
}}
+}
+}
+\if{html}{\out{
}}
\if{html}{\out{}}
\if{latex}{\out{\hypertarget{method-ObjectiveFSelectAsync-clone}{}}}
\subsection{Method \code{clone()}}{
diff --git a/man/mlr_fselectors_async_random_search.Rd b/man/mlr_fselectors_async_random_search.Rd
index c67a7f7f..3fde6cfd 100644
--- a/man/mlr_fselectors_async_random_search.Rd
+++ b/man/mlr_fselectors_async_random_search.Rd
@@ -13,11 +13,6 @@ Bergstra J, Bengio Y (2012).
\description{
Feature selection using Asynchronous Random Search Algorithm.
}
-\details{
-The feature sets are randomly drawn.
-The sets are evaluated asynchronously.
-The algorithm uses \link[bbotk:mlr_optimizers_async_random_search]{bbotk::OptimizerAsyncRandomSearch} for optimization.
-}
\section{Dictionary}{
This \link{FSelector} can be instantiated with the associated sugar function \code{\link[=fs]{fs()}}:
diff --git a/tests/testthat/test_FSelectInstanceAsyncSingleCrit.R b/tests/testthat/test_FSelectInstanceAsyncSingleCrit.R
index 7210d6c4..84bd2205 100644
--- a/tests/testthat/test_FSelectInstanceAsyncSingleCrit.R
+++ b/tests/testthat/test_FSelectInstanceAsyncSingleCrit.R
@@ -172,3 +172,66 @@ test_that("saving the models with FSelectInstanceAsyncSingleCrit works", {
# fselector$optimize(instance)
# })
+
+
+test_that("fast aggregation and benchmark result produce the same scores", {
+ skip_on_cran()
+ skip_if_not_installed("rush")
+ flush_redis()
+
+ on.exit(mirai::daemons(0))
+ mirai::daemons(1)
+ rush::rush_plan(n_workers = 1, worker_type = "remote")
+
+ instance = fsi_async(
+ task = tsk("pima"),
+ learner = lrn("classif.rpart"),
+ resampling = rsmp("cv", folds = 3),
+ measures = msr("classif.ce"),
+ terminator = trm("evals", n_evals = 6)
+ )
+
+ fselector = fs("async_random_search")
+ fselector$optimize(instance)
+
+ expect_equal(get_private(instance$objective)$.aggregator, async_aggregator_fast)
+
+ expect_equal(instance$archive$data$classif.ce,
+ instance$archive$benchmark_result$aggregate(msr("classif.ce"))$classif.ce)
+})
+
+test_that("fast aggregation and benchmark result produce the same conditions", {
+ skip_on_cran()
+ skip_if_not_installed("rush")
+ flush_redis()
+
+ on.exit(mirai::daemons(0))
+ mirai::daemons(1)
+ rush::rush_plan(n_workers = 1, worker_type = "remote")
+
+
+ learner = lrn("classif.debug", error_train = 0.5, warning_train = 0.5)
+ learner$encapsulate("callr", fallback = lrn("classif.debug"))
+
+ instance = fsi_async(
+ task = tsk("pima"),
+ learner = learner,
+ resampling = rsmp("cv", folds = 3),
+ measures = msr("classif.ce"),
+ terminator = trm("evals", n_evals = 6)
+ )
+
+ fselector = fs("async_random_search")
+ fselector$optimize(instance)
+
+ expect_equal(get_private(instance$objective)$.aggregator, async_aggregator_fast)
+
+ expect_equal(instance$archive$data$classif.ce,
+ instance$archive$benchmark_result$aggregate(msr("classif.ce"))$classif.ce)
+
+ expect_equal(instance$archive$data$errors,
+ instance$archive$benchmark_result$aggregate(msr("classif.ce"), conditions = TRUE)$errors)
+
+ expect_equal(instance$archive$data$warnings,
+ instance$archive$benchmark_result$aggregate(msr("classif.ce"), conditions = TRUE)$warnings)
+})
diff --git a/tests/testthat/test_FSelectInstanceSingleCrit.R b/tests/testthat/test_FSelectInstanceSingleCrit.R
index 385c2ecd..9cfc6748 100644
--- a/tests/testthat/test_FSelectInstanceSingleCrit.R
+++ b/tests/testthat/test_FSelectInstanceSingleCrit.R
@@ -132,3 +132,48 @@ test_that("objective contains no benchmark results", {
expect_null(instance$objective$.__enclos_env__$private$.benchmark_result)
})
+
+test_that("fast aggregation and benchmark result produce the same scores", {
+ instance = fsi(
+ task = tsk("pima"),
+ learner = lrn("classif.rpart"),
+ resampling = rsmp("cv", folds = 3),
+ measures = msr("classif.ce"),
+ terminator = trm("evals", n_evals = 6)
+ )
+
+ fselector = fs("random_search", batch_size = 2)
+ fselector$optimize(instance)
+
+ expect_equal(get_private(instance$objective)$.aggregator, aggregator_fast)
+
+ expect_equal(instance$archive$data$classif.ce,
+ instance$archive$benchmark_result$aggregate(msr("classif.ce"))$classif.ce)
+})
+
+test_that("fast aggregation and benchmark result produce the same conditions", {
+ learner = lrn("classif.debug", error_train = 0.5, warning_train = 0.5)
+ learner$encapsulate("callr", fallback = lrn("classif.debug"))
+
+ instance = fsi(
+ task = tsk("pima"),
+ learner = learner,
+ resampling = rsmp("cv", folds = 3),
+ measures = msr("classif.ce"),
+ terminator = trm("evals", n_evals = 6)
+ )
+
+ fselector = fs("random_search", batch_size = 2)
+ fselector$optimize(instance)
+
+ expect_equal(get_private(instance$objective)$.aggregator, aggregator_fast)
+
+ expect_equal(instance$archive$data$classif.ce,
+ instance$archive$benchmark_result$aggregate(msr("classif.ce"))$classif.ce)
+
+ expect_equal(instance$archive$data$errors,
+ instance$archive$benchmark_result$aggregate(msr("classif.ce"), conditions = TRUE)$errors)
+
+ expect_equal(instance$archive$data$warnings,
+ instance$archive$benchmark_result$aggregate(msr("classif.ce"), conditions = TRUE)$warnings)
+})
diff --git a/tests/testthat/test_fselect.R b/tests/testthat/test_fselect.R
index 7cf0fad7..98f422ab 100644
--- a/tests/testthat/test_fselect.R
+++ b/tests/testthat/test_fselect.R
@@ -38,7 +38,7 @@ test_that("fselect interface is equal to FSelectInstanceBatchSingleCrit", {
test_that("fselect interface is equal to FSelectInstanceBatchMultiCrit", {
fselect_args = formalArgs(fselect)
- fselect_args = fselect_args[fselect_args %nin% c("fselector", "ties_method")]
+ fselect_args = fselect_args[fselect_args %nin% c("fselector", "ties_method", "aggregate_fast")]
instance_args = formalArgs(FSelectInstanceBatchMultiCrit$public_methods$initialize)
instance_args = c(instance_args, "term_evals", "term_time", "rush")
diff --git a/tests/testthat/test_fsi.R b/tests/testthat/test_fsi.R
index e327dfe8..29ad5caf 100644
--- a/tests/testthat/test_fsi.R
+++ b/tests/testthat/test_fsi.R
@@ -43,7 +43,7 @@ test_that("fsi and FSelectInstanceBatchSingleCrit are equal", {
test_that("fsi and FSelectInstanceBatchMultiCrit are equal", {
fsi_args = formalArgs(fsi)
- fsi_args = fsi_args[fsi_args != "ties_method"]
+ fsi_args = fsi_args[fsi_args %nin% c("ties_method", "aggregate_fast")]
expect_equal(fsi_args, formalArgs(FSelectInstanceBatchMultiCrit$public_methods$initialize))
diff --git a/tests/testthat/test_fsi_async.R b/tests/testthat/test_fsi_async.R
index fca2c7aa..ffc0f696 100644
--- a/tests/testthat/test_fsi_async.R
+++ b/tests/testthat/test_fsi_async.R
@@ -44,7 +44,7 @@ test_that("fsi_async interface is equal to FSelectInstanceAsyncMultiCrit", {
flush_redis()
fsi_args = formalArgs(fsi_async)
- fsi_args = fsi_args[fsi_args != "ties_method"]
+ fsi_args = fsi_args[fsi_args %nin% c("ties_method", "aggregate_fast")]
instance_args = formalArgs(FSelectInstanceAsyncMultiCrit$public_methods$initialize)
expect_equal(fsi_args, instance_args)