|
| 1 | +#' @title Univariate Cox Survival Filter |
| 2 | +#' |
| 3 | +#' @name mlr_filters_univariatecox |
| 4 | +#' |
| 5 | +#' @description Calculates scores for assessing the relationship between |
| 6 | +#' individual features and the time-to-event outcome (right-censored survival |
| 7 | +#' data) using a univariate Cox proportional hazards model. |
| 8 | +#' The goal is to determine which features have a statistically significant |
| 9 | +#' association with the event of interest, typically in the context of clinical |
| 10 | +#' or biomedical research. |
| 11 | +#' |
| 12 | +#' This filter fits a [CoxPH][mlr3proba::LearnerSurvCoxPH()] learner using each |
| 13 | +#' feature independently and extracts the \eqn{p}-value that quantifies the |
| 14 | +#' significance of the feature's impact on survival. The filter value is |
| 15 | +#' `-log10(p)` where `p` is the \eqn{p}-value. This transformation is necessary |
| 16 | +#' to ensure numerical stability for very small \eqn{p}-values. Also higher |
| 17 | +#' values denote more important features. |
| 18 | +#' |
| 19 | +#' @family Filter |
| 20 | +#' @include Filter.R |
| 21 | +#' @template seealso_filter |
| 22 | +#' @export |
| 23 | +#' @examples |
| 24 | +#' if (requireNamespace("mlr3proba")) { |
| 25 | +#' task = tsk("rats") |
| 26 | +#' filter = flt("univariatecox") |
| 27 | +#' filter$calculate(task) |
| 28 | +#' as.data.table(filter) |
| 29 | +#' } |
| 30 | +#' |
| 31 | +#' if (mlr3misc::require_namespaces(c("mlr3pipelines", "mlr3proba"), quietly = TRUE)) { |
| 32 | +#' library("mlr3pipelines") |
| 33 | +#' task = tsk("rats") |
| 34 | +#' |
| 35 | +#' # Note: `filter.cutoff` is selected randomly and should be tuned. |
| 36 | +#' # The significance level of `0.05` serves as a conventional threshold. |
| 37 | +#' # The filter returns the `-log`-transformed scores so we transform |
| 38 | +#' # the cutoff as well: |
| 39 | +#' cutoff = -log(0.05) # ~2.99 |
| 40 | +#' |
| 41 | +#' graph = |
| 42 | +#' po("filter", filter = flt("univariatecox"), filter.cutoff = cutoff) %>>% |
| 43 | +#' po("learner", lrn("surv.coxph")) |
| 44 | +#' learner = as_learner(graph) |
| 45 | +#' |
| 46 | +#' learner$train(task) |
| 47 | +#' |
| 48 | +#' # univariate cox filter scores |
| 49 | +#' learner$model$surv.univariatecox$scores |
| 50 | +#' |
| 51 | +#' # only two features had a score larger than the specified `cutoff` and |
| 52 | +#' # were used to train the CoxPH model |
| 53 | +#' learner$model$surv.coxph$train_task$feature_names |
| 54 | +#' } |
| 55 | +FilterUnivariateCox = R6Class("FilterUnivariateCox", |
| 56 | + inherit = Filter, |
| 57 | + public = list( |
| 58 | + #' @description Create a FilterUnivariateCox object. |
| 59 | + initialize = function() { |
| 60 | + super$initialize( |
| 61 | + id = "surv.univariatecox", |
| 62 | + packages = c("mlr3proba"), |
| 63 | + param_set = ps(), |
| 64 | + feature_types = c("integer", "numeric", "factor"), |
| 65 | + task_types = "surv", |
| 66 | + label = "Univariate Cox Survival Score", |
| 67 | + man = "mlr3filters::mlr_filters_univariatecox" |
| 68 | + ) |
| 69 | + } |
| 70 | + ), |
| 71 | + |
| 72 | + private = list( |
| 73 | + .calculate = function(task, nfeat) { |
| 74 | + t = task$clone() |
| 75 | + features = t$feature_names |
| 76 | + learner = lrn("surv.coxph") |
| 77 | + |
| 78 | + scores = map_dbl(features, function(feature) { |
| 79 | + t$col_roles$feature = feature |
| 80 | + learner$train(t) |
| 81 | + pval = summary(learner$model)$coefficients[, "Pr(>|z|)"] |
| 82 | + -log(pval) # smaller p-values => larger scores |
| 83 | + }) |
| 84 | + |
| 85 | + set_names(scores, features) |
| 86 | + } |
| 87 | + ) |
| 88 | +) |
| 89 | + |
| 90 | +#' @include mlr_filters.R |
| 91 | +mlr_filters$add("univariatecox", FilterUnivariateCox) |
0 commit comments