Skip to content

catboost and predict threshold #58

@pecto2020

Description

@pecto2020

Hi,
predict(catboost) in tidymodels doesn't use the default threshold of 0.5 but something else. Does catboost use a class_weight during the training process? In that case how do I change it in tidymodels/treesnip? I attach a comparison between catboost and random forest.
Thanks

library(tidymodels)
#> Registered S3 method overwritten by 'tune':
#>   method                   from   
#>   required_pkgs.model_spec parsnip
library(mlbench)
library(catboost)
library(treesnip)
set_dependency("boost_tree", eng = "catboost", "catboost")
set_dependency("boost_tree", eng = "catboost", "treesnip")


#load data
data(PimaIndiansDiabetes)
diabetes_orig<-PimaIndiansDiabetes

#set random seed
set.seed(123)
#create initial split
diabetes_split <- initial_split(diabetes_orig, prop = 3/4)
diabetes_split
#> <Analysis/Assess/Total>
#> <576/192/768>
#create training set
diabetes_train <- training(diabetes_split)
#create test set
diabetes_test <- testing(diabetes_split)

#train Random Forest

# model specification
trees_spec<-rand_forest()%>%
  set_mode("classification") %>%
  set_engine("ranger")

# fit on training data
trees_fit<-trees_spec %>% fit(diabetes~., data=diabetes_train)

# predict
trees_pred<-predict(trees_fit, diabetes_test)%>%
  bind_cols(predict(trees_fit,diabetes_test, type="prob"))%>%
  bind_cols(diabetes_test%>% select(diabetes)) 
# get metrics
trees_perf<- trees_pred %>%
  roc_auc(truth = diabetes, .pred_pos, event_level="second") %>%
  bind_rows(trees_pred %>% sens(trut = diabetes, .pred_class, event_levels="second"))

# change threshold
trees_05<-trees_pred %>% 
  mutate(
    .pred_class = ifelse(.pred_pos>0.5,"pos","neg"))%>%
  mutate_if(is.character, as.factor)
# get metrics
trees_perf_05<-trees_05%>% 
  roc_auc(truth = diabetes, .pred_pos, event_level="second") %>%
  bind_rows( trees_05 %>% sens(truth = diabetes, .pred_class, event_levels="second"))

trees_perf
#> # A tibble: 2 x 3
#>   .metric .estimator .estimate
#>   <chr>   <chr>          <dbl>
#> 1 roc_auc binary         0.823
#> 2 sens    binary         0.856
trees_perf_05
#> # A tibble: 2 x 3
#>   .metric .estimator .estimate
#>   <chr>   <chr>          <dbl>
#> 1 roc_auc binary         0.823
#> 2 sens    binary         0.856

#train Catboost

# model specification
catboost_spec<-(boost_tree(tree_depth=10) %>% 
                  set_mode("classification") %>%
                  set_engine("catboost", nthread=4))
# fit on training data
catboost_fit<-catboost_spec %>% fit(diabetes~., data=diabetes_train)

# predict
catboost_pred<-predict(catboost_fit, diabetes_test) %>%
  bind_cols(predict(catboost_fit,diabetes_test, type="prob"))%>%
  bind_cols(diabetes_test%>% select(diabetes)) 

# get metrics
catboost_perf<- catboost_pred %>%
  roc_auc(truth = diabetes, .pred_pos, event_level="second") %>%
  bind_rows(catboost_pred %>% sens(truth = diabetes, .pred_class, event_levels="second"))



#  change threshold
catboost_05<-catboost_pred %>% 
  mutate(
    .pred_class = ifelse(.pred_pos>0.5,"pos","neg"))%>%
  mutate_if(is.character, as.factor)
# get metrics
catboost_perf_05<-catboost_05%>% 
  roc_auc(truth = diabetes, .pred_pos, event_level="second") %>%
  bind_rows(catboost_05 %>% sens(truth = diabetes, .pred_class, event_levels="second"))

catboost_perf
#> # A tibble: 2 x 3
#>   .metric .estimator .estimate
#>   <chr>   <chr>          <dbl>
#> 1 roc_auc binary         0.801
#> 2 sens    binary         1
catboost_perf_05
#> # A tibble: 2 x 3
#>   .metric .estimator .estimate
#>   <chr>   <chr>          <dbl>
#> 1 roc_auc binary         0.801
#> 2 sens    binary         0.992

Created on 2022-02-02 by the reprex package (v2.0.1)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions