Skip to content

Commit bafd7e0

Browse files
committed
NEWS and README
1 parent 8552018 commit bafd7e0

File tree

4 files changed

+421
-420
lines changed

4 files changed

+421
-420
lines changed

NEWS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
## Major changes
44

5+
- Added `permshap()` to calculate exact permutation SHAP values. The function currently works for up to 14 features.
56
- Factor-valued predictions are now supported. Each level is represented by its dummy variable.
67

78
## Other changes

README.md

Lines changed: 55 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,22 +13,30 @@
1313

1414
## Overview
1515

16-
This package offers an efficient implementation of Kernel SHAP, see [1] and [2]. For up to $p=8$ features, the resulting Kernel SHAP values are exact regarding the selected background data. For larger $p$, an almost exact hybrid algorithm involving iterative sampling is used by default.
16+
The package contains two workhorses to calculate SHAP values for any model:
1717

18-
The typical workflow to explain any model `object`:
18+
- `kernelshap()`: Kernel SHAP algorithm of [1] and [2]. By default, exact Kernel SHAP is used for up to $p=8$ features, and an almost exact hybrid algorithm otherwise.
19+
- `permshap()`: Exact permutation SHAP (currently available for up to $p=14$ features).
20+
21+
### Kernel SHAP or permutation SHAP?
22+
23+
- Exact Kernel SHAP and exact permutation SHAP values agree for additive models, and differ for models with interactions.
24+
- If the number of features is sufficiently small, we recommend `permshap()` over `kernelshap()`.
25+
26+
### Typical workflow to explain any model
1927

2028
1. **Sample rows to explain:** Sample 500 to 2000 rows `X` to be explained. If the training dataset is small, simply use the full training data for this purpose. `X` should only contain feature columns.
21-
2. **Select background data:** Kernel SHAP requires a representative background dataset `bg_X` to calculate marginal means. For this purpose, set aside 50 to 500 rows from the training data.
29+
2. **Select background data:** Both algorithms require a representative background dataset `bg_X` to calculate marginal means. For this purpose, set aside 50 to 500 rows from the training data.
2230
If the training data is small, use the full training data. In cases with a natural "off" value (like MNIST digits), this can also be a single row with all values set to the off value.
23-
3. **Crunch:** Use `kernelshap(object, X, bg_X, ...)` to calculate SHAP values. Runtime is proportional to `nrow(X)`, while memory consumption scales linearly in `nrow(bg_X)`.
24-
4. **Analyze:** Use {shapviz} to visualize the result.
31+
3. **Crunch:** Use `kernelshap(object, X, bg_X, ...)` or `permshap(object, X, bg_X, ...)` to calculate SHAP values. Runtime is proportional to `nrow(X)`, while memory consumption scales linearly in `nrow(bg_X)`.
32+
4. **Analyze:** Use {shapviz} to visualize the results.
2533

2634
**Remarks**
2735

2836
- Multivariate predictions are handled at no additional computational cost.
2937
- Factor-valued predictions are automatically turned into one-hot-encoded columns.
30-
- By changing the defaults, the iterative pure sampling approach in [2] can be enforced.
3138
- Case weights are supported via the argument `bg_w`.
39+
- By changing the defaults in `kernelshap()`, the iterative pure sampling approach in [2] can be enforced.
3240

3341
## Installation
3442

@@ -82,6 +90,17 @@ shap_lm
8290
sv_lm <- shapviz(shap_lm)
8391
sv_importance(sv_lm)
8492
sv_dependence(sv_lm, "log_carat", color_var = NULL)
93+
94+
# Since the model is additive, permutation SHAP gives the same results:
95+
system.time(
96+
permshap_lm <- permshap(fit_lm, X, bg_X = bg_X)
97+
)
98+
permshap_lm
99+
100+
# SHAP values of first observations:
101+
# log_carat clarity color cut
102+
# [1,] 1.2692479 0.1081900 -0.07847065 0.004630899
103+
# [2,] -0.4499226 -0.1111329 0.11832292 0.026503850
85104
```
86105

87106
![](man/figures/README-lm-imp.svg)
@@ -126,6 +145,17 @@ shap_rf
126145
sv_rf <- shapviz(shap_rf)
127146
sv_importance(sv_rf, kind = "bee", show_numbers = TRUE)
128147
sv_dependence(sv_rf, "log_carat")
148+
149+
# Permutation SHAP gives very slightly different results here (due to interactions):
150+
system.time(
151+
permshap_rf <- permshap(fit_rf, X, bg_X = bg_X)
152+
)
153+
permshap_rf
154+
#
155+
# SHAP values of first observations:
156+
# log_carat clarity color cut
157+
# [1,] 1.1986635 0.09557752 -0.1385312 0.001842753
158+
# [2,] -0.4970758 -0.12034448 0.1051721 0.030014490
129159
```
130160

131161
![](man/figures/README-rf-imp.jpeg)
@@ -205,7 +235,7 @@ library(mgcv)
205235
fit_gam <- gam(log_price ~ s(log_carat) + clarity + color + cut, data = diamonds)
206236

207237
system.time( # 11 seconds
208-
shap_gam <- kernelshap(
238+
shap_gam <- permshap(
209239
fit_gam,
210240
X,
211241
bg_X = bg_X,
@@ -224,7 +254,8 @@ shap_gam
224254

225255
## Multi-output models
226256

227-
{kernelshap} supports multivariate predictions, such as:
257+
{kernelshap} supports multivariate predictions:
258+
228259
- probabilistic classification,
229260
- non-probabilistic classification (factor-valued responses are turned into dummies),
230261
- regression with multivariate response, and
@@ -241,17 +272,19 @@ library(shapviz)
241272
library(ggplot2)
242273

243274
# Probabilistic
244-
fit_prob <- ranger(Species ~ ., data = iris, num.trees = 20, probability = TRUE, seed = 1)
245-
ks_prob <- kernelshap(fit_prob, X = iris, bg_X = iris) |>
275+
fit_prob <- ranger(
276+
Species ~ ., data = iris, num.trees = 20, probability = TRUE, seed = 1
277+
)
278+
ps_prob <- permshap(fit_prob, X = iris[, -5], bg_X = iris) |>
246279
shapviz()
247-
sv_importance(ks_prob) +
280+
sv_importance(ps_prob) +
248281
ggtitle("Probabilistic")
249282

250283
# Non-probabilistic: Predictions are factors
251284
fit_class <- ranger(Species ~ ., data = iris, num.trees = 20, seed = 1)
252-
ks_class <- kernelshap(fit_class, X = iris, bg_X = iris) |>
285+
ps_class <- permshap(fit_class, X = iris[, -5], bg_X = iris) |>
253286
shapviz()
254-
sv_importance(ks_class) +
287+
sv_importance(ps_class) +
255288
ggtitle("Non-Probabilistic")
256289
```
257290

@@ -282,8 +315,13 @@ iris_wf <- workflow() %>%
282315
fit <- iris_wf %>%
283316
fit(iris)
284317

285-
ks <- kernelshap(fit, iris[, -1], bg_X = iris)
318+
ks <- permshap(fit, iris[, -1], bg_X = iris)
286319
ks
320+
#
321+
# SHAP values of first observations:
322+
# Sepal.Width Petal.Length Petal.Width Species
323+
# [1,] 0.21951350 -1.955357 0.3149451 0.5823533
324+
# [2,] -0.02843097 -1.955357 0.3149451 0.5823533
287325
```
288326

289327
### caret
@@ -318,14 +356,14 @@ mlr_tasks$get("iris")
318356
task_iris <- TaskRegr$new(id = "reg", backend = iris, target = "Sepal.Length")
319357
fit_lm <- lrn("regr.lm")
320358
fit_lm$train(task_iris)
321-
s <- kernelshap(fit_lm, iris[-1], bg_X = iris)
359+
s <- permshap(fit_lm, iris[, -1], bg_X = iris)
322360
s
323361

324362
# *Probabilistic* classification -> lrn(..., predict_type = "prob")
325363
task_iris <- TaskClassif$new(id = "class", backend = iris, target = "Species")
326-
fit_rf <- lrn("classif.ranger", predict_type = "prob", num.trees = 50)
364+
fit_rf <- lrn("classif.ranger", predict_type = "prob", num.trees = 20)
327365
fit_rf$train(task_iris)
328-
s <- kernelshap(fit_rf, X = iris[-5], bg_X = iris)
366+
s <- permshap(fit_rf, X = iris[, -5], bg_X = iris)
329367
s
330368
```
331369

0 commit comments

Comments
 (0)