-
Notifications
You must be signed in to change notification settings - Fork 13
Open
Labels
bugan unexpected problem or unintended behavioran unexpected problem or unintended behavior
Description
When I run decision_tree() with the "rpart" engine, I see that the results of the tune() function's concordance index and also the calculation on a test sample give results less than 0.5, when I expect them to be greater than 0.5. I provide an example with a well-known survival data file (Hosmer), extracted from the "smoothHR" package.
I have tested with other files and the results are similar. Also, if I repeat this script changing the engine to "partykit", I get a concordance index greater than 0.7, which is correct.
library(tidymodels)
library(censored)
library(smoothHR)
## Data
whas500 <- whas500 %>% select(age, gender, hr, sysbp, diasbp, bmi, cvd, afb,
sho, chf, av3, miord, mitype, lenfol, fstat)
set.seed(252)
whas500_split <- initial_split(whas500, strata = fstat)
whas500_train <- training(whas500_split)
whas500_test <- testing(whas500_split)
whas500_train <- whas500_train %>%
mutate(surv_var = Surv(lenfol, fstat), .keep = "unused")
whas500_test <- whas500_test %>%
mutate(surv_var = Surv(lenfol, fstat), .keep = "unused")
## resampling
set.seed(253)
cv_split <- vfold_cv(whas500_train, v = 10, repeats = 2 )
## Model specification
tree_spec <-
decision_tree( tree_depth = tune(), min_n = tune(),
cost_complexity = tune() ) %>%
set_engine("rpart") %>%
set_mode("censored regression")
## Workflow
wflow_tree <- workflow() %>%
add_model(tree_spec) %>%
add_formula(surv_var ~ . )
## Parameters Tune
tree_grid <- grid_regular(cost_complexity(), tree_depth(), min_n(),
levels = 4 )
tune_result_tree <- wflow_tree %>%
tune_grid( resamples = cv_split, grid = tree_grid,
metrics = metric_set(concordance_survival) )
show_best(tune_result_tree, metric="concordance_survival")
## Final workflow and final mpdel
final_wflow_tree <- wflow_tree %>%
finalize_workflow( select_best(tune_result_tree, metric="concordance_survival") )
tree_fit <- final_wflow_tree %>% fit(whas500_train)
tree_fit
## Predictions in the testing sample
pred_tree_time <- predict(tree_fit, whas500_test, type = "time")
pred_tree_df <- bind_cols(whas500_test %>% select(surv_var), pred_tree_time )
head(pred_tree_df)
## Concordance
concordance_survival(pred_tree_df, truth = surv_var, estimate = .pred_time )
Metadata
Metadata
Assignees
Labels
bugan unexpected problem or unintended behavioran unexpected problem or unintended behavior