Skip to content

Commit 62f4958

Browse files
authored
Merge pull request #175 from ModelOriented/rep-Z
Precalculate Z
2 parents d86ab1e + d02b623 commit 62f4958

File tree

11 files changed

+217
-104
lines changed

11 files changed

+217
-104
lines changed

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Package: kernelshap
22
Title: Kernel SHAP
3-
Version: 0.9.0
3+
Version: 0.9.1
44
Authors@R: c(
55
person("Michael", "Mayer", , "mayermichael79@gmail.com", role = c("aut", "cre"),
66
comment = c(ORCID = "0009-0007-2540-9629")),

NEWS.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
# kernelshap 0.9.1
2+
3+
### Speed and memory improvements
4+
5+
- More pre-calculations for exact part of the methods ([#175](https://github.com/ModelOriented/kernelshap/pull/175)).
6+
17
# kernelshap 0.9.0
28

39
### Bug fix

R/kernelshap.R

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -242,16 +242,19 @@ kernelshap.default <- function(
242242
p = p, deg = hybrid_degree, feature_names = feature_names
243243
)
244244
}
245-
m_exact <- nrow(precalc[["Z"]])
245+
Z <- precalc[["Z"]]
246+
m_exact <- nrow(Z)
246247
prop_exact <- sum(precalc[["w"]])
247-
precalc[["bg_X_exact"]] <- rep_rows(bg_X, rep.int(seq_len(bg_n), m_exact))
248+
precalc[["bg_exact_rep"]] <- rep_rows(bg_X, rep.int(seq_len(bg_n), m_exact))
249+
g <- rep_each(m_exact, each = bg_n)
250+
precalc[["Z_exact_rep"]] <- Z[g, , drop = FALSE]
248251
} else {
249252
precalc <- list()
250253
m_exact <- 0L
251254
prop_exact <- 0
252255
}
253256
if (!exact) {
254-
precalc[["bg_X_m"]] <- rep_rows(bg_X, rep.int(seq_len(bg_n), m))
257+
precalc[["bg_sampling_rep"]] <- rep_rows(bg_X, rep.int(seq_len(bg_n), m))
255258
}
256259

257260
if (max(m, m_exact) * bg_n > 2e5) {
@@ -276,6 +279,7 @@ kernelshap.default <- function(
276279
max_iter = max_iter,
277280
v0 = v0,
278281
precalc = precalc,
282+
bg_n = bg_n,
279283
...
280284
)
281285
} else {
@@ -298,6 +302,7 @@ kernelshap.default <- function(
298302
max_iter = max_iter,
299303
v0 = v0,
300304
precalc = precalc,
305+
bg_n = bg_n,
301306
...
302307
)
303308
if (verbose && n >= 2L) {

R/permshap.R

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -143,23 +143,27 @@ permshap.default <- function(
143143
# Pre-calculations that are identical for each row to be explained
144144
if (exact) {
145145
Z <- exact_Z(p, feature_names = feature_names)
146-
m_exact <- nrow(Z) - 2L # We won't evaluate vz for first and last row
146+
Z_no_extremes <- Z[2L:(nrow(Z) - 1L), , drop = FALSE]
147+
m_exact <- nrow(Z_no_extremes) # 2^p - 2
147148
m_eval <- 0L # for consistency with sampling case
149+
g <- rep_each(m_exact, each = bg_n)
148150
precalc <- list(
149-
Z = Z,
150-
bg_X_rep = rep_rows(bg_X, rep.int(seq_len(bg_n), m_exact)),
151+
Z_exact_rep = Z_no_extremes[g, , drop = FALSE],
152+
bg_exact_rep = rep_rows(bg_X, rep.int(seq_len(bg_n), m_exact)),
151153
positions = positions_for_exact(Z),
152154
shapley_w = shapley_weights(p, ell = rowSums(Z) - 1) # how many other players?
153155
)
154156
} else {
155157
max_iter <- as.integer(ceiling(max_iter / p) * p) # should be multiple of p
156-
m_exact <- 2L * p
157-
m <- 2L * (p - 3L) # inner loop
158+
Z <- exact_Z_balanced(p, feature_names)
159+
m_exact <- nrow(Z) # 2L * p
160+
m <- 2L * (p - 3L) # for inner loop
158161
m_eval <- if (low_memory) m else m * p # outer loop
162+
g <- rep_each(m_exact, each = bg_n)
159163
precalc <- list(
160-
bg_X_rep = rep_rows(bg_X, rep.int(seq_len(bg_n), m_eval)),
161-
bg_X_balanced = rep_rows(bg_X, rep.int(seq_len(bg_n), m_exact)),
162-
Z_balanced = exact_Z_balanced(p, feature_names)
164+
Z_balanced_rep = Z[g, , drop = FALSE],
165+
bg_balanced_rep = rep_rows(bg_X, rep.int(seq_len(bg_n), m_exact)),
166+
bg_sampling_rep = rep_rows(bg_X, rep.int(seq_len(bg_n), m_eval))
163167
)
164168
}
165169

@@ -184,6 +188,7 @@ permshap.default <- function(
184188
low_memory = low_memory,
185189
tol = tol,
186190
max_iter = max_iter,
191+
bg_n = bg_n,
187192
...
188193
)
189194
} else {
@@ -205,6 +210,7 @@ permshap.default <- function(
205210
low_memory = low_memory,
206211
tol = tol,
207212
max_iter = max_iter,
213+
bg_n = bg_n,
208214
...
209215
)
210216
if (verbose && n >= 2L) {

R/utils.R

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -73,36 +73,32 @@ exact_Z <- function(p, feature_names) {
7373
#'
7474
#' @inheritParams kernelshap
7575
#' @param x Row to be explained.
76-
#' @param bg Background data stacked m times.
77-
#' @param Z A logical (m x p) matrix with on-off values.
76+
#' @param bg_rep Background data stacked m times.
77+
#' @param Z_rep A logical ((m * bg_n) x p) matrix with on-off values.
7878
#' @param w A vector with case weights (of the same length as the unstacked
7979
#' background data).
80+
#' @param bg_n Size of background dataset (unstacked).
8081
#' @returns A (m x K) matrix with vz values.
81-
get_vz <- function(x, bg, Z, object, pred_fun, w, ...) {
82-
m <- nrow(Z)
83-
n_bg <- nrow(bg) / m # because bg was replicated m times
84-
85-
# Replicate Z, so that bg and Z are of dimension (m*n_bg x p)
86-
g <- rep_each(m, each = n_bg)
87-
Z_rep <- Z[g, , drop = FALSE]
88-
89-
for (v in colnames(Z)) {
82+
get_vz <- function(x, bg_rep, Z_rep, object, pred_fun, w, bg_n, ...) {
83+
for (v in colnames(Z_rep)) {
9084
s <- Z_rep[, v]
9185
if (is.matrix(x)) {
92-
bg[s, v] <- x[, v]
86+
bg_rep[s, v] <- x[, v]
9387
} else {
94-
bg[[v]][s] <- x[[v]]
88+
bg_rep[[v]][s] <- x[[v]]
9589
}
9690
}
9791

98-
preds <- align_pred(pred_fun(object, bg, ...))
92+
preds <- align_pred(pred_fun(object, bg_rep, ...))
9993

10094
# Aggregate (distinguishing fast 1-dim case)
95+
m <- nrow(Z_rep) %/% bg_n
10196
if (ncol(preds) == 1L) {
10297
return(wrowmean_vector(preds, ngroups = m, w = w))
10398
}
99+
g <- rep_each(m, each = bg_n)
104100
if (is.null(w)) {
105-
return(rowsum(preds, group = g, reorder = FALSE) / n_bg)
101+
return(rowsum(preds, group = g, reorder = FALSE) / bg_n)
106102
}
107103
rowsum(preds * w, group = g, reorder = FALSE) / sum(w)
108104
}

R/utils_kernelshap.R

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,25 +15,32 @@ kernelshap_one <- function(
1515
max_iter,
1616
v0,
1717
precalc,
18+
bg_n,
1819
...) {
1920
p <- length(feature_names)
2021
K <- ncol(v1)
2122
K_names <- colnames(v1)
2223

2324
# Calculate A_exact and b_exact
2425
if (exact || deg >= 1L) {
25-
A_exact <- precalc[["A"]] # (p x p)
26-
bg_X_exact <- precalc[["bg_X_exact"]] # (m_ex*n_bg x p)
27-
Z <- precalc[["Z"]] # (m_ex x p)
26+
A_exact <- precalc$A # (p x p)
27+
Z <- precalc$Z # (m_ex x p)
2828
m_exact <- nrow(Z)
2929
v0_m_exact <- v0[rep.int(1L, m_exact), , drop = FALSE] # (m_ex x K)
3030

3131
# Most expensive part
3232
vz <- get_vz(
33-
x = x, bg = bg_X_exact, Z = Z, object = object, pred_fun = pred_fun, w = bg_w, ...
33+
x = x,
34+
bg_rep = precalc$bg_exact_rep, # (m_ex*bg_n x p)
35+
Z_rep = precalc$Z_exact_rep, # (m_ex*bg_n x p)
36+
object = object,
37+
pred_fun = pred_fun,
38+
w = bg_w,
39+
bg_n = bg_n,
40+
...
3441
)
3542
# Note: w is correctly replicated along columns of (vz - v0_m_exact)
36-
b_exact <- crossprod(Z, precalc[["w"]] * (vz - v0_m_exact)) # (p x K)
43+
b_exact <- crossprod(Z, precalc$w * (vz - v0_m_exact)) # (p x K)
3744

3845
# Some of the hybrid cases are exact as well
3946
if (exact || trunc(p / 2) == deg) {
@@ -43,7 +50,8 @@ kernelshap_one <- function(
4350
}
4451

4552
# Iterative sampling part, always using A_exact and b_exact to fill up the weights
46-
bg_X_m <- precalc[["bg_X_m"]] # (m*n_bg x p)
53+
g <- rep_each(m, each = bg_n)
54+
4755
v0_m <- v0[rep.int(1L, m), , drop = FALSE] # (m x K)
4856
est_m <- array(
4957
data = 0, dim = c(max_iter, p, K), dimnames = list(NULL, feature_names, K_names)
@@ -62,16 +70,23 @@ kernelshap_one <- function(
6270
while (!converged && n_iter < max_iter) {
6371
n_iter <- n_iter + 1L
6472
input <- input_sampling(p = p, m = m, deg = deg, feature_names = feature_names)
65-
Z <- input[["Z"]]
73+
Z <- input$Z
6674

6775
# Expensive # (m x K)
6876
vz <- get_vz(
69-
x = x, bg = bg_X_m, Z = Z, object = object, pred_fun = pred_fun, w = bg_w, ...
77+
x = x,
78+
bg_rep = precalc$bg_sampling_rep, # (m*bg_n x p)
79+
Z_rep = Z[g, , drop = FALSE],
80+
object = object,
81+
pred_fun = pred_fun,
82+
w = bg_w,
83+
bg_n = bg_n,
84+
...
7085
)
7186

72-
# The sum of weights of A_exact and input[["A"]] is 1, same for b
73-
A_temp <- A_exact + input[["A"]] # (p x p)
74-
b_temp <- b_exact + crossprod(Z, input[["w"]] * (vz - v0_m)) # (p x K)
87+
# The sum of weights of A_exact and input$A is 1, same for b
88+
A_temp <- A_exact + input$A # (p x p)
89+
b_temp <- b_exact + crossprod(Z, input$w * (vz - v0_m)) # (p x K)
7590
A_sum <- A_sum + A_temp # (p x p)
7691
b_sum <- b_sum + b_temp # (p x K)
7792

R/utils_permshap.R

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -26,26 +26,25 @@ permshap_one <- function(
2626
low_memory,
2727
tol,
2828
max_iter,
29+
bg_n,
2930
...) {
30-
bg <- precalc$bg_X_rep
31-
3231
p <- length(feature_names)
3332
K <- ncol(v1)
3433
K_names <- colnames(v1)
3534
beta_n <- matrix(0, nrow = p, ncol = K, dimnames = list(feature_names, K_names))
3635

3736
if (exact) {
38-
Z <- precalc$Z # ((m_ex+2) x K)
39-
vz <- get_vz( # (m_ex x K)
37+
vz <- get_vz(
4038
x = x,
41-
bg = bg,
42-
Z = Z[2L:(nrow(Z) - 1L), , drop = FALSE], # (m_ex x p)
39+
bg_rep = precalc$bg_exact_rep,
40+
Z_rep = precalc$Z_exact_rep,
4341
object = object,
4442
pred_fun = pred_fun,
4543
w = bg_w,
44+
bg_n = bg_n,
4645
...
4746
)
48-
vz <- rbind(v0, vz, v1) # we add the cheaply calculated v0 and v1
47+
vz <- rbind(v0, vz, v1)
4948

5049
for (j in seq_len(p)) {
5150
pos <- precalc$positions[[j]]
@@ -68,11 +67,12 @@ permshap_one <- function(
6867
# Pre-calculate part of Z with rowsum 1 or p - 1
6968
vz_balanced <- get_vz( # (2p x K)
7069
x = x,
71-
bg = precalc$bg_X_balanced,
72-
Z = precalc$Z_balanced,
70+
bg_rep = precalc$bg_balanced_rep,
71+
Z_rep = precalc$Z_balanced_rep,
7372
object = object,
7473
pred_fun = pred_fun,
7574
w = bg_w,
75+
bg_n = bg_n,
7676
...
7777
)
7878

@@ -83,24 +83,35 @@ permshap_one <- function(
8383
from_balanced <- c(2L, 2L + p, p, 2L * p)
8484
from_iter <- c(3L:(p - 1L), (p + 3L):(2L * p - 1L))
8585

86+
bg_sampling_rep <- precalc$bg_sampling_rep
87+
g <- rep_each(nrow(bg_sampling_rep) %/% bg_n, each = bg_n)
88+
8689
while (!converged && n_iter < max_iter) {
8790
chains <- balanced_chains(p)
8891
Z <- lapply(chains, sample_Z_from_chain, feature_names = feature_names)
8992
if (!low_memory) { # predictions for all chains at once
9093
Z <- do.call(rbind, Z)
9194
vz <- get_vz(
92-
x = x, bg = bg, Z = Z, object = object, pred_fun = pred_fun, w = bg_w, ...
95+
x = x,
96+
bg_rep = bg_sampling_rep,
97+
Z_rep = Z[g, , drop = FALSE],
98+
object = object,
99+
pred_fun = pred_fun,
100+
w = bg_w,
101+
bg_n = bg_n,
102+
...
93103
)
94104
} else { # predictions for each chain separately
95105
vz <- vector("list", length = p)
96106
for (j in seq_len(p)) {
97107
vz[[j]] <- get_vz(
98108
x = x,
99-
bg = bg,
100-
Z = Z[[j]],
109+
bg_rep = bg_sampling_rep,
110+
Z_rep = Z[[j]][g, , drop = FALSE],
101111
object = object,
102112
pred_fun = pred_fun,
103113
w = bg_w,
114+
bg_n = bg_n,
104115
...
105116
)
106117
}

backlog/compare_with_python.R

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,23 +25,23 @@ ks
2525
# [1,] -2.050074 -0.28048747 0.1281222 0.01587382
2626
# [2,] -2.085838 0.04050415 0.1283010 0.03731644
2727

28-
# Pure sampling version takes a bit longer (6.6 seconds)
28+
# Pure sampling version takes a bit longer (5.6 seconds)
2929
system.time(
3030
ks2 <- kernelshap(fit, X_small, bg_X = bg_X, exact = FALSE, hybrid_degree = 0)
3131
)
3232
ks2
3333

3434
bench::mark(kernelshap(fit, X_small, bg_X = bg_X, verbose=F))
35-
# 2.17s 1.64GB -> 1.79s 1.43GB
35+
# 1.66s 1.4GB -> 1.79s 1.43GB
3636

3737
bench::mark(kernelshap(fit, X_small, bg_X = bg_X, verbose=F, exact=F, hybrid_degree = 1))
38-
# 4.88s 2.79GB -> 4.38s 2.48GB
38+
# 4.58s 2.45GB -> 4.38s 2.48GB
3939

4040
bench::mark(permshap(fit, X_small, bg_X = bg_X, verbose=F))
41-
# 1.97s 1.64GB -> 1.9s 1.43GB
41+
# 1.75s 1.4GB -> 1.9s 1.43GB
4242

4343
bench::mark(permshap(fit, X_small, bg_X = bg_X, verbose=F, exact=F))
44-
# 3.04s 1.88GB -> 2.8s 1.64GB
44+
# 3.97s 1.63GB -> 2.8s 1.64GB
4545

4646
# SHAP values of first 2 observations:
4747
# carat clarity color cut

backlog/compare_with_python2.R

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ pf <- function(model, newdata) {
1818
}
1919
ks <- kernelshap(pf, head(X), bg_X = X, pred_fun = pf)
2020
ks # -1.196216 -1.241848 -0.9567848 3.879420 -0.33825 0.5456252
21-
es <- permshap(pf, head(X), bg_X = X, pred_fun = pf)
22-
es # -1.196216 -1.241848 -0.9567848 3.879420 -0.33825 0.5456252
21+
ps <- permshap(pf, head(X), bg_X = X, pred_fun = pf)
22+
ps # -1.196216 -1.241848 -0.9567848 3.879420 -0.33825 0.5456252
2323

2424
set.seed(10)
2525
kss <- kernelshap(
@@ -61,3 +61,15 @@ ksh2 <- kernelshap(
6161
tol = 0.0001
6262
)
6363
ksh2 # 1.195976 -1.241107 -0.9565121 3.878891 -0.3384621 0.5451118
64+
65+
set.seed(1)
66+
pss <- permshap(
67+
pf,
68+
head(X, 1),
69+
bg_X = X,
70+
pred_fun = pf,
71+
exact = FALSE,
72+
max_iter = 40000,
73+
tol = 0.0001
74+
)
75+
pss # -1.222608 -1.252001 -0.9312635 3.890444 -0.33825 0.5456252 non-convergence

packaging.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ library(usethis)
1515
use_description(
1616
fields = list(
1717
Title = "Kernel SHAP",
18-
Version = "0.9.0",
18+
Version = "0.9.1",
1919
Description = "Efficient implementation of Kernel SHAP
2020
(Lundberg and Lee, 2017, <doi:10.48550/arXiv.1705.07874>)
2121
permutation SHAP, and additive SHAP for model interpretability.

0 commit comments

Comments
 (0)