52
52
# '
53
53
# ' @param object Fitted model object.
54
54
# ' @param X A (n x p) matrix, data.frame, tibble or data.table of rows to be explained.
55
- # ' Important: The columns should only represent model features, not the response.
55
+ # ' The columns should only represent model features, not the response.
56
56
# ' @param bg_X Background data used to integrate out "switched off" features,
57
57
# ' often a subset of the training data (typically 50 to 500 rows)
58
58
# ' It should contain the same columns as \code{X}.
59
59
# ' In cases with a natural "off" value (like MNIST digits),
60
60
# ' this can also be a single row with all values set to the off value.
61
61
# ' @param pred_fun Prediction function of the form \code{function(object, X, ...)},
62
62
# ' providing K >= 1 numeric predictions per row. Its first argument represents the
63
- # ' model \code{object}, its second argument a data structure like \code{X} and \code{bg_X}.
64
- # ' (The names of the first two arguments do not matter.) Additional (named)
65
- # ' arguments are passed via \code{...}. The default, \code{stats::predict}, will
66
- # ' work in most cases. Some exceptions (classes "ranger" and mlr3 "Learner")
67
- # ' are handled separately. In other cases, the function must be specified manually.
63
+ # ' model \code{object}, its second argument a data structure like \code{X}.
64
+ # ' Additional (named) arguments are passed via \code{...}.
65
+ # ' The default, \code{stats::predict}, will work in most cases.
68
66
# ' @param feature_names Optional vector of column names in \code{X} used to calculate
69
- # ' SHAP values. By default, this equals \code{colnames(X)}. Not supported for matrix
70
- # ' \code{X} .
67
+ # ' SHAP values. By default, this equals \code{colnames(X)}. Not supported if \code{X}
68
+ # ' is a matrix .
71
69
# ' @param bg_w Optional vector of case weights for each row of \code{bg_X}.
72
70
# ' @param exact If \code{TRUE}, the algorithm will produce exact Kernel SHAP values
73
71
# ' with respect to the background data. In this case, the arguments \code{hybrid_degree},
135
133
# '}
136
134
# ' @export
137
135
# ' @examples
138
- # ' # Linear regression
136
+ # ' # MODEL ONE: Linear regression
139
137
# ' fit <- stats::lm(Sepal.Length ~ ., data = iris)
140
- # ' s <- kernelshap(fit, iris[1:2, -1], bg_X = iris)
138
+ # '
139
+ # ' # Select rows to explain (only feature columns)
140
+ # ' X_explain <- iris[1:2, -1]
141
+ # '
142
+ # ' # Select small background dataset (could use all rows here because iris is small)
143
+ # ' set.seed(1)
144
+ # ' bg_X <- iris[sample(nrow(iris), 100), ]
145
+ # '
146
+ # ' # Calculate SHAP values
147
+ # ' s <- kernelshap(fit, X_explain, bg_X = bg_X)
141
148
# ' s
142
149
# '
143
- # ' # Multivariate model
150
+ # ' # MODEL TWO: Multi-response linear regression
144
151
# ' fit <- stats::lm(
145
152
# ' as.matrix(iris[1:2]) ~ Petal.Length + Petal.Width + Species, data = iris
146
153
# ' )
147
- # ' s <- kernelshap(fit, iris[1:4, 3:5], bg_X = iris )
154
+ # ' s <- kernelshap(fit, iris[1:4, 3:5], bg_X = bg_X )
148
155
# ' summary(s)
149
- # '
150
- # ' # Matrix input works as well, and pred_fun can be overwritten
151
- # ' fit <- stats::lm(Sepal.Length ~ ., data = iris[1:4])
152
- # ' pred_fun <- function(fit, X) stats::predict(fit, as.data.frame(X))
153
- # ' X <- data.matrix(iris[2:4])
154
- # ' s <- kernelshap(fit, X[1:3, ], bg_X = X, pred_fun = pred_fun)
155
- # ' s
156
- # '
157
- # ' # Logistic regression
158
- # ' fit <- stats::glm(
159
- # ' I(Species == "virginica") ~ Sepal.Length + Sepal.Width,
160
- # ' data = iris,
161
- # ' family = binomial
162
- # ' )
163
- # '
164
- # ' # On scale of linear predictor
165
- # ' s <- kernelshap(fit, iris[1:2], bg_X = iris)
166
- # ' s
167
- # '
168
- # ' # On scale of response (probability)
169
- # ' s <- kernelshap(fit, iris[1:2], bg_X = iris, type = "response")
170
- # ' s
171
156
# '
172
157
# ' # Non-feature columns can be dropped via 'feature_names'
173
- # ' fit <- stats::lm(Sepal.Length ~ . - Species, data = iris)
174
158
# ' s <- kernelshap(
175
159
# ' fit,
176
- # ' iris[1:2 , ],
177
- # ' bg_X = iris ,
178
- # ' feature_names = c("Sepal.Width ", "Petal.Length ", "Petal.Width ")
160
+ # ' iris[1:4 , ],
161
+ # ' bg_X = bg_X ,
162
+ # ' feature_names = c("Petal.Length ", "Petal.Width ", "Species ")
179
163
# ' )
180
164
# ' s
181
165
kernelshap <- function (object , ... ){
@@ -202,7 +186,8 @@ kernelshap.default <- function(object, X, bg_X, pred_fun = stats::predict,
202
186
! is.null(colnames(bg_X )),
203
187
(p <- length(feature_names )) > = 1L ,
204
188
all(feature_names %in% colnames(X )),
205
- all(feature_names %in% colnames(bg_X )),
189
+ all(feature_names %in% colnames(bg_X )), # not necessary, but clearer
190
+ all(colnames(X ) %in% colnames(bg_X )),
206
191
is.function(pred_fun ),
207
192
exact %in% c(TRUE , FALSE ),
208
193
p == 1L || exact || hybrid_degree %in% 0 : (p / 2 ),
@@ -218,10 +203,12 @@ kernelshap.default <- function(object, X, bg_X, pred_fun = stats::predict,
218
203
stop(" If X is a matrix, feature_names must equal colnames(X)" )
219
204
}
220
205
221
- # Calculate v0 and v1
222
- bg_preds <- check_pred(pred_fun(object , bg_X , ... ), n = bg_n )
223
- v0 <- weighted_colMeans(bg_preds , bg_w ) # Average pred of bg data: 1 x K
206
+ # Calculate v1 and v0
224
207
v1 <- check_pred(pred_fun(object , X , ... ), n = n ) # Predictions on X: n x K
208
+ bg_preds <- check_pred(
209
+ pred_fun(object , bg_X [, colnames(X ), drop = FALSE ], ... ), n = bg_n
210
+ )
211
+ v0 <- weighted_colMeans(bg_preds , bg_w ) # Average pred of bg data: 1 x K
225
212
226
213
# For p = 1, exact Shapley values are returned
227
214
if (p == 1L ) {
@@ -231,6 +218,7 @@ kernelshap.default <- function(object, X, bg_X, pred_fun = stats::predict,
231
218
}
232
219
233
220
# Drop unnecessary columns in bg_X. If X is matrix, also column order is relevant
221
+ # In what follows, predictions will never be applied directly to bg_X anymore
234
222
if (! identical(colnames(bg_X ), feature_names )) {
235
223
bg_X <- bg_X [, feature_names , drop = FALSE ]
236
224
}
0 commit comments