You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Variable importance plot lists only original variables, not one-hot-encoded variables, when permutation importance calculated with model_parts() and explain_tidymodels() functions? #85
I am using explain_tidymodels() to compute variable importance. I have a workflow which includes a recipe with a step_dummy() step. I'm trying to understand why the associated variable importance calculated with model_parts() is given for the original variables rather than the one-hot-encoded variables when this step is included. Is the permutation importance aggregated at some point for the group of one-hot-encoded variables that go together? I didn't see this explained in the documentation. Reprex below. Please advise, Thank you
library("DALEXtra")
library("tidymodels")
library("recipes")
# example with no dummy variables
data <- titanic_imputed
data$survived <- as.factor(data$survived)
rec <- recipe(survived ~ ., data = data) %>%
step_normalize(fare)
model <- decision_tree(tree_depth = 25) %>%
set_engine("rpart") %>%
set_mode("classification")
wflow <- workflow() %>%
add_recipe(rec) %>%
add_model(model)
model_fitted <- wflow %>%
fit(data = data)
explainTest <- explain_tidymodels(model_fitted, data = data, y = as.numeric(data$survived))
explainModelParts <- model_parts(explainTest, type="variable_importance")
plot(explainModelParts)
# example with dummy variables
data <- titanic_imputed
data$survived <- as.factor(data$survived)
rec <- recipe(survived ~ ., data = data) %>%
step_dummy(gender, class, embarked, one_hot = TRUE) %>% # one hot encode the categorical variables
step_normalize(fare)
model <- decision_tree(tree_depth = 25) %>%
set_engine("rpart") %>%
set_mode("classification")
wflow <- workflow() %>%
add_recipe(rec) %>%
add_model(model)
model_fitted <- wflow %>%
fit(data = data)
explainModel <- explain_tidymodels(model_fitted, data = data, y = as.numeric(data$survived))
vipData <- model_parts(explainModel, type = "variable_importance")
plot(vipData) # this plot shows original variable names and does not include the one hot encoded variables