Skip to content

Commit a690ff6

Browse files
authored
Merge pull request #109 from ModelOriented/permshap
Add permshap()
2 parents 09c4226 + bafd7e0 commit a690ff6

24 files changed

+1343
-534
lines changed

.github/workflows/test-coverage.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ jobs:
3535
function_exclusions = c(
3636
"kernelshap\\.Learner",
3737
"kernelshap\\.ranger",
38+
"permshap\\.Learner",
39+
"permshap\\.ranger",
3840
"mlr3_pred_fun"
3941
)
4042
)

DESCRIPTION

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,13 @@ Authors@R: c(
77
person("Przemyslaw", "Biecek", , "przemyslaw.biecek@gmail.com", role = "ctb",
88
comment = c(ORCID = "0000-0001-8423-1823"))
99
)
10-
Description: Efficient implementation of Kernel SHAP, see Lundberg and Lee (2017),
11-
and Covert and Lee (2021) <http://proceedings.mlr.press/v130/covert21a>.
12-
For models with up to eight features, the results are exact regarding the
13-
selected background data. Otherwise, an almost exact hybrid algorithm
14-
involving iterative sampling is used. The package plays well together
15-
with meta-learning packages like 'tidymodels', 'caret' or 'mlr3'.
16-
Visualizations can be done using the R package 'shapviz'.
10+
Description: Efficient implementation of Kernel SHAP, see Lundberg and Lee
11+
(2017), and Covert and Lee (2021)
12+
<http://proceedings.mlr.press/v130/covert21a>. Furthermore, for up to
13+
14 features, exact permutation SHAP values can be calculated. The
14+
package plays well together with meta-learning packages like
15+
'tidymodels', 'caret' or 'mlr3'. Visualizations can be done using the
16+
R package 'shapviz'.
1717
License: GPL (>= 2)
1818
Depends:
1919
R (>= 3.2.0)

NAMESPACE

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,14 @@
33
S3method(kernelshap,Learner)
44
S3method(kernelshap,default)
55
S3method(kernelshap,ranger)
6+
S3method(permshap,Learner)
7+
S3method(permshap,default)
8+
S3method(permshap,ranger)
69
S3method(print,kernelshap)
10+
S3method(print,permshap)
711
S3method(summary,kernelshap)
812
export(is.kernelshap)
13+
export(is.permshap)
914
export(kernelshap)
15+
export(permshap)
1016
importFrom(foreach,"%dopar%")

NEWS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
## Major changes
44

5+
- Added `permshap()` to calculate exact permutation SHAP values. The function currently works for up to 14 features.
56
- Factor-valued predictions are now supported. Each level is represented by its dummy variable.
67

78
## Other changes

R/kernelshap.R

Lines changed: 10 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@
6868
#' In cases with a natural "off" value (like MNIST digits),
6969
#' this can also be a single row with all values set to the off value.
7070
#' @param pred_fun Prediction function of the form `function(object, X, ...)`,
71-
#' providing \eqn{K \ge 1} numeric predictions per row. Its first argument
71+
#' providing \eqn{K \ge 1} predictions per row. Its first argument
7272
#' represents the model `object`, its second argument a data structure like `X`.
7373
#' Additional (named) arguments are passed via `...`.
7474
#' The default, [stats::predict()], will work in most cases.
@@ -113,7 +113,7 @@
113113
#' @param max_iter If the stopping criterion (see `tol`) is not reached after
114114
#' `max_iter` iterations, the algorithm stops. Ignored if `exact = TRUE`.
115115
#' @param parallel If `TRUE`, use parallel [foreach::foreach()] to loop over rows
116-
#' to be explained. Must register backend beforehand, e.g., via {doFuture} package,
116+
#' to be explained. Must register backend beforehand, e.g., via 'doFuture' package,
117117
#' see README for an example. Parallelization automatically disables the progress bar.
118118
#' @param parallel_args Named list of arguments passed to [foreach::foreach()].
119119
#' Ideally, this is `NULL` (default). Only relevant if `parallel = TRUE`.
@@ -191,19 +191,9 @@ kernelshap.default <- function(object, X, bg_X, pred_fun = stats::predict,
191191
m = 2L * length(feature_names) * (1L + 3L * (hybrid_degree == 0L)),
192192
tol = 0.005, max_iter = 100L, parallel = FALSE,
193193
parallel_args = NULL, verbose = TRUE, ...) {
194+
basic_checks(X = X, bg_X = bg_X, feature_names = feature_names, pred_fun = pred_fun)
195+
p <- length(feature_names)
194196
stopifnot(
195-
is.matrix(X) || is.data.frame(X),
196-
is.matrix(bg_X) || is.data.frame(bg_X),
197-
is.matrix(X) == is.matrix(bg_X),
198-
dim(X) >= 1L,
199-
dim(bg_X) >= 1L,
200-
!is.null(colnames(X)),
201-
!is.null(colnames(bg_X)),
202-
(p <- length(feature_names)) >= 1L,
203-
all(feature_names %in% colnames(X)),
204-
all(feature_names %in% colnames(bg_X)), # not necessary, but clearer
205-
all(colnames(X) %in% colnames(bg_X)),
206-
is.function(pred_fun),
207197
exact %in% c(TRUE, FALSE),
208198
p == 1L || exact || hybrid_degree %in% 0:(p / 2),
209199
paired_sampling %in% c(TRUE, FALSE),
@@ -212,27 +202,20 @@ kernelshap.default <- function(object, X, bg_X, pred_fun = stats::predict,
212202
n <- nrow(X)
213203
bg_n <- nrow(bg_X)
214204
if (!is.null(bg_w)) {
215-
stopifnot(length(bg_w) == bg_n, all(bg_w >= 0), !all(bg_w == 0))
216-
if (!is.double(bg_w)) {
217-
bg_w <- as.double(bg_w)
218-
}
219-
}
220-
if (is.matrix(X) && !identical(colnames(X), feature_names)) {
221-
stop("If X is a matrix, feature_names must equal colnames(X)")
205+
bg_w <- prep_w(bg_w, bg_n = bg_n)
222206
}
223207

224208
# Calculate v1 and v0
225209
v1 <- align_pred(pred_fun(object, X, ...)) # Predictions on X: n x K
226210
bg_preds <- align_pred(pred_fun(object, bg_X[, colnames(X), drop = FALSE], ...))
227-
v0 <- weighted_colMeans(bg_preds, bg_w) # Average pred of bg data: 1 x K
211+
v0 <- wcolMeans(bg_preds, bg_w) # Average pred of bg data: 1 x K
228212

229213
# For p = 1, exact Shapley values are returned
230214
if (p == 1L) {
231-
return(
232-
case_p1(
233-
n = n, feature_names = feature_names, v0 = v0, v1 = v1, X = X, verbose = verbose
234-
)
215+
out <- case_p1(
216+
n = n, feature_names = feature_names, v0 = v0, v1 = v1, X = X, verbose = verbose
235217
)
218+
return(out)
236219
}
237220

238221
# Drop unnecessary columns in bg_X. If X is matrix, also column order is relevant
@@ -266,9 +249,7 @@ kernelshap.default <- function(object, X, bg_X, pred_fun = stats::predict,
266249
message(txt)
267250
}
268251
if (max(m, m_exact) * bg_n > 2e5) {
269-
warning("\nPredictions on large data sets with ", max(m, m_exact), "x", bg_n,
270-
" observations are being done.\n",
271-
"Consider reducing the computational burden (e.g. use smaller X_bg)")
252+
warning_burden(max(m, m_exact), bg_n = bg_n)
272253
}
273254

274255
# Apply Kernel SHAP to each row of X

R/methods.R

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,12 @@ print.kernelshap <- function(x, n = 2L, ...) {
1818
invisible(x)
1919
}
2020

21+
#' @describeIn print.kernelshap Print method for "permshap" object
22+
#' @export
23+
print.permshap <- function(x, n = 2L, ...) {
24+
print.kernelshap(x, n = n, ...)
25+
}
26+
2127
#' Summary Method
2228
#'
2329
#' @param object An object of class "kernelshap".
@@ -76,11 +82,28 @@ summary.kernelshap <- function(object, compact = FALSE, n = 2L, ...) {
7682
#' @returns `TRUE` if `object` is of class "kernelshap", and `FALSE` otherwise.
7783
#' @export
7884
#' @examples
79-
#' fit <- stats::lm(Sepal.Length ~ ., data = iris)
85+
#' fit <- lm(Sepal.Length ~ ., data = iris)
8086
#' s <- kernelshap(fit, iris[1:2, -1], bg_X = iris[-1])
8187
#' is.kernelshap(s)
8288
#' is.kernelshap("a")
8389
#' @seealso [kernelshap()]
8490
is.kernelshap <- function(object){
8591
inherits(object, "kernelshap")
8692
}
93+
94+
#' Check for permshap
95+
#'
96+
#' Is object of class "permshap"?
97+
#'
98+
#' @param object An R object.
99+
#' @returns `TRUE` if `object` is of class "permshap", and `FALSE` otherwise.
100+
#' @export
101+
#' @examples
102+
#' fit <- lm(Sepal.Length ~ ., data = iris)
103+
#' s <- permshap(fit, iris[1:2, -1], bg_X = iris[-1])
104+
#' is.permshap(s)
105+
#' is.permshap("a")
106+
#' @seealso [kernelshap()]
107+
is.permshap <- function(object){
108+
inherits(object, "permshap")
109+
}

R/permshap.R

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
#' Permutation SHAP
2+
#'
3+
#' Exact permutation SHAP values with respect to a background dataset.
4+
#' The function is currently limited to maximum 14 features.
5+
#'
6+
#' @inheritParams kernelshap
7+
#' @returns
8+
#' An object of class "permshap" with the following components:
9+
#' - `S`: \eqn{(n \times p)} matrix with SHAP values or, if the model output has
10+
#' dimension \eqn{K > 1}, a list of \eqn{K} such matrices.
11+
#' - `X`: Same as input argument `X`.
12+
#' - `baseline`: Vector of length K representing the average prediction on the
13+
#' background data.
14+
#' - `m_exact`: Integer providing the effective number of exact on-off vectors used.
15+
#' - `exact`: Logical flag indicating whether calculations are exact or not
16+
#' (currently `TRUE`).
17+
#' - `txt`: Summary text.
18+
#' - `predictions`: \eqn{(n \times K)} matrix with predictions of `X`.
19+
#' @export
20+
#' @examples
21+
#' # MODEL ONE: Linear regression
22+
#' fit <- lm(Sepal.Length ~ ., data = iris)
23+
#'
24+
#' # Select rows to explain (only feature columns)
25+
#' X_explain <- iris[1:2, -1]
26+
#'
27+
#' # Select small background dataset (could use all rows here because iris is small)
28+
#' set.seed(1)
29+
#' bg_X <- iris[sample(nrow(iris), 100), ]
30+
#'
31+
#' # Calculate SHAP values
32+
#' s <- permshap(fit, X_explain, bg_X = bg_X)
33+
#' s
34+
#'
35+
#' # MODEL TWO: Multi-response linear regression
36+
#' fit <- lm(as.matrix(iris[1:2]) ~ Petal.Length + Petal.Width + Species, data = iris)
37+
#' s <- permshap(fit, iris[1:4, 3:5], bg_X = bg_X)
38+
#' s
39+
#'
40+
#' # Non-feature columns can be dropped via 'feature_names'
41+
#' s <- permshap(
42+
#' fit,
43+
#' iris[1:4, ],
44+
#' bg_X = bg_X,
45+
#' feature_names = c("Petal.Length", "Petal.Width", "Species")
46+
#' )
47+
#' s
48+
permshap <- function(object, ...) {
49+
UseMethod("permshap")
50+
}
51+
52+
#' @describeIn permshap Default permutation SHAP method.
53+
#' @export
54+
permshap.default <- function(object, X, bg_X, pred_fun = stats::predict,
55+
feature_names = colnames(X), bg_w = NULL,
56+
parallel = FALSE, parallel_args = NULL,
57+
verbose = TRUE, ...) {
58+
basic_checks(X = X, bg_X = bg_X, feature_names = feature_names, pred_fun = pred_fun)
59+
p <- length(feature_names)
60+
stopifnot("Permutation SHAP only supported for up to 14 features" = p <= 14L)
61+
n <- nrow(X)
62+
bg_n <- nrow(bg_X)
63+
if (!is.null(bg_w)) {
64+
bg_w <- prep_w(bg_w, bg_n = bg_n)
65+
}
66+
txt <- "Exact permutation SHAP"
67+
if (verbose) {
68+
message(txt)
69+
}
70+
71+
# Baseline and predictions on explanation data (latter not required in algo)
72+
bg_preds <- align_pred(pred_fun(object, bg_X[, colnames(X), drop = FALSE], ...))
73+
v0 <- wcolMeans(bg_preds, bg_w) # Average pred of bg data: 1 x K
74+
v1 <- align_pred(pred_fun(object, X, ...)) # Predictions on X: n x K
75+
76+
# Drop unnecessary columns in bg_X. If X is matrix, also column order is relevant
77+
# Predictions will never be applied directly to bg_X anymore
78+
if (!identical(colnames(bg_X), feature_names)) {
79+
bg_X <- bg_X[, feature_names, drop = FALSE]
80+
}
81+
82+
# Precalculations that are identical for each row to be explained
83+
Z <- exact_Z(p, feature_names = feature_names, keep_extremes = TRUE)
84+
m_exact <- nrow(Z)
85+
precalc <- list(
86+
Z = Z,
87+
Z_code = rowpaste(Z),
88+
bg_X_rep = bg_X[rep(seq_len(bg_n), times = m_exact), , drop = FALSE]
89+
)
90+
91+
if (m_exact * bg_n > 2e5) {
92+
warning_burden(m_exact, bg_n = bg_n)
93+
}
94+
95+
# Apply permutation SHAP to each row of X
96+
if (isTRUE(parallel)) {
97+
parallel_args <- c(list(i = seq_len(n)), parallel_args)
98+
res <- do.call(foreach::foreach, parallel_args) %dopar% permshap_one(
99+
x = X[i, , drop = FALSE],
100+
object = object,
101+
pred_fun = pred_fun,
102+
bg_w = bg_w,
103+
precalc = precalc,
104+
...
105+
)
106+
} else {
107+
if (verbose && n >= 2L) {
108+
pb <- utils::txtProgressBar(max = n, style = 3)
109+
}
110+
res <- vector("list", n)
111+
for (i in seq_len(n)) {
112+
res[[i]] <- permshap_one(
113+
x = X[i, , drop = FALSE],
114+
object = object,
115+
pred_fun = pred_fun,
116+
bg_w = bg_w,
117+
precalc = precalc,
118+
...
119+
)
120+
if (verbose && n >= 2L) {
121+
utils::setTxtProgressBar(pb, i)
122+
}
123+
}
124+
}
125+
out <- list(
126+
S = reorganize_list(res),
127+
X = X,
128+
baseline = as.vector(v0),
129+
m_exact = m_exact,
130+
exact = TRUE,
131+
txt = txt,
132+
predictions = v1
133+
)
134+
class(out) <- "permshap"
135+
out
136+
}
137+
138+
#' @describeIn permshap Permutation SHAP method for "ranger" models, see Readme for an example.
139+
#' @export
140+
permshap.ranger <- function(object, X, bg_X,
141+
pred_fun = function(m, X, ...) stats::predict(m, X, ...)$predictions,
142+
feature_names = colnames(X),
143+
bg_w = NULL, parallel = FALSE, parallel_args = NULL,
144+
verbose = TRUE, ...) {
145+
permshap.default(
146+
object = object,
147+
X = X,
148+
bg_X = bg_X,
149+
pred_fun = pred_fun,
150+
feature_names = feature_names,
151+
bg_w = bg_w,
152+
parallel = parallel,
153+
parallel_args = parallel_args,
154+
verbose = verbose,
155+
...
156+
)
157+
}
158+
159+
#' @describeIn permshap Permutation SHAP method for "mlr3" models, see Readme for an example.
160+
#' @export
161+
permshap.Learner <- function(object, X, bg_X,
162+
pred_fun = NULL,
163+
feature_names = colnames(X),
164+
bg_w = NULL, parallel = FALSE, parallel_args = NULL,
165+
verbose = TRUE, ...) {
166+
if (is.null(pred_fun)) {
167+
pred_fun <- mlr3_pred_fun(object, X = X)
168+
}
169+
permshap.default(
170+
object = object,
171+
X = X,
172+
bg_X = bg_X,
173+
pred_fun = pred_fun,
174+
feature_names = feature_names,
175+
bg_w = bg_w,
176+
parallel = parallel,
177+
parallel_args = parallel_args,
178+
verbose = verbose,
179+
...
180+
)
181+
}

0 commit comments

Comments
 (0)