Skip to content

Time predictions from decision_tree() with the "rpart" engine correct? #331

@jesusherranz

Description

@jesusherranz

When I run decision_tree() with the "rpart" engine, I see that the results of the tune() function's concordance index and also the calculation on a test sample give results less than 0.5, when I expect them to be greater than 0.5. I provide an example with a well-known survival data file (Hosmer), extracted from the "smoothHR" package.
I have tested with other files and the results are similar. Also, if I repeat this script changing the engine to "partykit", I get a concordance index greater than 0.7, which is correct.

library(tidymodels)
library(censored)
library(smoothHR)

## Data
whas500 <- whas500 %>% select(age, gender, hr, sysbp, diasbp, bmi, cvd, afb,
                              sho, chf, av3, miord, mitype, lenfol, fstat)

set.seed(252)
whas500_split <- initial_split(whas500, strata = fstat)
whas500_train <- training(whas500_split)
whas500_test <- testing(whas500_split)
                         
whas500_train <- whas500_train %>%
  mutate(surv_var = Surv(lenfol, fstat), .keep = "unused")
whas500_test <- whas500_test %>%
  mutate(surv_var = Surv(lenfol, fstat), .keep = "unused")

## resampling
set.seed(253)
cv_split <- vfold_cv(whas500_train, v = 10, repeats = 2 )

## Model specification
tree_spec <- 
    decision_tree( tree_depth = tune(), min_n = tune(),
                   cost_complexity = tune() ) %>%
    set_engine("rpart") %>% 
    set_mode("censored regression") 
              
## Workflow              
wflow_tree <- workflow() %>%
  add_model(tree_spec) %>% 
  add_formula(surv_var ~ . ) 
  
## Parameters Tune
tree_grid <- grid_regular(cost_complexity(), tree_depth(), min_n(),
                          levels = 4 )    

tune_result_tree <- wflow_tree %>% 
  tune_grid( resamples = cv_split, grid = tree_grid, 
             metrics = metric_set(concordance_survival) ) 
show_best(tune_result_tree, metric="concordance_survival")

## Final workflow and final mpdel
final_wflow_tree <- wflow_tree %>% 
  finalize_workflow( select_best(tune_result_tree, metric="concordance_survival") )
tree_fit <- final_wflow_tree %>% fit(whas500_train)
tree_fit

## Predictions in the testing sample
pred_tree_time <- predict(tree_fit, whas500_test, type = "time")
pred_tree_df <- bind_cols(whas500_test %>% select(surv_var), pred_tree_time ) 
head(pred_tree_df)

## Concordance
concordance_survival(pred_tree_df, truth = surv_var, estimate = .pred_time ) 

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugan unexpected problem or unintended behavior

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions