Skip to content

Variable importance plot lists only original variables, not one-hot-encoded variables, when permutation importance calculated with model_parts() and explain_tidymodels() functions? #85

@kransom14

Description

@kransom14

I am using explain_tidymodels() to compute variable importance. I have a workflow which includes a recipe with a step_dummy() step. I'm trying to understand why the associated variable importance calculated with model_parts() is given for the original variables rather than the one-hot-encoded variables when this step is included. Is the permutation importance aggregated at some point for the group of one-hot-encoded variables that go together? I didn't see this explained in the documentation. Reprex below. Please advise, Thank you

library("DALEXtra")
library("tidymodels")
library("recipes")

# example with no dummy variables
data <- titanic_imputed

data$survived <- as.factor(data$survived)

rec <- recipe(survived ~ ., data = data) %>%
  step_normalize(fare)

model <- decision_tree(tree_depth = 25) %>%
  set_engine("rpart") %>%
  set_mode("classification")

wflow <- workflow() %>%
  add_recipe(rec) %>%
  add_model(model)

model_fitted <- wflow %>%
  fit(data = data)

explainTest <- explain_tidymodels(model_fitted, data = data, y = as.numeric(data$survived))
explainModelParts <- model_parts(explainTest, type="variable_importance")
plot(explainModelParts)


# example with dummy variables
data <- titanic_imputed

data$survived <- as.factor(data$survived)

rec <- recipe(survived ~ ., data = data) %>%
  step_dummy(gender, class, embarked, one_hot = TRUE) %>% # one hot encode the categorical variables
  step_normalize(fare)

model <- decision_tree(tree_depth = 25) %>%
  set_engine("rpart") %>%
  set_mode("classification")

wflow <- workflow() %>%
  add_recipe(rec) %>%
  add_model(model)

model_fitted <- wflow %>%
  fit(data = data)

explainModel <- explain_tidymodels(model_fitted, data = data, y = as.numeric(data$survived))

vipData <- model_parts(explainModel, type = "variable_importance")
plot(vipData) # this plot shows original variable names and does not include the one hot encoded variables

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions