Skip to content

Commit cfd893f

Browse files
committed
Initial work to add permshap()
1 parent 09c4226 commit cfd893f

File tree

9 files changed

+508
-87
lines changed

9 files changed

+508
-87
lines changed

NAMESPACE

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,12 @@
33
S3method(kernelshap,Learner)
44
S3method(kernelshap,default)
55
S3method(kernelshap,ranger)
6+
S3method(permshap,Learner)
7+
S3method(permshap,default)
8+
S3method(permshap,ranger)
69
S3method(print,kernelshap)
710
S3method(summary,kernelshap)
811
export(is.kernelshap)
912
export(kernelshap)
13+
export(permshap)
1014
importFrom(foreach,"%dopar%")

R/kernelshap.R

Lines changed: 10 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@
6868
#' In cases with a natural "off" value (like MNIST digits),
6969
#' this can also be a single row with all values set to the off value.
7070
#' @param pred_fun Prediction function of the form `function(object, X, ...)`,
71-
#' providing \eqn{K \ge 1} numeric predictions per row. Its first argument
71+
#' providing \eqn{K \ge 1} predictions per row. Its first argument
7272
#' represents the model `object`, its second argument a data structure like `X`.
7373
#' Additional (named) arguments are passed via `...`.
7474
#' The default, [stats::predict()], will work in most cases.
@@ -113,7 +113,7 @@
113113
#' @param max_iter If the stopping criterion (see `tol`) is not reached after
114114
#' `max_iter` iterations, the algorithm stops. Ignored if `exact = TRUE`.
115115
#' @param parallel If `TRUE`, use parallel [foreach::foreach()] to loop over rows
116-
#' to be explained. Must register backend beforehand, e.g., via {doFuture} package,
116+
#' to be explained. Must register backend beforehand, e.g., via 'doFuture' package,
117117
#' see README for an example. Parallelization automatically disables the progress bar.
118118
#' @param parallel_args Named list of arguments passed to [foreach::foreach()].
119119
#' Ideally, this is `NULL` (default). Only relevant if `parallel = TRUE`.
@@ -191,31 +191,19 @@ kernelshap.default <- function(object, X, bg_X, pred_fun = stats::predict,
191191
m = 2L * length(feature_names) * (1L + 3L * (hybrid_degree == 0L)),
192192
tol = 0.005, max_iter = 100L, parallel = FALSE,
193193
parallel_args = NULL, verbose = TRUE, ...) {
194+
basic_checks(X = X, bg_X = bg_X, feature_names = feature_names, pred_fun = pred_fun)
195+
p <- length(feature_names)
194196
stopifnot(
195-
is.matrix(X) || is.data.frame(X),
196-
is.matrix(bg_X) || is.data.frame(bg_X),
197-
is.matrix(X) == is.matrix(bg_X),
198-
dim(X) >= 1L,
199-
dim(bg_X) >= 1L,
200-
!is.null(colnames(X)),
201-
!is.null(colnames(bg_X)),
202-
(p <- length(feature_names)) >= 1L,
203-
all(feature_names %in% colnames(X)),
204-
all(feature_names %in% colnames(bg_X)), # not necessary, but clearer
205-
all(colnames(X) %in% colnames(bg_X)),
206-
is.function(pred_fun),
207197
exact %in% c(TRUE, FALSE),
208198
p == 1L || exact || hybrid_degree %in% 0:(p / 2),
209199
paired_sampling %in% c(TRUE, FALSE),
210200
"m must be even" = trunc(m / 2) == m / 2
211201
)
202+
p <- length(feature_names)
212203
n <- nrow(X)
213204
bg_n <- nrow(bg_X)
214205
if (!is.null(bg_w)) {
215-
stopifnot(length(bg_w) == bg_n, all(bg_w >= 0), !all(bg_w == 0))
216-
if (!is.double(bg_w)) {
217-
bg_w <- as.double(bg_w)
218-
}
206+
bg_w <- prep_w(bg_w, bg_n = bg_n)
219207
}
220208
if (is.matrix(X) && !identical(colnames(X), feature_names)) {
221209
stop("If X is a matrix, feature_names must equal colnames(X)")
@@ -224,15 +212,14 @@ kernelshap.default <- function(object, X, bg_X, pred_fun = stats::predict,
224212
# Calculate v1 and v0
225213
v1 <- align_pred(pred_fun(object, X, ...)) # Predictions on X: n x K
226214
bg_preds <- align_pred(pred_fun(object, bg_X[, colnames(X), drop = FALSE], ...))
227-
v0 <- weighted_colMeans(bg_preds, bg_w) # Average pred of bg data: 1 x K
215+
v0 <- wcolMeans(bg_preds, bg_w) # Average pred of bg data: 1 x K
228216

229217
# For p = 1, exact Shapley values are returned
230218
if (p == 1L) {
231-
return(
232-
case_p1(
233-
n = n, feature_names = feature_names, v0 = v0, v1 = v1, X = X, verbose = verbose
234-
)
219+
out <- case_p1(
220+
n = n, feature_names = feature_names, v0 = v0, v1 = v1, X = X, verbose = verbose
235221
)
222+
return(out)
236223
}
237224

238225
# Drop unnecessary columns in bg_X. If X is matrix, also column order is relevant

R/permshap.R

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
#' Permutation SHAP
2+
#'
3+
#' Exact permutation SHAP values with respect to a background dataset.
4+
#'
5+
#' @inheritParams kernelshap
6+
#' @returns
7+
#' An object of class "permshap" with the following components:
8+
#' - `S`: \eqn{(n \times p)} matrix with SHAP values or, if the model output has
9+
#' dimension \eqn{K > 1}, a list of \eqn{K} such matrices.
10+
#' - `X`: Same as input argument `X`.
11+
#' - `baseline`: Vector of length K representing the average prediction on the
12+
#' background data.
13+
#' @export
14+
#' @examples
15+
#' # MODEL ONE: Linear regression
16+
#' fit <- lm(Sepal.Length ~ ., data = iris)
17+
#'
18+
#' # Select rows to explain (only feature columns)
19+
#' X_explain <- iris[1:2, -1]
20+
#'
21+
#' # Select small background dataset (could use all rows here because iris is small)
22+
#' set.seed(1)
23+
#' bg_X <- iris[sample(nrow(iris), 100), ]
24+
#'
25+
#' # Calculate SHAP values
26+
#' s <- permshap(fit, X_explain, bg_X = bg_X)
27+
#' s
28+
#'
29+
#' # MODEL TWO: Multi-response linear regression
30+
#' fit <- lm(as.matrix(iris[1:2]) ~ Petal.Length + Petal.Width + Species, data = iris)
31+
#' s <- permshap(fit, iris[1:4, 3:5], bg_X = bg_X)
32+
#' s
33+
#'
34+
#' # Non-feature columns can be dropped via 'feature_names'
35+
#' s <- permshap(
36+
#' fit,
37+
#' iris[1:4, ],
38+
#' bg_X = bg_X,
39+
#' feature_names = c("Petal.Length", "Petal.Width", "Species")
40+
#' )
41+
#' s
42+
permshap <- function(object, ...) {
43+
UseMethod("permshap")
44+
}
45+
46+
#' @describeIn permshap Default permutation SHAP method.
47+
#' @export
48+
permshap.default <- function(object, X, bg_X, pred_fun = stats::predict,
49+
feature_names = colnames(X), bg_w = NULL,
50+
parallel = FALSE, parallel_args = NULL,
51+
verbose = TRUE, ...) {
52+
basic_checks(X = X, bg_X = bg_X, feature_names = feature_names, pred_fun = pred_fun)
53+
p <- length(feature_names)
54+
stopifnot(p <= 14L)
55+
n <- nrow(X)
56+
bg_n <- nrow(bg_X)
57+
if (!is.null(bg_w)) {
58+
bg_w <- prep_w(bg_w, bg_n = bg_n)
59+
}
60+
if (is.matrix(X) && !identical(colnames(X), feature_names)) {
61+
stop("If X is a matrix, feature_names must equal colnames(X)")
62+
}
63+
64+
if (verbose) {
65+
message("Exact permutation SHAP values")
66+
}
67+
68+
# Baseline
69+
bg_preds <- align_pred(pred_fun(object, bg_X[, colnames(X), drop = FALSE], ...))
70+
v0 <- wcolMeans(bg_preds, bg_w) # Average pred of bg data: 1 x K
71+
72+
# Drop unnecessary columns in bg_X. If X is matrix, also column order is relevant
73+
# Predictions will never be applied directly to bg_X anymore
74+
if (!identical(colnames(bg_X), feature_names)) {
75+
bg_X <- bg_X[, feature_names, drop = FALSE]
76+
}
77+
78+
# Precalculations that are identical for each row to be explained
79+
Z <- exact_Z(p, feature_names = feature_names, keep_extremes = TRUE)
80+
m_exact <- nrow(Z)
81+
precalc <- list(
82+
Z = Z,
83+
Z_code = rowpaste(Z),
84+
bg_X_rep = bg_X[rep(seq_len(bg_n), times = m_exact), , drop = FALSE]
85+
)
86+
87+
if (m_exact * bg_n > 2e5) {
88+
warning("\nPredictions on large data sets with ", m_exact, "x", bg_n,
89+
" observations are being done.\n",
90+
"Consider reducing the computational burden (e.g. use smaller X_bg)")
91+
}
92+
93+
# Apply permutation SHAP to each row of X
94+
if (isTRUE(parallel)) {
95+
parallel_args <- c(list(i = seq_len(n)), parallel_args)
96+
res <- do.call(foreach::foreach, parallel_args) %dopar% permshap_one(
97+
x = X[i, , drop = FALSE],
98+
object = object,
99+
pred_fun = pred_fun,
100+
bg_w = bg_w,
101+
precalc = precalc,
102+
...
103+
)
104+
} else {
105+
if (verbose && n >= 2L) {
106+
pb <- utils::txtProgressBar(max = n, style = 3)
107+
}
108+
res <- vector("list", n)
109+
for (i in seq_len(n)) {
110+
res[[i]] <- permshap_one(
111+
x = X[i, , drop = FALSE],
112+
object = object,
113+
pred_fun = pred_fun,
114+
bg_w = bg_w,
115+
precalc = precalc,
116+
...
117+
)
118+
if (verbose && n >= 2L) {
119+
utils::setTxtProgressBar(pb, i)
120+
}
121+
}
122+
}
123+
out <- list(S = reorganize_list(res), X = X, baseline = as.vector(v0))
124+
class(out) <- "permshap"
125+
out
126+
}
127+
128+
#' @describeIn permshap Permutation SHAP method for "ranger" models, see Readme for an example.
129+
#' @export
130+
permshap.ranger <- function(object, X, bg_X,
131+
pred_fun = function(m, X, ...) stats::predict(m, X, ...)$predictions,
132+
feature_names = colnames(X),
133+
bg_w = NULL, parallel = FALSE, parallel_args = NULL,
134+
verbose = TRUE, ...) {
135+
permshap.default(
136+
object = object,
137+
X = X,
138+
bg_X = bg_X,
139+
pred_fun = pred_fun,
140+
feature_names = feature_names,
141+
bg_w = bg_w,
142+
parallel = parallel,
143+
parallel_args = parallel_args,
144+
verbose = verbose,
145+
...
146+
)
147+
}
148+
149+
#' @describeIn permshap Permutation SHAP method for "mlr3" models, see Readme for an example.
150+
#' @export
151+
permshap.Learner <- function(object, X, bg_X,
152+
pred_fun = NULL,
153+
feature_names = colnames(X),
154+
bg_w = NULL, parallel = FALSE, parallel_args = NULL,
155+
verbose = TRUE, ...) {
156+
if (is.null(pred_fun)) {
157+
pred_fun <- mlr3_pred_fun(object, X = X)
158+
}
159+
permshap.default(
160+
object = object,
161+
X = X,
162+
bg_X = bg_X,
163+
pred_fun = pred_fun,
164+
feature_names = feature_names,
165+
bg_w = bg_w,
166+
parallel = parallel,
167+
parallel_args = parallel_args,
168+
verbose = verbose,
169+
...
170+
)
171+
}

R/utils.R

Lines changed: 101 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
#' @param x A matrix-like object.
99
#' @param w Optional case weights.
1010
#' @returns A (1 x ncol(x)) matrix of column means.
11-
weighted_colMeans <- function(x, w = NULL, ...) {
11+
wcolMeans <- function(x, w = NULL, ...) {
1212
if (NCOL(x) == 1L && is.null(w)) {
1313
return(as.matrix(mean(x)))
1414
}
@@ -18,6 +18,66 @@ weighted_colMeans <- function(x, w = NULL, ...) {
1818
rbind(if (is.null(w)) colMeans(x) else colSums(x * w) / sum(w))
1919
}
2020

21+
#' All on-off Vectors
22+
#'
23+
#' Internal function that creates matrix of all on-off vectors of length `p`.
24+
#'
25+
#' @noRd
26+
#' @keywords internal
27+
#'
28+
#' @param p Number of features.
29+
#' @param feature_names Feature names.
30+
#' @param keep_extremes Should extremes be kept? Defaults to `FALSE` (for kernelshap).
31+
#' @returns An integer matrix of all on-off vectors of length `p`.
32+
exact_Z <- function(p, feature_names, keep_extremes = FALSE) {
33+
Z <- as.matrix(do.call(expand.grid, replicate(p, 0:1, simplify = FALSE)))
34+
colnames(Z) <- feature_names
35+
if (keep_extremes) Z else Z[2:(nrow(Z) - 1L), , drop = FALSE]
36+
}
37+
38+
#' Masker
39+
#'
40+
#' Internal function.
41+
#' For each on-off vector (rows in `Z`), the (weighted) average prediction is returned.
42+
#'
43+
#' @noRd
44+
#' @keywords internal
45+
#'
46+
#' @inheritParams kernelshap
47+
#' @param X Row to be explained stacked m*n_bg times.
48+
#' @param bg Background data stacked m times.
49+
#' @param Z A (m x p) matrix with on-off values.
50+
#' @param w A vector with case weights (of the same length as the unstacked
51+
#' background data).
52+
#' @returns A (m x K) matrix with vz values.
53+
get_vz <- function(X, bg, Z, object, pred_fun, w, ...) {
54+
m <- nrow(Z)
55+
not_Z <- !Z
56+
n_bg <- nrow(bg) / m # because bg was replicated m times
57+
58+
# Replicate not_Z, so that X, bg, not_Z are all of dimension (m*n_bg x p)
59+
g <- rep_each(m, each = n_bg)
60+
not_Z <- not_Z[g, , drop = FALSE]
61+
62+
if (is.matrix(X)) {
63+
# Remember that columns of X and bg are perfectly aligned in this case
64+
X[not_Z] <- bg[not_Z]
65+
} else {
66+
for (v in colnames(Z)) {
67+
s <- not_Z[, v]
68+
X[[v]][s] <- bg[[v]][s]
69+
}
70+
}
71+
preds <- align_pred(pred_fun(object, X, ...))
72+
73+
# Aggregate
74+
if (is.null(w)) {
75+
return(rowsum(preds, group = g, reorder = FALSE) / n_bg)
76+
}
77+
# w is recycled over rows and columns
78+
rowsum(preds * w, group = g, reorder = FALSE) / sum(w)
79+
}
80+
2181
#' Combine Matrices
2282
#'
2383
#' Binds list of matrices along new first axis.
@@ -183,6 +243,46 @@ fdummy <- function(x) {
183243
out
184244
}
185245

246+
#' Basic Input Checks
247+
#'
248+
#' @noRd
249+
#' @keywords internal
250+
#'
251+
#' @inheritParams kernelshap
252+
#'
253+
#' @returns TRUE or an error
254+
basic_checks <- function(X, bg_X, feature_names, pred_fun) {
255+
stopifnot(
256+
is.matrix(X) || is.data.frame(X),
257+
is.matrix(bg_X) || is.data.frame(bg_X),
258+
is.matrix(X) == is.matrix(bg_X),
259+
dim(X) >= 1L,
260+
dim(bg_X) >= 1L,
261+
!is.null(colnames(X)),
262+
!is.null(colnames(bg_X)),
263+
length(feature_names) >= 1L,
264+
all(feature_names %in% colnames(X)),
265+
all(feature_names %in% colnames(bg_X)), # not necessary, but clearer
266+
all(colnames(X) %in% colnames(bg_X)),
267+
is.function(pred_fun)
268+
)
269+
TRUE
270+
}
271+
272+
#' Prepare Case Weights
273+
#'
274+
#' @noRd
275+
#' @keywords internal
276+
#'
277+
#' @param w Vector of case weights.
278+
#' @param bg_n Number of rows in the background data.
279+
#'
280+
#' @returns TRUE or an error
281+
prep_w <- function(w, bg_n) {
282+
stopifnot(length(w) == bg_n, all(w >= 0), !all(w == 0))
283+
if (!is.double(w)) as.double(w) else w
284+
}
285+
186286
#' mlr3 Helper
187287
#'
188288
#' Returns the prediction function of a mlr3 Learner.

0 commit comments

Comments
 (0)