diff --git a/DESCRIPTION b/DESCRIPTION index 40ca6dcb..d8eaab9f 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -27,7 +27,7 @@ Depends: mlr3 (>= 0.23.0), R (>= 3.1.0) Imports: - bbotk (>= 1.5.0.9000), + bbotk (>= 1.6.0), checkmate (>= 2.0.0), cli, data.table, @@ -47,6 +47,8 @@ Suggests: rush, rush (>= 0.2.0), testthat (>= 3.0.0) +Remotes: + mlr-org/mlr3@aggregate_fast Config/testthat/edition: 3 Config/testthat/parallel: false Encoding: UTF-8 diff --git a/R/ObjectiveFSelectAsync.R b/R/ObjectiveFSelectAsync.R index 2aca31fc..aea69cd8 100644 --- a/R/ObjectiveFSelectAsync.R +++ b/R/ObjectiveFSelectAsync.R @@ -16,6 +16,37 @@ #' @export ObjectiveFSelectAsync = R6Class("ObjectiveFSelectAsync", inherit = ObjectiveFSelect, + public = list( + #' @description + #' Creates a new instance of this [R6][R6::R6Class] class. + initialize = function( + task, + learner, + resampling, + measures, + check_values = TRUE, + store_benchmark_result = TRUE, + store_models = FALSE, + callbacks = NULL + ) { + super$initialize( + task = task, + learner = learner, + resampling = resampling, + measures = measures, + store_benchmark_result = store_benchmark_result, + store_models = store_models, + check_values = check_values, + callbacks = callbacks + ) + measure_properties = unlist(map(self$measures, "properties")) + if (self$codomain$length == 1 && all(c("requires_task", "requires_learner", "requires_model", "requires_train_set") %nin% measure_properties)) { + private$.aggregator = async_aggregator_fast + } else { + private$.aggregator = async_aggregator_default + } + } + ), private = list( .eval = function(xs, resampling) { lg$debug("Evaluating feature subset %s", as_short_string(xs)) @@ -37,8 +68,8 @@ ObjectiveFSelectAsync = R6Class("ObjectiveFSelectAsync", lg$debug("Aggregating performance") - # aggregate performance - private$.aggregated_performance = as.list(private$.resample_result$aggregate(self$measures)) + # aggregate performance using the appropriate aggregator + private$.aggregated_performance = as.list(private$.aggregator(private$.resample_result, self$measures)) lg$debug("Aggregated performance %s", as_short_string(private$.aggregated_performance)) @@ -61,6 +92,15 @@ ObjectiveFSelectAsync = R6Class("ObjectiveFSelectAsync", .xs = NULL, .resample_result = NULL, - .aggregated_performance = NULL + .aggregated_performance = NULL, + .aggregator = NULL ) ) + +async_aggregator_default = function(resample_result, measures) { + resample_result$aggregate(measures) +} + +async_aggregator_fast = function(resample_result, measures) { + mlr3::faggregate(resample_result, measures[[1]]) +} diff --git a/R/ObjectiveFSelectBatch.R b/R/ObjectiveFSelectBatch.R index 31ae0027..b2a5acd6 100644 --- a/R/ObjectiveFSelectBatch.R +++ b/R/ObjectiveFSelectBatch.R @@ -51,6 +51,12 @@ ObjectiveFSelectBatch = R6Class("ObjectiveFSelectBatch", check_values = check_values, callbacks = callbacks ) + measure_properties = unlist(map(self$measures, "properties")) + if (self$codomain$length == 1 && all(c("requires_task", "requires_learner", "requires_model", "requires_train_set") %nin% measure_properties)) { + private$.aggregator = aggregator_fast + } else { + private$.aggregator = aggregator_default + } } ), @@ -80,16 +86,10 @@ ObjectiveFSelectBatch = R6Class("ObjectiveFSelectBatch", lg$debug("Aggregating performance") # aggregate performance scores - private$.aggregated_performance = private$.benchmark_result$aggregate(self$measures, conditions = TRUE)[, c(self$codomain$target_ids, "warnings", "errors"), with = FALSE] + private$.aggregated_performance = private$.aggregator(private$.benchmark_result, self$measures, self$codomain) lg$debug("Aggregated performance %s", as_short_string(private$.aggregated_performance)) - # add runtime to evaluations - time = map_dbl(private$.benchmark_result$resample_results$resample_result, function(rr) { - sum(map_dbl(get_private(rr)$.data$learner_states(get_private(rr)$.view), function(state) state$train_time + state$predict_time)) - }) - set(private$.aggregated_performance, j = "runtime_learners", value = time) - # store benchmark result in archive if (self$store_benchmark_result) { lg$debug("Storing resample result") @@ -106,6 +106,34 @@ ObjectiveFSelectBatch = R6Class("ObjectiveFSelectBatch", .design = NULL, .benchmark_result = NULL, .aggregated_performance = NULL, - .model_required = FALSE + .model_required = FALSE, + .aggregator = NULL ) ) + +aggregator_default = function(benchmark_result, measures, codomain) { + aggr = benchmark_result$aggregate(measures, conditions = TRUE)[, c(codomain$target_ids, "warnings", "errors"), with = FALSE] + + # add runtime + data = get_private(benchmark_result)$.data$data + tab = data$fact[data$uhashes, c("uhash", "learner_state"), with = FALSE] + learner_state = NULL + runtime = tab[, sum(map_dbl(learner_state, function(s) sum(s$train_time + s$predict_time))), by = uhash]$V1 + set(aggr, j = "runtime_learners", value = runtime) + aggr +} + +aggregator_fast = function(benchmark_result, measures, codomain) { + aggr = faggregate(benchmark_result, measures[[1]]) + + # add runtime and conditions + data = get_private(benchmark_result)$.data$data + tab = data$fact[data$uhashes, c("uhash", "learner_state"), with = FALSE] + + learner_state = NULL + aggr[tab[, list( + errors = sum(map_int(learner_state, function(s) sum(s$log$class == "error"))), + warnings = sum(map_int(learner_state, function(s) sum(s$log$class == "warning"))), + runtime_learners = sum(map_dbl(learner_state, function(s) sum(s$train_time + s$predict_time))) + ), by = uhash], on = "uhash"] +} diff --git a/R/helper.R b/R/helper.R index cd9b1117..e41ffaf1 100644 --- a/R/helper.R +++ b/R/helper.R @@ -19,7 +19,7 @@ measures_to_codomain = function(measures) { } extract_runtime = function(resample_result) { - runtimes = map_dbl(get_private(resample_result)$.data$learner_states(get_private(resample_result)$.view), function(state) { + runtimes = map_dbl(get_private(resample_result)$.data$learner_states(), function(state) { state$train_time + state$predict_time }) sum(runtimes) diff --git a/inst/testthat/helper_misc.R b/inst/testthat/helper_misc.R index b07b2437..68d3ca5a 100644 --- a/inst/testthat/helper_misc.R +++ b/inst/testthat/helper_misc.R @@ -37,7 +37,7 @@ MeasureDummy = R6Class("MeasureDummy", inherit = MeasureRegr, ) } private$.score_design = score_design - super$initialize(id = "dummy", range = c(0, 4), minimize = minimize) + super$initialize(id = "dummy", range = c(0, 4), minimize = minimize, properties = c("requires_task", "requires_learner")) } ), private = list( diff --git a/man-roxygen/param_aggregate_fast.R b/man-roxygen/param_aggregate_fast.R new file mode 100644 index 00000000..91ba62fb --- /dev/null +++ b/man-roxygen/param_aggregate_fast.R @@ -0,0 +1,5 @@ +#' @param aggregate_fast (`logical(1)`)\cr +#' If `TRUE`, the performance values are aggregated in a fast way. +#' This is only supported for measures that do not require task, learner, model or train set. +#' The archive does not contain warnings and errors. +#' Default is `FALSE`. diff --git a/man/ObjectiveFSelectAsync.Rd b/man/ObjectiveFSelectAsync.Rd index b9aad69c..3bbaf11b 100644 --- a/man/ObjectiveFSelectAsync.Rd +++ b/man/ObjectiveFSelectAsync.Rd @@ -13,6 +13,7 @@ This class is usually constructed internally by the \link{FSelectInstanceAsyncSi \section{Methods}{ \subsection{Public methods}{ \itemize{ +\item \href{#method-ObjectiveFSelectAsync-new}{\code{ObjectiveFSelectAsync$new()}} \item \href{#method-ObjectiveFSelectAsync-clone}{\code{ObjectiveFSelectAsync$clone()}} } } @@ -25,11 +26,62 @@ This class is usually constructed internally by the \link{FSelectInstanceAsyncSi
  • bbotk::Objective$format()
  • bbotk::Objective$help()
  • bbotk::Objective$print()
  • -
  • mlr3fselect::ObjectiveFSelect$initialize()
  • }} \if{html}{\out{
    }} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ObjectiveFSelectAsync-new}{}}} +\subsection{Method \code{new()}}{ +Creates a new instance of this \link[R6:R6Class]{R6} class. +\subsection{Usage}{ +\if{html}{\out{
    }}\preformatted{ObjectiveFSelectAsync$new( + task, + learner, + resampling, + measures, + check_values = TRUE, + store_benchmark_result = TRUE, + store_models = FALSE, + callbacks = NULL +)}\if{html}{\out{
    }} +} + +\subsection{Arguments}{ +\if{html}{\out{
    }} +\describe{ +\item{\code{task}}{(\link[mlr3:Task]{mlr3::Task})\cr +Task to operate on.} + +\item{\code{learner}}{(\link[mlr3:Learner]{mlr3::Learner})\cr +Learner to optimize the feature subset for.} + +\item{\code{resampling}}{(\link[mlr3:Resampling]{mlr3::Resampling})\cr +Resampling that is used to evaluated the performance of the feature subsets. +Uninstantiated resamplings are instantiated during construction so that all feature subsets are evaluated on the same data splits. +Already instantiated resamplings are kept unchanged.} + +\item{\code{measures}}{(list of \link[mlr3:Measure]{mlr3::Measure})\cr +Measures to optimize. +If \code{NULL}, \CRANpkg{mlr3}'s default measure is used.} + +\item{\code{check_values}}{(\code{logical(1)})\cr +Check the parameters before the evaluation and the results for +validity?} + +\item{\code{store_benchmark_result}}{(\code{logical(1)})\cr +Store benchmark result in archive?} + +\item{\code{store_models}}{(\code{logical(1)}). +Store models in benchmark result?} + +\item{\code{callbacks}}{(list of \link{CallbackBatchFSelect})\cr +List of callbacks.} +} +\if{html}{\out{
    }} +} +} +\if{html}{\out{
    }} \if{html}{\out{}} \if{latex}{\out{\hypertarget{method-ObjectiveFSelectAsync-clone}{}}} \subsection{Method \code{clone()}}{ diff --git a/man/mlr_fselectors_async_random_search.Rd b/man/mlr_fselectors_async_random_search.Rd index c67a7f7f..3fde6cfd 100644 --- a/man/mlr_fselectors_async_random_search.Rd +++ b/man/mlr_fselectors_async_random_search.Rd @@ -13,11 +13,6 @@ Bergstra J, Bengio Y (2012). \description{ Feature selection using Asynchronous Random Search Algorithm. } -\details{ -The feature sets are randomly drawn. -The sets are evaluated asynchronously. -The algorithm uses \link[bbotk:mlr_optimizers_async_random_search]{bbotk::OptimizerAsyncRandomSearch} for optimization. -} \section{Dictionary}{ This \link{FSelector} can be instantiated with the associated sugar function \code{\link[=fs]{fs()}}: diff --git a/tests/testthat/test_FSelectInstanceAsyncSingleCrit.R b/tests/testthat/test_FSelectInstanceAsyncSingleCrit.R index 7210d6c4..84bd2205 100644 --- a/tests/testthat/test_FSelectInstanceAsyncSingleCrit.R +++ b/tests/testthat/test_FSelectInstanceAsyncSingleCrit.R @@ -172,3 +172,66 @@ test_that("saving the models with FSelectInstanceAsyncSingleCrit works", { # fselector$optimize(instance) # }) + + +test_that("fast aggregation and benchmark result produce the same scores", { + skip_on_cran() + skip_if_not_installed("rush") + flush_redis() + + on.exit(mirai::daemons(0)) + mirai::daemons(1) + rush::rush_plan(n_workers = 1, worker_type = "remote") + + instance = fsi_async( + task = tsk("pima"), + learner = lrn("classif.rpart"), + resampling = rsmp("cv", folds = 3), + measures = msr("classif.ce"), + terminator = trm("evals", n_evals = 6) + ) + + fselector = fs("async_random_search") + fselector$optimize(instance) + + expect_equal(get_private(instance$objective)$.aggregator, async_aggregator_fast) + + expect_equal(instance$archive$data$classif.ce, + instance$archive$benchmark_result$aggregate(msr("classif.ce"))$classif.ce) +}) + +test_that("fast aggregation and benchmark result produce the same conditions", { + skip_on_cran() + skip_if_not_installed("rush") + flush_redis() + + on.exit(mirai::daemons(0)) + mirai::daemons(1) + rush::rush_plan(n_workers = 1, worker_type = "remote") + + + learner = lrn("classif.debug", error_train = 0.5, warning_train = 0.5) + learner$encapsulate("callr", fallback = lrn("classif.debug")) + + instance = fsi_async( + task = tsk("pima"), + learner = learner, + resampling = rsmp("cv", folds = 3), + measures = msr("classif.ce"), + terminator = trm("evals", n_evals = 6) + ) + + fselector = fs("async_random_search") + fselector$optimize(instance) + + expect_equal(get_private(instance$objective)$.aggregator, async_aggregator_fast) + + expect_equal(instance$archive$data$classif.ce, + instance$archive$benchmark_result$aggregate(msr("classif.ce"))$classif.ce) + + expect_equal(instance$archive$data$errors, + instance$archive$benchmark_result$aggregate(msr("classif.ce"), conditions = TRUE)$errors) + + expect_equal(instance$archive$data$warnings, + instance$archive$benchmark_result$aggregate(msr("classif.ce"), conditions = TRUE)$warnings) +}) diff --git a/tests/testthat/test_FSelectInstanceSingleCrit.R b/tests/testthat/test_FSelectInstanceSingleCrit.R index 385c2ecd..9cfc6748 100644 --- a/tests/testthat/test_FSelectInstanceSingleCrit.R +++ b/tests/testthat/test_FSelectInstanceSingleCrit.R @@ -132,3 +132,48 @@ test_that("objective contains no benchmark results", { expect_null(instance$objective$.__enclos_env__$private$.benchmark_result) }) + +test_that("fast aggregation and benchmark result produce the same scores", { + instance = fsi( + task = tsk("pima"), + learner = lrn("classif.rpart"), + resampling = rsmp("cv", folds = 3), + measures = msr("classif.ce"), + terminator = trm("evals", n_evals = 6) + ) + + fselector = fs("random_search", batch_size = 2) + fselector$optimize(instance) + + expect_equal(get_private(instance$objective)$.aggregator, aggregator_fast) + + expect_equal(instance$archive$data$classif.ce, + instance$archive$benchmark_result$aggregate(msr("classif.ce"))$classif.ce) +}) + +test_that("fast aggregation and benchmark result produce the same conditions", { + learner = lrn("classif.debug", error_train = 0.5, warning_train = 0.5) + learner$encapsulate("callr", fallback = lrn("classif.debug")) + + instance = fsi( + task = tsk("pima"), + learner = learner, + resampling = rsmp("cv", folds = 3), + measures = msr("classif.ce"), + terminator = trm("evals", n_evals = 6) + ) + + fselector = fs("random_search", batch_size = 2) + fselector$optimize(instance) + + expect_equal(get_private(instance$objective)$.aggregator, aggregator_fast) + + expect_equal(instance$archive$data$classif.ce, + instance$archive$benchmark_result$aggregate(msr("classif.ce"))$classif.ce) + + expect_equal(instance$archive$data$errors, + instance$archive$benchmark_result$aggregate(msr("classif.ce"), conditions = TRUE)$errors) + + expect_equal(instance$archive$data$warnings, + instance$archive$benchmark_result$aggregate(msr("classif.ce"), conditions = TRUE)$warnings) +}) diff --git a/tests/testthat/test_fselect.R b/tests/testthat/test_fselect.R index 7cf0fad7..98f422ab 100644 --- a/tests/testthat/test_fselect.R +++ b/tests/testthat/test_fselect.R @@ -38,7 +38,7 @@ test_that("fselect interface is equal to FSelectInstanceBatchSingleCrit", { test_that("fselect interface is equal to FSelectInstanceBatchMultiCrit", { fselect_args = formalArgs(fselect) - fselect_args = fselect_args[fselect_args %nin% c("fselector", "ties_method")] + fselect_args = fselect_args[fselect_args %nin% c("fselector", "ties_method", "aggregate_fast")] instance_args = formalArgs(FSelectInstanceBatchMultiCrit$public_methods$initialize) instance_args = c(instance_args, "term_evals", "term_time", "rush") diff --git a/tests/testthat/test_fsi.R b/tests/testthat/test_fsi.R index e327dfe8..29ad5caf 100644 --- a/tests/testthat/test_fsi.R +++ b/tests/testthat/test_fsi.R @@ -43,7 +43,7 @@ test_that("fsi and FSelectInstanceBatchSingleCrit are equal", { test_that("fsi and FSelectInstanceBatchMultiCrit are equal", { fsi_args = formalArgs(fsi) - fsi_args = fsi_args[fsi_args != "ties_method"] + fsi_args = fsi_args[fsi_args %nin% c("ties_method", "aggregate_fast")] expect_equal(fsi_args, formalArgs(FSelectInstanceBatchMultiCrit$public_methods$initialize)) diff --git a/tests/testthat/test_fsi_async.R b/tests/testthat/test_fsi_async.R index fca2c7aa..ffc0f696 100644 --- a/tests/testthat/test_fsi_async.R +++ b/tests/testthat/test_fsi_async.R @@ -44,7 +44,7 @@ test_that("fsi_async interface is equal to FSelectInstanceAsyncMultiCrit", { flush_redis() fsi_args = formalArgs(fsi_async) - fsi_args = fsi_args[fsi_args != "ties_method"] + fsi_args = fsi_args[fsi_args %nin% c("ties_method", "aggregate_fast")] instance_args = formalArgs(FSelectInstanceAsyncMultiCrit$public_methods$initialize) expect_equal(fsi_args, instance_args)