Skip to content

feat: add fast aggregation #139

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ Depends:
mlr3 (>= 0.23.0),
R (>= 3.1.0)
Imports:
bbotk (>= 1.5.0.9000),
bbotk (>= 1.6.0),
checkmate (>= 2.0.0),
cli,
data.table,
Expand All @@ -47,6 +47,8 @@ Suggests:
rush,
rush (>= 0.2.0),
testthat (>= 3.0.0)
Remotes:
mlr-org/mlr3@aggregate_fast
Config/testthat/edition: 3
Config/testthat/parallel: false
Encoding: UTF-8
Expand Down
46 changes: 43 additions & 3 deletions R/ObjectiveFSelectAsync.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,37 @@
#' @export
ObjectiveFSelectAsync = R6Class("ObjectiveFSelectAsync",
inherit = ObjectiveFSelect,
public = list(
#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function(
task,
learner,
resampling,
measures,
check_values = TRUE,
store_benchmark_result = TRUE,
store_models = FALSE,
callbacks = NULL
) {
super$initialize(
task = task,
learner = learner,
resampling = resampling,
measures = measures,
store_benchmark_result = store_benchmark_result,
store_models = store_models,
check_values = check_values,
callbacks = callbacks
)
measure_properties = unlist(map(self$measures, "properties"))
if (self$codomain$length == 1 && all(c("requires_task", "requires_learner", "requires_model", "requires_train_set") %nin% measure_properties)) {
private$.aggregator = async_aggregator_fast
} else {
private$.aggregator = async_aggregator_default
}
}
),
private = list(
.eval = function(xs, resampling) {
lg$debug("Evaluating feature subset %s", as_short_string(xs))
Expand All @@ -37,8 +68,8 @@ ObjectiveFSelectAsync = R6Class("ObjectiveFSelectAsync",

lg$debug("Aggregating performance")

# aggregate performance
private$.aggregated_performance = as.list(private$.resample_result$aggregate(self$measures))
# aggregate performance using the appropriate aggregator
private$.aggregated_performance = as.list(private$.aggregator(private$.resample_result, self$measures))

lg$debug("Aggregated performance %s", as_short_string(private$.aggregated_performance))

Expand All @@ -61,6 +92,15 @@ ObjectiveFSelectAsync = R6Class("ObjectiveFSelectAsync",

.xs = NULL,
.resample_result = NULL,
.aggregated_performance = NULL
.aggregated_performance = NULL,
.aggregator = NULL
)
)

async_aggregator_default = function(resample_result, measures) {
resample_result$aggregate(measures)
}

async_aggregator_fast = function(resample_result, measures) {
mlr3::faggregate(resample_result, measures[[1]])
}
44 changes: 36 additions & 8 deletions R/ObjectiveFSelectBatch.R
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ ObjectiveFSelectBatch = R6Class("ObjectiveFSelectBatch",
check_values = check_values,
callbacks = callbacks
)
measure_properties = unlist(map(self$measures, "properties"))
if (self$codomain$length == 1 && all(c("requires_task", "requires_learner", "requires_model", "requires_train_set") %nin% measure_properties)) {
private$.aggregator = aggregator_fast
} else {
private$.aggregator = aggregator_default
}
}
),

Expand Down Expand Up @@ -80,16 +86,10 @@ ObjectiveFSelectBatch = R6Class("ObjectiveFSelectBatch",
lg$debug("Aggregating performance")

# aggregate performance scores
private$.aggregated_performance = private$.benchmark_result$aggregate(self$measures, conditions = TRUE)[, c(self$codomain$target_ids, "warnings", "errors"), with = FALSE]
private$.aggregated_performance = private$.aggregator(private$.benchmark_result, self$measures, self$codomain)

lg$debug("Aggregated performance %s", as_short_string(private$.aggregated_performance))

# add runtime to evaluations
time = map_dbl(private$.benchmark_result$resample_results$resample_result, function(rr) {
sum(map_dbl(get_private(rr)$.data$learner_states(get_private(rr)$.view), function(state) state$train_time + state$predict_time))
})
set(private$.aggregated_performance, j = "runtime_learners", value = time)

# store benchmark result in archive
if (self$store_benchmark_result) {
lg$debug("Storing resample result")
Expand All @@ -106,6 +106,34 @@ ObjectiveFSelectBatch = R6Class("ObjectiveFSelectBatch",
.design = NULL,
.benchmark_result = NULL,
.aggregated_performance = NULL,
.model_required = FALSE
.model_required = FALSE,
.aggregator = NULL
)
)

aggregator_default = function(benchmark_result, measures, codomain) {
aggr = benchmark_result$aggregate(measures, conditions = TRUE)[, c(codomain$target_ids, "warnings", "errors"), with = FALSE]

# add runtime
data = get_private(benchmark_result)$.data$data
tab = data$fact[data$uhashes, c("uhash", "learner_state"), with = FALSE]
learner_state = NULL
runtime = tab[, sum(map_dbl(learner_state, function(s) sum(s$train_time + s$predict_time))), by = uhash]$V1
set(aggr, j = "runtime_learners", value = runtime)
aggr
}

aggregator_fast = function(benchmark_result, measures, codomain) {
aggr = faggregate(benchmark_result, measures[[1]])

# add runtime and conditions
data = get_private(benchmark_result)$.data$data
tab = data$fact[data$uhashes, c("uhash", "learner_state"), with = FALSE]

learner_state = NULL
aggr[tab[, list(
errors = sum(map_int(learner_state, function(s) sum(s$log$class == "error"))),
warnings = sum(map_int(learner_state, function(s) sum(s$log$class == "warning"))),
runtime_learners = sum(map_dbl(learner_state, function(s) sum(s$train_time + s$predict_time)))
), by = uhash], on = "uhash"]
}
2 changes: 1 addition & 1 deletion R/helper.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ measures_to_codomain = function(measures) {
}

extract_runtime = function(resample_result) {
runtimes = map_dbl(get_private(resample_result)$.data$learner_states(get_private(resample_result)$.view), function(state) {
runtimes = map_dbl(get_private(resample_result)$.data$learner_states(), function(state) {
state$train_time + state$predict_time
})
sum(runtimes)
Expand Down
2 changes: 1 addition & 1 deletion inst/testthat/helper_misc.R
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ MeasureDummy = R6Class("MeasureDummy", inherit = MeasureRegr,
)
}
private$.score_design = score_design
super$initialize(id = "dummy", range = c(0, 4), minimize = minimize)
super$initialize(id = "dummy", range = c(0, 4), minimize = minimize, properties = c("requires_task", "requires_learner"))
}
),
private = list(
Expand Down
5 changes: 5 additions & 0 deletions man-roxygen/param_aggregate_fast.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#' @param aggregate_fast (`logical(1)`)\cr
#' If `TRUE`, the performance values are aggregated in a fast way.
#' This is only supported for measures that do not require task, learner, model or train set.
#' The archive does not contain warnings and errors.
#' Default is `FALSE`.
54 changes: 53 additions & 1 deletion man/ObjectiveFSelectAsync.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 0 additions & 5 deletions man/mlr_fselectors_async_random_search.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

63 changes: 63 additions & 0 deletions tests/testthat/test_FSelectInstanceAsyncSingleCrit.R
Original file line number Diff line number Diff line change
Expand Up @@ -172,3 +172,66 @@ test_that("saving the models with FSelectInstanceAsyncSingleCrit works", {

# fselector$optimize(instance)
# })


test_that("fast aggregation and benchmark result produce the same scores", {
skip_on_cran()
skip_if_not_installed("rush")
flush_redis()

on.exit(mirai::daemons(0))
mirai::daemons(1)
rush::rush_plan(n_workers = 1, worker_type = "remote")

instance = fsi_async(
task = tsk("pima"),
learner = lrn("classif.rpart"),
resampling = rsmp("cv", folds = 3),
measures = msr("classif.ce"),
terminator = trm("evals", n_evals = 6)
)

fselector = fs("async_random_search")
fselector$optimize(instance)

expect_equal(get_private(instance$objective)$.aggregator, async_aggregator_fast)

expect_equal(instance$archive$data$classif.ce,
instance$archive$benchmark_result$aggregate(msr("classif.ce"))$classif.ce)
})

test_that("fast aggregation and benchmark result produce the same conditions", {
skip_on_cran()
skip_if_not_installed("rush")
flush_redis()

on.exit(mirai::daemons(0))
mirai::daemons(1)
rush::rush_plan(n_workers = 1, worker_type = "remote")


learner = lrn("classif.debug", error_train = 0.5, warning_train = 0.5)
learner$encapsulate("callr", fallback = lrn("classif.debug"))

instance = fsi_async(
task = tsk("pima"),
learner = learner,
resampling = rsmp("cv", folds = 3),
measures = msr("classif.ce"),
terminator = trm("evals", n_evals = 6)
)

fselector = fs("async_random_search")
fselector$optimize(instance)

expect_equal(get_private(instance$objective)$.aggregator, async_aggregator_fast)

expect_equal(instance$archive$data$classif.ce,
instance$archive$benchmark_result$aggregate(msr("classif.ce"))$classif.ce)

expect_equal(instance$archive$data$errors,
instance$archive$benchmark_result$aggregate(msr("classif.ce"), conditions = TRUE)$errors)

expect_equal(instance$archive$data$warnings,
instance$archive$benchmark_result$aggregate(msr("classif.ce"), conditions = TRUE)$warnings)
})
45 changes: 45 additions & 0 deletions tests/testthat/test_FSelectInstanceSingleCrit.R
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,48 @@ test_that("objective contains no benchmark results", {

expect_null(instance$objective$.__enclos_env__$private$.benchmark_result)
})

test_that("fast aggregation and benchmark result produce the same scores", {
instance = fsi(
task = tsk("pima"),
learner = lrn("classif.rpart"),
resampling = rsmp("cv", folds = 3),
measures = msr("classif.ce"),
terminator = trm("evals", n_evals = 6)
)

fselector = fs("random_search", batch_size = 2)
fselector$optimize(instance)

expect_equal(get_private(instance$objective)$.aggregator, aggregator_fast)

expect_equal(instance$archive$data$classif.ce,
instance$archive$benchmark_result$aggregate(msr("classif.ce"))$classif.ce)
})

test_that("fast aggregation and benchmark result produce the same conditions", {
learner = lrn("classif.debug", error_train = 0.5, warning_train = 0.5)
learner$encapsulate("callr", fallback = lrn("classif.debug"))

instance = fsi(
task = tsk("pima"),
learner = learner,
resampling = rsmp("cv", folds = 3),
measures = msr("classif.ce"),
terminator = trm("evals", n_evals = 6)
)

fselector = fs("random_search", batch_size = 2)
fselector$optimize(instance)

expect_equal(get_private(instance$objective)$.aggregator, aggregator_fast)

expect_equal(instance$archive$data$classif.ce,
instance$archive$benchmark_result$aggregate(msr("classif.ce"))$classif.ce)

expect_equal(instance$archive$data$errors,
instance$archive$benchmark_result$aggregate(msr("classif.ce"), conditions = TRUE)$errors)

expect_equal(instance$archive$data$warnings,
instance$archive$benchmark_result$aggregate(msr("classif.ce"), conditions = TRUE)$warnings)
})
2 changes: 1 addition & 1 deletion tests/testthat/test_fselect.R
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ test_that("fselect interface is equal to FSelectInstanceBatchSingleCrit", {

test_that("fselect interface is equal to FSelectInstanceBatchMultiCrit", {
fselect_args = formalArgs(fselect)
fselect_args = fselect_args[fselect_args %nin% c("fselector", "ties_method")]
fselect_args = fselect_args[fselect_args %nin% c("fselector", "ties_method", "aggregate_fast")]

instance_args = formalArgs(FSelectInstanceBatchMultiCrit$public_methods$initialize)
instance_args = c(instance_args, "term_evals", "term_time", "rush")
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test_fsi.R
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ test_that("fsi and FSelectInstanceBatchSingleCrit are equal", {

test_that("fsi and FSelectInstanceBatchMultiCrit are equal", {
fsi_args = formalArgs(fsi)
fsi_args = fsi_args[fsi_args != "ties_method"]
fsi_args = fsi_args[fsi_args %nin% c("ties_method", "aggregate_fast")]

expect_equal(fsi_args, formalArgs(FSelectInstanceBatchMultiCrit$public_methods$initialize))

Expand Down
Loading
Loading