Skip to content

Commit 56a8ad1

Browse files
authored
Merge pull request #123 from ModelOriented/performance
Performance improvements
2 parents 0dc629c + 81265db commit 56a8ad1

File tree

9 files changed

+67
-21
lines changed

9 files changed

+67
-21
lines changed

NEWS.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
# kernelshap 0.4.1
22

3-
## Other changes
3+
## Performance improvements
44

5-
- Slight speed-up of `kernelshap()` and `permshap()` for single-output predictions.
6-
- Slight speed-up of `kernelshap()` and `permshap()` for factor-valued predictions.
5+
- Significant speed-up for data objects with *single class "data.frame"*, i.e., no data.tables or tibbles or grouped data etc. This change makes it almost as fast to work with data.frames as with matrices.
6+
- Slight speed-up for single-output predictions.
7+
- Slight speed-up for factor-valued predictions.
78
- Slight speed-up of `permshap()` by caching calculations for the two special permutations of all 0 and all 1. Consequently, the `m_exact` component in the output is reduced by 2.
89

910
# kernelshap 0.4.0

R/kernelshap.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -233,14 +233,14 @@ kernelshap.default <- function(object, X, bg_X, pred_fun = stats::predict,
233233
}
234234
m_exact <- nrow(precalc[["Z"]])
235235
prop_exact <- sum(precalc[["w"]])
236-
precalc[["bg_X_exact"]] <- bg_X[rep(seq_len(bg_n), times = m_exact), , drop = FALSE]
236+
precalc[["bg_X_exact"]] <- rep_rows(bg_X, rep.int(seq_len(bg_n), m_exact))
237237
} else {
238238
precalc <- list()
239239
m_exact <- 0L
240240
prop_exact <- 0
241241
}
242242
if (!exact) {
243-
precalc[["bg_X_m"]] <- bg_X[rep(seq_len(bg_n), times = m), , drop = FALSE]
243+
precalc[["bg_X_m"]] <- rep_rows(bg_X, rep.int(seq_len(bg_n), m))
244244
}
245245

246246
# Some infos

R/permshap.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ permshap.default <- function(object, X, bg_X, pred_fun = stats::predict,
9090
precalc <- list(
9191
Z = Z,
9292
Z_code = rowpaste(Z),
93-
bg_X_rep = bg_X[rep(seq_len(bg_n), times = m_exact), , drop = FALSE]
93+
bg_X_rep = rep_rows(bg_X, rep.int(seq_len(bg_n), m_exact))
9494
)
9595

9696
if (m_exact * bg_n > 2e5) {

R/utils.R

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,26 @@
1+
#' Fast Row Subsetting
2+
#'
3+
#' Internal function used to row-subset data.frames.
4+
#' Brings a massive speed-up for data.frames. All other classes (tibble, data.table,
5+
#' matrix) are subsetted in the usual way.
6+
#'
7+
#' @noRd
8+
#' @keywords internal
9+
#'
10+
#' @param x A matrix-like object.
11+
#' @param i Logical or integer vector of rows to pick.
12+
#' @returns Subsetted version of `x`.
13+
rep_rows <- function(x, i) {
14+
if (!(all(class(x) == "data.frame"))) {
15+
return(x[i, , drop = FALSE]) # matrix, tibble, data.table, ...
16+
}
17+
# data.frame
18+
out <- lapply(x, function(z) if (length(dim(z)) != 2L) z[i] else z[i, , drop = FALSE])
19+
attr(out, "row.names") <- .set_row_names(length(i))
20+
class(out) <- "data.frame"
21+
out
22+
}
23+
124
#' Weighted Version of colMeans()
225
#'
326
#' Internal function used to calculate column-wise weighted means.

R/utils_kernelshap.R

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@ kernelshap_one <- function(x, v1, object, pred_fun, feature_names, bg_w, exact,
1111
bg_X_exact <- precalc[["bg_X_exact"]] # (m_ex*n_bg x p)
1212
Z <- precalc[["Z"]] # (m_ex x p)
1313
m_exact <- nrow(Z)
14-
v0_m_exact <- v0[rep(1L, m_exact), , drop = FALSE] # (m_ex x K)
15-
14+
v0_m_exact <- v0[rep.int(1L, m_exact), , drop = FALSE] # (m_ex x K)
15+
1616
# Most expensive part
1717
vz <- get_vz( # (m_ex x K)
18-
X = x[rep.int(1L, times = nrow(bg_X_exact)), , drop = FALSE],# (m_ex*n_bg x p)
18+
X = rep_rows(x, rep.int(1L, nrow(bg_X_exact))), # (m_ex*n_bg x p)
1919
bg = bg_X_exact, # (m_ex*n_bg x p)
2020
Z = Z, # (m_ex x p)
2121
object = object,
@@ -35,8 +35,8 @@ kernelshap_one <- function(x, v1, object, pred_fun, feature_names, bg_w, exact,
3535

3636
# Iterative sampling part, always using A_exact and b_exact to fill up the weights
3737
bg_X_m <- precalc[["bg_X_m"]] # (m*n_bg x p)
38-
X <- x[rep(1L, times = nrow(bg_X_m)), , drop = FALSE] # (m*n_bg x p)
39-
v0_m <- v0[rep(1L, m), , drop = FALSE] # (m x K)
38+
X <- rep_rows(x, rep.int(1L, nrow(bg_X_m))) # (m*n_bg x p)
39+
v0_m <- v0[rep.int(1L, m), , drop = FALSE] # (m x K)
4040

4141
est_m = list()
4242
converged <- FALSE
@@ -91,7 +91,7 @@ solver <- function(A, b, constraint) {
9191
Ainv <- ginv(A)
9292
dimnames(Ainv) <- dimnames(A)
9393
s <- (matrix(colSums(Ainv %*% b), nrow = 1L) - constraint) / sum(Ainv) # (1 x K)
94-
Ainv %*% (b - s[rep(1L, p), , drop = FALSE]) # (p x K)
94+
Ainv %*% (b - s[rep.int(1L, p), , drop = FALSE]) # (p x K)
9595
}
9696

9797
ginv <- function (X, tol = sqrt(.Machine$double.eps)) {
@@ -138,7 +138,7 @@ sample_Z <- function(p, m, feature_names, S = 1:(p - 1L)) {
138138
# t(out)
139139

140140
# Vectorized by Mathias Ambuehl
141-
out <- rep(rep(0:1, m), as.vector(rbind(p - N, N)))
141+
out <- rep(rep.int(0:1, m), as.vector(rbind(p - N, N)))
142142
dim(out) <- c(p, m)
143143
ord <- order(col(out), sample.int(m * p))
144144
out[] <- out[ord]
@@ -180,7 +180,7 @@ input_sampling <- function(p, m, deg, paired, feature_names) {
180180
}
181181
w_total <- if (deg == 0L) 1 else 1 - 2 * sum(kernel_weights(p)[seq_len(deg)])
182182
w <- w_total / m
183-
list(Z = Z, w = rep(w, m), A = crossprod(Z) * w)
183+
list(Z = Z, w = rep.int(w, m), A = crossprod(Z) * w)
184184
}
185185

186186
# Functions required only for handling (partly) exact cases
@@ -262,7 +262,7 @@ input_partly_exact <- function(p, deg, feature_names) {
262262
Z[[k]] <- partly_exact_Z(p, k = k, feature_names = feature_names)
263263
n <- nrow(Z[[k]])
264264
w_tot <- kw[k] * (2 - (p == 2L * k))
265-
w[[k]] <- rep(w_tot / n, n)
265+
w[[k]] <- rep.int(w_tot / n, n)
266266
}
267267
w <- unlist(w, recursive = FALSE, use.names = FALSE)
268268
Z <- do.call(rbind, Z)

R/utils_permshap.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ shapley_weights <- function(p, ell) {
2828
permshap_one <- function(x, v1, object, pred_fun, bg_w, v0, precalc, ...) {
2929
Z <- precalc[["Z"]] # ((m_ex+2) x K)
3030
vz <- get_vz( # (m_ex x K)
31-
X = x[rep.int(1L, times = nrow(precalc[["bg_X_rep"]])), , drop = FALSE], # (m_ex*n_bg x p)
31+
X = rep_rows(x, rep.int(1L, times = nrow(precalc[["bg_X_rep"]]))), # (m_ex*n_bg x p)
3232
bg = precalc[["bg_X_rep"]], # (m_ex*n_bg x p)
3333
Z = Z[2:(nrow(Z) - 1L), , drop = FALSE], # (m_ex x p)
3434
object = object,

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ sv_dependence(sv_nn, "clarity")
206206

207207
## Parallel computing
208208

209-
Parallel computing is supported via `foreach`, at the price of losing the progress bar. Note that this does not work with Keras models (and some others).
209+
Parallel computing is supported via {foreach}, at the price of losing the progress bar. Note that this does not work with Keras models (and some others).
210210

211211
### Example: Linear regression continued
212212

@@ -226,7 +226,7 @@ system.time(
226226

227227
### Example: Parallel GAM
228228

229-
On Windows, sometimes not all packages or global objects are passed to the parallel sessions. In this case, the necessary instructions to `foreach` can be specified through a named list via `parallel_args`, see the following example:
229+
On Windows, sometimes not all packages or global objects are passed to the parallel sessions. In this case, the necessary instructions to {foreach} can be specified through a named list via `parallel_args`, see the following example:
230230

231231
```r
232232
library(mgcv)

backlog/2023-11-11 Permutation-SHAP.R

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,21 @@ library(ranger)
33

44
differences <- numeric(4)
55

6+
system.time({
67
set.seed(1)
7-
88
for (depth in 1:4) {
9+
print(depth)
910
fit <- ranger(
1011
Sepal.Length ~ .,
1112
mtry = 3,
1213
data = iris,
1314
max.depth = depth
1415
)
15-
ps <- permshap(fit, iris[2:5], bg_X = iris)
16-
ks <- kernelshap(fit, iris[2:5], bg_X = iris)
16+
ps <- permshap(fit, iris[2:5], bg_X = iris, verbose = FALSE)
17+
ks <- kernelshap(fit, iris[2:5], bg_X = iris, verbose = FALSE)
1718
differences[depth] <- mean(abs(ks$S - ps$S))
1819
}
20+
})
1921

2022
differences # for tree depth 1, 2, 3, 4
2123
# 5.053249e-17 9.046443e-17 2.387905e-04 4.403375e-04

tests/testthat/test-utils.R

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,26 @@ test_that("exact_Z() works for both kernel- and permshap", {
3232
expect_equal(dim(r1), c(2^p, p))
3333
})
3434

35+
test_that("rep_rows() gives the same as usual subsetting (except rownames)", {
36+
setrn <- function(x) {rownames(x) <- 1:nrow(x); x}
37+
38+
expect_equal(rep_rows(iris, 1), iris[1, ])
39+
expect_equal(rep_rows(iris, 2:1), setrn(iris[2:1, ]))
40+
expect_equal(rep_rows(iris, c(1, 1, 1)), setrn(iris[c(1, 1, 1), ]))
41+
42+
ir <- iris[1, ]
43+
ir$y <- list(list(a = 1, b = 2))
44+
expect_equal(rep_rows(ir, c(1, 1)), setrn(ir[c(1, 1), ]))
45+
})
46+
47+
test_that("rep_rows() gives the same as usual subsetting for matrices", {
48+
ir <- data.matrix(iris[1:4])
49+
50+
expect_equal(rep_rows(ir, c(1, 1, 2)), ir[c(1, 1, 2), ])
51+
expect_equal(rep_rows(ir, 1), ir[1, , drop = FALSE])
52+
})
53+
54+
3555
# Unit tests copied from {hstats}
3656

3757
test_that("rep_each() works", {

0 commit comments

Comments
 (0)