Skip to content

Commit 7e8202f

Browse files
authored
Merge pull request #107 from ModelOriented/factors
Support factor-valued predictions
2 parents 2a6fd0c + fd2e393 commit 7e8202f

File tree

11 files changed

+907
-12
lines changed

11 files changed

+907
-12
lines changed

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Package: kernelshap
22
Title: Kernel SHAP
3-
Version: 0.3.8
3+
Version: 0.4.0
44
Authors@R: c(
55
person("Michael", "Mayer", , "mayermichael79@gmail.com", role = c("aut", "cre")),
66
person("David", "Watson", , "david.s.watson11@gmail.com", role = "aut"),

NEWS.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,14 @@
1+
# kernelshap 0.4.0
2+
3+
## Major changes
4+
5+
- Factor valued predictions are now supported. Each level is represented by its dummy variable.
6+
7+
## Other changes
8+
9+
- Slight speed-up.
10+
- Integer valued case-weights are now turned into doubles to avoid integer overflow.
11+
112
# kernelshap 0.3.8
213

314
## API improvements

R/from_hstats.R

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# These functions have originally been implemented in {hstats}
2+
3+
#' Fast Index Generation
4+
#'
5+
#' For not too small m, much faster than `rep(seq_len(m), each = each)`.
6+
#'
7+
#' @noRd
8+
#' @keywords internal
9+
#'
10+
#' @param m Integer. See `each`.
11+
#' @param each Integer. How many times should each value in `1:m` be repeated?
12+
#' @returns Like `x`, but converted to matrix.
13+
#' @examples
14+
#' rep_each(10, 2)
15+
#' rep(1:10, each = 2) # Dito
16+
rep_each <- function(m, each) {
17+
out <- .col(dim = c(each, m))
18+
dim(out) <- NULL
19+
out
20+
}
21+
22+
#' Fast OHE
23+
#'
24+
#' Turns vector/factor into its One-Hot-Encoding.
25+
#' Ingeniouly written by Mathias Ambuehl.
26+
#'
27+
#' Working with integers instead of doubles would be slightly faster, but at the price
28+
#' of potential integer overflows in subsequent calculations.
29+
#'
30+
#' @noRd
31+
#' @keywords internal
32+
#'
33+
#' @param x Object representing model predictions.
34+
#' @returns Like `x`, but converted to matrix.
35+
fdummy <- function(x) {
36+
x <- as.factor(x)
37+
lev <- levels(x)
38+
out <- matrix(0, nrow = length(x), ncol = length(lev))
39+
out[cbind(seq_along(x), as.integer(x))] <- 1
40+
colnames(out) <- lev
41+
out
42+
}

R/kernelshap.R

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,9 @@ kernelshap.default <- function(object, X, bg_X, pred_fun = stats::predict,
213213
bg_n <- nrow(bg_X)
214214
if (!is.null(bg_w)) {
215215
stopifnot(length(bg_w) == bg_n, all(bg_w >= 0), !all(bg_w == 0))
216+
if (!is.double(bg_w)) {
217+
bg_w <- as.double(bg_w)
218+
}
216219
}
217220
if (is.matrix(X) && !identical(colnames(X), feature_names)) {
218221
stop("If X is a matrix, feature_names must equal colnames(X)")

R/utils.R

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ get_vz <- function(X, bg, Z, object, pred_fun, w, ...) {
141141
n_bg <- nrow(bg) / m # because bg was replicated m times
142142

143143
# Replicate not_Z, so that X, bg, not_Z are all of dimension (m*n_bg x p)
144-
g <- rep(seq_len(m), each = n_bg)
144+
g <- rep_each(m, each = n_bg) # from_hstats.R
145145
not_Z <- not_Z[g, , drop = FALSE]
146146

147147
if (is.matrix(X)) {
@@ -175,7 +175,7 @@ get_vz <- function(X, bg, Z, object, pred_fun, w, ...) {
175175
#' @returns A (1 x ncol(x)) matrix of column means.
176176
weighted_colMeans <- function(x, w = NULL, ...) {
177177
if (NCOL(x) == 1L && is.null(w)) {
178-
return(matrix(mean(x)))
178+
return(as.matrix(mean(x)))
179179
}
180180
if (!is.matrix(x)) {
181181
x <- as.matrix(x)
@@ -226,21 +226,21 @@ reorganize_list <- function(alist) {
226226

227227
#' Aligns Predictions
228228
#'
229-
#' Turns predictions into matrix. Originally implemented in {hstats}.
229+
#' Turns predictions into matrix.
230230
#'
231231
#' @noRd
232232
#' @keywords internal
233233
#'
234234
#' @param x Object representing model predictions.
235235
#' @returns Like `x`, but converted to matrix.
236236
align_pred <- function(x) {
237-
if (!is.matrix(x)) {
238-
x <- as.matrix(x)
237+
if (is.data.frame(x) && ncol(x) == 1L) {
238+
x <- x[[1L]]
239239
}
240-
if (!is.numeric(x)) {
241-
stop("Predictions must be numeric")
240+
if (is.factor(x)) {
241+
return(fdummy(x)) # from_hstats.R
242242
}
243-
x
243+
if (is.matrix(x)) x else as.matrix(x)
244244
}
245245

246246
#' Head of List Elements

README.md

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ If the training data is small, use the full training data. In cases with a natur
2626
**Remarks**
2727

2828
- Multivariate predictions are handled at no additional computational cost.
29+
- Factor-valued predictions are automatically turned into one-hot-encoded columns.
2930
- By changing the defaults, the iterative pure sampling approach in [2] can be enforced.
3031
- Case weights are supported via the argument `bg_w`.
3132

@@ -46,8 +47,8 @@ Let's model diamonds prices!
4647
### Linear regression
4748

4849
```r
49-
library(ggplot2)
5050
library(kernelshap)
51+
library(ggplot2)
5152
library(shapviz)
5253

5354
diamonds <- transform(
@@ -221,6 +222,40 @@ shap_gam
221222
# [2,] -0.5153642 -0.1080045 0.11967804 0.031341595
222223
```
223224

225+
## Multi-output models
226+
227+
{kernelshap} supports multivariate predictions, such as:
228+
- probabilistic classification,
229+
- non-probabilistic classification (factor-valued responses are turned into dummies),
230+
- regression with multivariate response, and
231+
- predictions found by applying multiple regression models.
232+
233+
### Classification
234+
235+
We use {ranger} to fit a probabilistic and a non-probabilistic classification model.
236+
237+
```r
238+
library(kernelshap)
239+
library(ranger)
240+
library(shapviz)
241+
242+
# Probabilistic
243+
fit_prob <- ranger(Species ~ ., data = iris, num.trees = 20, probability = TRUE, seed = 1)
244+
ks_prob <- kernelshap(fit_prob, X = iris, bg_X = iris) |>
245+
shapviz()
246+
sv_importance(ks_prob)
247+
248+
# Non-probabilistic: Predictions are factors
249+
fit_class <- ranger(Species ~ ., data = iris, num.trees = 20, seed = 1)
250+
ks_class <- kernelshap(fit_class, X = iris, bg_X = iris) |>
251+
shapviz()
252+
sv_importance(ks_class)
253+
```
254+
255+
![](man/figures/README-prob-class.svg)
256+
257+
![](man/figures/README-fact-class.svg)
258+
224259
## Meta-learning packages
225260

226261
Here, we provide some working examples for "tidymodels", "caret", and "mlr3".
@@ -283,7 +318,7 @@ fit_lm$train(task_iris)
283318
s <- kernelshap(fit_lm, iris[-1], bg_X = iris)
284319
s
285320

286-
# Probabilistic classification -> lrn(..., predict_type = "prob")
321+
# *Probabilistic* classification -> lrn(..., predict_type = "prob")
287322
task_iris <- TaskClassif$new(id = "class", backend = iris, target = "Species")
288323
fit_rf <- lrn("classif.ranger", predict_type = "prob", num.trees = 50)
289324
fit_rf$train(task_iris)

0 commit comments

Comments
 (0)