Skip to content

Commit c4cbb2e

Browse files
committed
add univariate cox filter + test
1 parent c629e6b commit c4cbb2e

File tree

2 files changed

+112
-0
lines changed

2 files changed

+112
-0
lines changed

R/FilterUnivariateCox.R

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
#' @title Univariate Cox Survival Filter
2+
#'
3+
#' @name mlr_filters_univariatecox
4+
#'
5+
#' @description Calculates scores for assessing the relationship between
6+
#' individual features and the time-to-event outcome (right-censored survival
7+
#' data) using a univariate Cox proportional hazards model.
8+
#' The goal is to determine which features have a statistically significant
9+
#' association with the event of interest, typically in the context of clinical
10+
#' or biomedical research.
11+
#'
12+
#' This filter fits a [CoxPH][mlr3proba::LearnerSurvCoxPH()] learner using each
13+
#' feature independently and extracts the \eqn{p}-value that quantifies the
14+
#' significance of the feature's impact on survival. The filter value is
15+
#' `-log10(p)` where `p` is the \eqn{p}-value. This transformation is necessary
16+
#' to ensure numerical stability for very small \eqn{p}-values. Also higher
17+
#' values denote more important features.
18+
#'
19+
#' @family Filter
20+
#' @include Filter.R
21+
#' @template seealso_filter
22+
#' @export
23+
#' @examples
24+
#' if (requireNamespace("mlr3proba")) {
25+
#' task = tsk("rats")
26+
#' filter = flt("univariatecox")
27+
#' filter$calculate(task)
28+
#' as.data.table(filter)
29+
#' }
30+
#'
31+
#' if (mlr3misc::require_namespaces(c("mlr3pipelines", "mlr3proba"), quietly = TRUE)) {
32+
#' library("mlr3pipelines")
33+
#' task = tsk("rats")
34+
#'
35+
#' # Note: `filter.cutoff` is selected randomly and should be tuned.
36+
#' # The significance level of `0.05` serves as a conventional threshold.
37+
#' # The filter returns the `-log`-transformed scores so we transform
38+
#' # the cutoff as well:
39+
#' cutoff = -log(0.05) # ~2.99
40+
#'
41+
#' graph =
42+
#' po("filter", filter = flt("univariatecox"), filter.cutoff = cutoff) %>>%
43+
#' po("learner", lrn("surv.coxph"))
44+
#' learner = as_learner(graph)
45+
#'
46+
#' learner$train(task)
47+
#'
48+
#' # univariate cox filter scores
49+
#' learner$model$surv.univariatecox$scores
50+
#'
51+
#' # only two features had a score larger than the specified `cutoff` and
52+
#' # were used to train the CoxPH model
53+
#' learner$model$surv.coxph$train_task$feature_names
54+
#' }
55+
FilterUnivariateCox = R6Class("FilterUnivariateCox",
56+
inherit = Filter,
57+
public = list(
58+
#' @description Create a FilterUnivariateCox object.
59+
initialize = function() {
60+
super$initialize(
61+
id = "surv.univariatecox",
62+
packages = c("mlr3proba"),
63+
param_set = ps(),
64+
feature_types = c("integer", "numeric", "factor"),
65+
task_types = "surv",
66+
label = "Univariate Cox Survival Score",
67+
man = "mlr3filters::mlr_filters_univariatecox"
68+
)
69+
}
70+
),
71+
72+
private = list(
73+
.calculate = function(task, nfeat) {
74+
t = task$clone()
75+
features = t$feature_names
76+
learner = lrn("surv.coxph")
77+
78+
scores = map_dbl(features, function(feature) {
79+
t$col_roles$feature = feature
80+
learner$train(t)
81+
pval = summary(learner$model)$coefficients[, "Pr(>|z|)"]
82+
-log(pval) # smaller p-values => larger scores
83+
})
84+
85+
set_names(scores, features)
86+
}
87+
)
88+
)
89+
90+
#' @include mlr_filters.R
91+
mlr_filters$add("univariatecox", FilterUnivariateCox)
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
skip_if_not_installed("mlr3proba")
2+
3+
test_that("FilterUnivariateCox", {
4+
t = tsk("rats")
5+
f = flt("univariatecox")
6+
f$calculate(t)
7+
8+
expect_filter(f, task = t)
9+
expect_true(all(f$scores >= 0))
10+
11+
# works with 2-level factors (but not 3-level ones)
12+
feature = "sex"
13+
expect_class(t$data(cols = feature)[[1]], "factor")
14+
15+
l = lrn("surv.coxph")
16+
t2 = t$clone()
17+
t2$col_roles$feature = feature
18+
l$train(t2)
19+
20+
expect_equal(-log(summary(l$model)$coefficients[,"Pr(>|z|)"]), f$scores[[feature]])
21+
})

0 commit comments

Comments
 (0)