Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: kernelshap
Title: Kernel SHAP
Version: 0.9.0
Version: 0.9.1
Authors@R: c(
person("Michael", "Mayer", , "mayermichael79@gmail.com", role = c("aut", "cre"),
comment = c(ORCID = "0009-0007-2540-9629")),
Expand Down
6 changes: 6 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# kernelshap 0.9.1

### Speed and memory improvements

- More pre-calculations for exact part of the methods ([#175](https://github.com/ModelOriented/kernelshap/pull/175)).

# kernelshap 0.9.0

### Bug fix
Expand Down
11 changes: 8 additions & 3 deletions R/kernelshap.R
Original file line number Diff line number Diff line change
Expand Up @@ -242,16 +242,19 @@ kernelshap.default <- function(
p = p, deg = hybrid_degree, feature_names = feature_names
)
}
m_exact <- nrow(precalc[["Z"]])
Z <- precalc[["Z"]]
m_exact <- nrow(Z)
prop_exact <- sum(precalc[["w"]])
precalc[["bg_X_exact"]] <- rep_rows(bg_X, rep.int(seq_len(bg_n), m_exact))
precalc[["bg_exact_rep"]] <- rep_rows(bg_X, rep.int(seq_len(bg_n), m_exact))
g <- rep_each(m_exact, each = bg_n)
precalc[["Z_exact_rep"]] <- Z[g, , drop = FALSE]
} else {
precalc <- list()
m_exact <- 0L
prop_exact <- 0
}
if (!exact) {
precalc[["bg_X_m"]] <- rep_rows(bg_X, rep.int(seq_len(bg_n), m))
precalc[["bg_sampling_rep"]] <- rep_rows(bg_X, rep.int(seq_len(bg_n), m))
}

if (max(m, m_exact) * bg_n > 2e5) {
Expand All @@ -276,6 +279,7 @@ kernelshap.default <- function(
max_iter = max_iter,
v0 = v0,
precalc = precalc,
bg_n = bg_n,
...
)
} else {
Expand All @@ -298,6 +302,7 @@ kernelshap.default <- function(
max_iter = max_iter,
v0 = v0,
precalc = precalc,
bg_n = bg_n,
...
)
if (verbose && n >= 2L) {
Expand Down
22 changes: 14 additions & 8 deletions R/permshap.R
Original file line number Diff line number Diff line change
Expand Up @@ -143,23 +143,27 @@ permshap.default <- function(
# Pre-calculations that are identical for each row to be explained
if (exact) {
Z <- exact_Z(p, feature_names = feature_names)
m_exact <- nrow(Z) - 2L # We won't evaluate vz for first and last row
Z_no_extremes <- Z[2L:(nrow(Z) - 1L), , drop = FALSE]
m_exact <- nrow(Z_no_extremes) # 2^p - 2
m_eval <- 0L # for consistency with sampling case
g <- rep_each(m_exact, each = bg_n)
precalc <- list(
Z = Z,
bg_X_rep = rep_rows(bg_X, rep.int(seq_len(bg_n), m_exact)),
Z_exact_rep = Z_no_extremes[g, , drop = FALSE],
bg_exact_rep = rep_rows(bg_X, rep.int(seq_len(bg_n), m_exact)),
positions = positions_for_exact(Z),
shapley_w = shapley_weights(p, ell = rowSums(Z) - 1) # how many other players?
)
} else {
max_iter <- as.integer(ceiling(max_iter / p) * p) # should be multiple of p
m_exact <- 2L * p
m <- 2L * (p - 3L) # inner loop
Z <- exact_Z_balanced(p, feature_names)
m_exact <- nrow(Z) # 2L * p
m <- 2L * (p - 3L) # for inner loop
m_eval <- if (low_memory) m else m * p # outer loop
g <- rep_each(m_exact, each = bg_n)
precalc <- list(
bg_X_rep = rep_rows(bg_X, rep.int(seq_len(bg_n), m_eval)),
bg_X_balanced = rep_rows(bg_X, rep.int(seq_len(bg_n), m_exact)),
Z_balanced = exact_Z_balanced(p, feature_names)
Z_balanced_rep = Z[g, , drop = FALSE],
bg_balanced_rep = rep_rows(bg_X, rep.int(seq_len(bg_n), m_exact)),
bg_sampling_rep = rep_rows(bg_X, rep.int(seq_len(bg_n), m_eval))
)
}

Expand All @@ -184,6 +188,7 @@ permshap.default <- function(
low_memory = low_memory,
tol = tol,
max_iter = max_iter,
bg_n = bg_n,
...
)
} else {
Expand All @@ -205,6 +210,7 @@ permshap.default <- function(
low_memory = low_memory,
tol = tol,
max_iter = max_iter,
bg_n = bg_n,
...
)
if (verbose && n >= 2L) {
Expand Down
26 changes: 11 additions & 15 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -73,36 +73,32 @@ exact_Z <- function(p, feature_names) {
#'
#' @inheritParams kernelshap
#' @param x Row to be explained.
#' @param bg Background data stacked m times.
#' @param Z A logical (m x p) matrix with on-off values.
#' @param bg_rep Background data stacked m times.
#' @param Z_rep A logical ((m * bg_n) x p) matrix with on-off values.
#' @param w A vector with case weights (of the same length as the unstacked
#' background data).
#' @param bg_n Size of background dataset (unstacked).
#' @returns A (m x K) matrix with vz values.
get_vz <- function(x, bg, Z, object, pred_fun, w, ...) {
m <- nrow(Z)
n_bg <- nrow(bg) / m # because bg was replicated m times

# Replicate Z, so that bg and Z are of dimension (m*n_bg x p)
g <- rep_each(m, each = n_bg)
Z_rep <- Z[g, , drop = FALSE]

for (v in colnames(Z)) {
get_vz <- function(x, bg_rep, Z_rep, object, pred_fun, w, bg_n, ...) {
for (v in colnames(Z_rep)) {
s <- Z_rep[, v]
if (is.matrix(x)) {
bg[s, v] <- x[, v]
bg_rep[s, v] <- x[, v]
} else {
bg[[v]][s] <- x[[v]]
bg_rep[[v]][s] <- x[[v]]
}
}

preds <- align_pred(pred_fun(object, bg, ...))
preds <- align_pred(pred_fun(object, bg_rep, ...))

# Aggregate (distinguishing fast 1-dim case)
m <- nrow(Z_rep) %/% bg_n
if (ncol(preds) == 1L) {
return(wrowmean_vector(preds, ngroups = m, w = w))
}
g <- rep_each(m, each = bg_n)
if (is.null(w)) {
return(rowsum(preds, group = g, reorder = FALSE) / n_bg)
return(rowsum(preds, group = g, reorder = FALSE) / bg_n)
}
rowsum(preds * w, group = g, reorder = FALSE) / sum(w)
}
Expand Down
37 changes: 26 additions & 11 deletions R/utils_kernelshap.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,25 +15,32 @@ kernelshap_one <- function(
max_iter,
v0,
precalc,
bg_n,
...) {
p <- length(feature_names)
K <- ncol(v1)
K_names <- colnames(v1)

# Calculate A_exact and b_exact
if (exact || deg >= 1L) {
A_exact <- precalc[["A"]] # (p x p)
bg_X_exact <- precalc[["bg_X_exact"]] # (m_ex*n_bg x p)
Z <- precalc[["Z"]] # (m_ex x p)
A_exact <- precalc$A # (p x p)
Z <- precalc$Z # (m_ex x p)
m_exact <- nrow(Z)
v0_m_exact <- v0[rep.int(1L, m_exact), , drop = FALSE] # (m_ex x K)

# Most expensive part
vz <- get_vz(
x = x, bg = bg_X_exact, Z = Z, object = object, pred_fun = pred_fun, w = bg_w, ...
x = x,
bg_rep = precalc$bg_exact_rep, # (m_ex*bg_n x p)
Z_rep = precalc$Z_exact_rep, # (m_ex*bg_n x p)
object = object,
pred_fun = pred_fun,
w = bg_w,
bg_n = bg_n,
...
)
# Note: w is correctly replicated along columns of (vz - v0_m_exact)
b_exact <- crossprod(Z, precalc[["w"]] * (vz - v0_m_exact)) # (p x K)
b_exact <- crossprod(Z, precalc$w * (vz - v0_m_exact)) # (p x K)

# Some of the hybrid cases are exact as well
if (exact || trunc(p / 2) == deg) {
Expand All @@ -43,7 +50,8 @@ kernelshap_one <- function(
}

# Iterative sampling part, always using A_exact and b_exact to fill up the weights
bg_X_m <- precalc[["bg_X_m"]] # (m*n_bg x p)
g <- rep_each(m, each = bg_n)

v0_m <- v0[rep.int(1L, m), , drop = FALSE] # (m x K)
est_m <- array(
data = 0, dim = c(max_iter, p, K), dimnames = list(NULL, feature_names, K_names)
Expand All @@ -62,16 +70,23 @@ kernelshap_one <- function(
while (!converged && n_iter < max_iter) {
n_iter <- n_iter + 1L
input <- input_sampling(p = p, m = m, deg = deg, feature_names = feature_names)
Z <- input[["Z"]]
Z <- input$Z

# Expensive # (m x K)
vz <- get_vz(
x = x, bg = bg_X_m, Z = Z, object = object, pred_fun = pred_fun, w = bg_w, ...
x = x,
bg_rep = precalc$bg_sampling_rep, # (m*bg_n x p)
Z_rep = Z[g, , drop = FALSE],
object = object,
pred_fun = pred_fun,
w = bg_w,
bg_n = bg_n,
...
)

# The sum of weights of A_exact and input[["A"]] is 1, same for b
A_temp <- A_exact + input[["A"]] # (p x p)
b_temp <- b_exact + crossprod(Z, input[["w"]] * (vz - v0_m)) # (p x K)
# The sum of weights of A_exact and input$A is 1, same for b
A_temp <- A_exact + input$A # (p x p)
b_temp <- b_exact + crossprod(Z, input$w * (vz - v0_m)) # (p x K)
A_sum <- A_sum + A_temp # (p x p)
b_sum <- b_sum + b_temp # (p x K)

Expand Down
35 changes: 23 additions & 12 deletions R/utils_permshap.R
Original file line number Diff line number Diff line change
Expand Up @@ -26,26 +26,25 @@ permshap_one <- function(
low_memory,
tol,
max_iter,
bg_n,
...) {
bg <- precalc$bg_X_rep

p <- length(feature_names)
K <- ncol(v1)
K_names <- colnames(v1)
beta_n <- matrix(0, nrow = p, ncol = K, dimnames = list(feature_names, K_names))

if (exact) {
Z <- precalc$Z # ((m_ex+2) x K)
vz <- get_vz( # (m_ex x K)
vz <- get_vz(
x = x,
bg = bg,
Z = Z[2L:(nrow(Z) - 1L), , drop = FALSE], # (m_ex x p)
bg_rep = precalc$bg_exact_rep,
Z_rep = precalc$Z_exact_rep,
object = object,
pred_fun = pred_fun,
w = bg_w,
bg_n = bg_n,
...
)
vz <- rbind(v0, vz, v1) # we add the cheaply calculated v0 and v1
vz <- rbind(v0, vz, v1)

for (j in seq_len(p)) {
pos <- precalc$positions[[j]]
Expand All @@ -68,11 +67,12 @@ permshap_one <- function(
# Pre-calculate part of Z with rowsum 1 or p - 1
vz_balanced <- get_vz( # (2p x K)
x = x,
bg = precalc$bg_X_balanced,
Z = precalc$Z_balanced,
bg_rep = precalc$bg_balanced_rep,
Z_rep = precalc$Z_balanced_rep,
object = object,
pred_fun = pred_fun,
w = bg_w,
bg_n = bg_n,
...
)

Expand All @@ -83,24 +83,35 @@ permshap_one <- function(
from_balanced <- c(2L, 2L + p, p, 2L * p)
from_iter <- c(3L:(p - 1L), (p + 3L):(2L * p - 1L))

bg_sampling_rep <- precalc$bg_sampling_rep
g <- rep_each(nrow(bg_sampling_rep) %/% bg_n, each = bg_n)

while (!converged && n_iter < max_iter) {
chains <- balanced_chains(p)
Z <- lapply(chains, sample_Z_from_chain, feature_names = feature_names)
if (!low_memory) { # predictions for all chains at once
Z <- do.call(rbind, Z)
vz <- get_vz(
x = x, bg = bg, Z = Z, object = object, pred_fun = pred_fun, w = bg_w, ...
x = x,
bg_rep = bg_sampling_rep,
Z_rep = Z[g, , drop = FALSE],
object = object,
pred_fun = pred_fun,
w = bg_w,
bg_n = bg_n,
...
)
} else { # predictions for each chain separately
vz <- vector("list", length = p)
for (j in seq_len(p)) {
vz[[j]] <- get_vz(
x = x,
bg = bg,
Z = Z[[j]],
bg_rep = bg_sampling_rep,
Z_rep = Z[[j]][g, , drop = FALSE],
object = object,
pred_fun = pred_fun,
w = bg_w,
bg_n = bg_n,
...
)
}
Expand Down
10 changes: 5 additions & 5 deletions backlog/compare_with_python.R
Original file line number Diff line number Diff line change
Expand Up @@ -25,23 +25,23 @@ ks
# [1,] -2.050074 -0.28048747 0.1281222 0.01587382
# [2,] -2.085838 0.04050415 0.1283010 0.03731644

# Pure sampling version takes a bit longer (6.6 seconds)
# Pure sampling version takes a bit longer (5.6 seconds)
system.time(
ks2 <- kernelshap(fit, X_small, bg_X = bg_X, exact = FALSE, hybrid_degree = 0)
)
ks2

bench::mark(kernelshap(fit, X_small, bg_X = bg_X, verbose=F))
# 2.17s 1.64GB -> 1.79s 1.43GB
# 1.66s 1.4GB -> 1.79s 1.43GB

bench::mark(kernelshap(fit, X_small, bg_X = bg_X, verbose=F, exact=F, hybrid_degree = 1))
# 4.88s 2.79GB -> 4.38s 2.48GB
# 4.58s 2.45GB -> 4.38s 2.48GB

bench::mark(permshap(fit, X_small, bg_X = bg_X, verbose=F))
# 1.97s 1.64GB -> 1.9s 1.43GB
# 1.75s 1.4GB -> 1.9s 1.43GB

bench::mark(permshap(fit, X_small, bg_X = bg_X, verbose=F, exact=F))
# 3.04s 1.88GB -> 2.8s 1.64GB
# 3.97s 1.63GB -> 2.8s 1.64GB

# SHAP values of first 2 observations:
# carat clarity color cut
Expand Down
16 changes: 14 additions & 2 deletions backlog/compare_with_python2.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ pf <- function(model, newdata) {
}
ks <- kernelshap(pf, head(X), bg_X = X, pred_fun = pf)
ks # -1.196216 -1.241848 -0.9567848 3.879420 -0.33825 0.5456252
es <- permshap(pf, head(X), bg_X = X, pred_fun = pf)
es # -1.196216 -1.241848 -0.9567848 3.879420 -0.33825 0.5456252
ps <- permshap(pf, head(X), bg_X = X, pred_fun = pf)
ps # -1.196216 -1.241848 -0.9567848 3.879420 -0.33825 0.5456252

set.seed(10)
kss <- kernelshap(
Expand Down Expand Up @@ -61,3 +61,15 @@ ksh2 <- kernelshap(
tol = 0.0001
)
ksh2 # 1.195976 -1.241107 -0.9565121 3.878891 -0.3384621 0.5451118

set.seed(1)
pss <- permshap(
pf,
head(X, 1),
bg_X = X,
pred_fun = pf,
exact = FALSE,
max_iter = 40000,
tol = 0.0001
)
pss # -1.222608 -1.252001 -0.9312635 3.890444 -0.33825 0.5456252 non-convergence
2 changes: 1 addition & 1 deletion packaging.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ library(usethis)
use_description(
fields = list(
Title = "Kernel SHAP",
Version = "0.9.0",
Version = "0.9.1",
Description = "Efficient implementation of Kernel SHAP
(Lundberg and Lee, 2017, <doi:10.48550/arXiv.1705.07874>)
permutation SHAP, and additive SHAP for model interpretability.
Expand Down
Loading
Loading