Skip to content

Commit db38b00

Browse files
committed
Support factor-valued predictions
1 parent 2a6fd0c commit db38b00

File tree

10 files changed

+894
-12
lines changed

10 files changed

+894
-12
lines changed

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Package: kernelshap
22
Title: Kernel SHAP
3-
Version: 0.3.8
3+
Version: 0.4.0
44
Authors@R: c(
55
person("Michael", "Mayer", , "mayermichael79@gmail.com", role = c("aut", "cre")),
66
person("David", "Watson", , "david.s.watson11@gmail.com", role = "aut"),

NEWS.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,14 @@
1+
# kernelshap 0.4.0
2+
3+
## Major changes
4+
5+
- Factor valued predictions are now supported. Each level is represented by its dummy variable.
6+
7+
## Other changes
8+
9+
- Slight speed-up.
10+
- Integer valued case-weights are now turned into doubles to avoid integer overflow.
11+
112
# kernelshap 0.3.8
213

314
## API improvements

R/from_hstats.R

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# These functions have originally been implemented in {hstats}
2+
3+
#' Fast Index Generation
4+
#'
5+
#' For not too small m, much faster than `rep(seq_len(m), each = each)`.
6+
#'
7+
#' @noRd
8+
#' @keywords internal
9+
#'
10+
#' @param m Integer. See `each`.
11+
#' @param each Integer. How many times should each value in `1:m` be repeated?
12+
#' @returns Like `x`, but converted to matrix.
13+
#' @examples
14+
#' rep_each(10, 2)
15+
#' rep(1:10, each = 2) # Dito
16+
rep_each <- function(m, each) {
17+
out <- .col(dim = c(each, m))
18+
dim(out) <- NULL
19+
out
20+
}
21+
22+
#' Fast OHE
23+
#'
24+
#' Turns vector/factor into its One-Hot-Encoding.
25+
#' Ingeniouly written by Mathias Ambuehl.
26+
#'
27+
#' Working with integers instead of doubles would be slightly faster, but at the price
28+
#' of potential integer overflows in subsequent calculations.
29+
#'
30+
#' @noRd
31+
#' @keywords internal
32+
#'
33+
#' @param x Object representing model predictions.
34+
#' @returns Like `x`, but converted to matrix.
35+
fdummy <- function(x) {
36+
x <- as.factor(x)
37+
lev <- levels(x)
38+
out <- matrix(0, nrow = length(x), ncol = length(lev))
39+
out[cbind(seq_along(x), as.integer(x))] <- 1
40+
colnames(out) <- lev
41+
out
42+
}

R/kernelshap.R

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,9 @@ kernelshap.default <- function(object, X, bg_X, pred_fun = stats::predict,
213213
bg_n <- nrow(bg_X)
214214
if (!is.null(bg_w)) {
215215
stopifnot(length(bg_w) == bg_n, all(bg_w >= 0), !all(bg_w == 0))
216+
if (!is.double(bg_w)) {
217+
bg_w <- as.double(bg_w)
218+
}
216219
}
217220
if (is.matrix(X) && !identical(colnames(X), feature_names)) {
218221
stop("If X is a matrix, feature_names must equal colnames(X)")

R/utils.R

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ get_vz <- function(X, bg, Z, object, pred_fun, w, ...) {
141141
n_bg <- nrow(bg) / m # because bg was replicated m times
142142

143143
# Replicate not_Z, so that X, bg, not_Z are all of dimension (m*n_bg x p)
144-
g <- rep(seq_len(m), each = n_bg)
144+
g <- rep_each(m, each = n_bg) # from_hstats.R
145145
not_Z <- not_Z[g, , drop = FALSE]
146146

147147
if (is.matrix(X)) {
@@ -175,7 +175,7 @@ get_vz <- function(X, bg, Z, object, pred_fun, w, ...) {
175175
#' @returns A (1 x ncol(x)) matrix of column means.
176176
weighted_colMeans <- function(x, w = NULL, ...) {
177177
if (NCOL(x) == 1L && is.null(w)) {
178-
return(matrix(mean(x)))
178+
return(as.matrix(mean(x)))
179179
}
180180
if (!is.matrix(x)) {
181181
x <- as.matrix(x)
@@ -226,21 +226,21 @@ reorganize_list <- function(alist) {
226226

227227
#' Aligns Predictions
228228
#'
229-
#' Turns predictions into matrix. Originally implemented in {hstats}.
229+
#' Turns predictions into matrix.
230230
#'
231231
#' @noRd
232232
#' @keywords internal
233233
#'
234234
#' @param x Object representing model predictions.
235235
#' @returns Like `x`, but converted to matrix.
236236
align_pred <- function(x) {
237-
if (!is.matrix(x)) {
238-
x <- as.matrix(x)
237+
if (is.data.frame(x) && ncol(x) == 1L) {
238+
x <- x[[1L]]
239239
}
240-
if (!is.numeric(x)) {
241-
stop("Predictions must be numeric")
240+
if (is.factor(x)) {
241+
return(fdummy(x)) # from_hstats.R
242242
}
243-
x
243+
if (is.matrix(x)) x else as.matrix(x)
244244
}
245245

246246
#' Head of List Elements

README.md

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@ Let's model diamonds prices!
4646
### Linear regression
4747

4848
```r
49-
library(ggplot2)
5049
library(kernelshap)
50+
library(ggplot2)
5151
library(shapviz)
5252

5353
diamonds <- transform(
@@ -221,6 +221,36 @@ shap_gam
221221
# [2,] -0.5153642 -0.1080045 0.11967804 0.031341595
222222
```
223223

224+
## Multi-output models
225+
226+
{kernelshap} supports multivariate predictions, such as:
227+
- probabilistic classification,
228+
- non-probabilistic classification (factor-valued responses are turned into dummies),
229+
- regression with multivariate response, and
230+
- predictions found by applying multiple regression models.
231+
232+
### Classification
233+
234+
We use {ranger} to fit a probabilistic and a non-probabilistic classification model.
235+
236+
```r
237+
library(kernelshap)
238+
library(ranger)
239+
library(shapviz)
240+
241+
# Probabilistic
242+
fit_prob <- ranger(Species ~ ., data = iris, num.trees = 20, probability = TRUE, seed = 1)
243+
ks_prob <- kernelshap(fit_prob, X = iris, bg_X = iris) |>
244+
shapviz()
245+
sv_importance(ks_prob)
246+
247+
# Non-probabilistic: Predictions are factors
248+
fit_class <- ranger(Species ~ ., data = iris, num.trees = 20, seed = 1)
249+
ks_class <- kernelshap(fit_class, X = iris, bg_X = iris) |>
250+
shapviz()
251+
sv_importance(ks_class)
252+
```
253+
224254
## Meta-learning packages
225255

226256
Here, we provide some working examples for "tidymodels", "caret", and "mlr3".
@@ -283,7 +313,7 @@ fit_lm$train(task_iris)
283313
s <- kernelshap(fit_lm, iris[-1], bg_X = iris)
284314
s
285315

286-
# Probabilistic classification -> lrn(..., predict_type = "prob")
316+
# *Probabilistic* classification -> lrn(..., predict_type = "prob")
287317
task_iris <- TaskClassif$new(id = "class", backend = iris, target = "Species")
288318
fit_rf <- lrn("classif.ranger", predict_type = "prob", num.trees = 50)
289319
fit_rf$train(task_iris)

0 commit comments

Comments
 (0)