Skip to content

Commit f4d4e22

Browse files
authored
Merge pull request #127 from ModelOriented/simplify-mlr3
Greatly simplify mlr3 workflow
2 parents 70e0c68 + 9da2237 commit f4d4e22

File tree

11 files changed

+68
-127
lines changed

11 files changed

+68
-127
lines changed

.github/workflows/test-coverage.yaml

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,8 @@ jobs:
3333
clean = FALSE,
3434
install_path = file.path(Sys.getenv("RUNNER_TEMP"), "package"),
3535
function_exclusions = c(
36-
"kernelshap\\.Learner",
3736
"kernelshap\\.ranger",
38-
"permshap\\.Learner",
39-
"permshap\\.ranger",
40-
"mlr3_pred_fun"
37+
"permshap\\.ranger"
4138
)
4239
)
4340
shell: Rscript {0}

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Package: kernelshap
22
Title: Kernel SHAP
3-
Version: 0.4.1
3+
Version: 0.4.2
44
Authors@R: c(
55
person("Michael", "Mayer", , "mayermichael79@gmail.com", role = c("aut", "cre")),
66
person("David", "Watson", , "david.s.watson11@gmail.com", role = "aut"),

NAMESPACE

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
# Generated by roxygen2: do not edit by hand
22

3-
S3method(kernelshap,Learner)
43
S3method(kernelshap,default)
54
S3method(kernelshap,ranger)
6-
S3method(permshap,Learner)
75
S3method(permshap,default)
86
S3method(permshap,ranger)
97
S3method(print,kernelshap)

NEWS.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,14 @@
1+
# kernelshap 0.4.2
2+
3+
## API
4+
5+
- {mlr3}: Non-probabilistic classification now works.
6+
- {mlr3}: For *probabilistic* classification, you now have to pass `predict_type = "prob"`.
7+
8+
## Documentation
9+
10+
- The README has received an {mlr3} and {caret} example.
11+
112
# kernelshap 0.4.1
213

314
## Performance improvements

R/kernelshap.R

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -355,36 +355,3 @@ kernelshap.ranger <- function(object, X, bg_X,
355355
)
356356
}
357357

358-
#' @describeIn kernelshap Kernel SHAP method for "mlr3" models, see Readme for an example.
359-
#' @export
360-
kernelshap.Learner <- function(object, X, bg_X,
361-
pred_fun = NULL,
362-
feature_names = colnames(X),
363-
bg_w = NULL, exact = length(feature_names) <= 8L,
364-
hybrid_degree = 1L + length(feature_names) %in% 4:16,
365-
paired_sampling = TRUE,
366-
m = 2L * length(feature_names) * (1L + 3L * (hybrid_degree == 0L)),
367-
tol = 0.005, max_iter = 100L, parallel = FALSE,
368-
parallel_args = NULL, verbose = TRUE, ...) {
369-
if (is.null(pred_fun)) {
370-
pred_fun <- mlr3_pred_fun(object, X = X)
371-
}
372-
kernelshap.default(
373-
object = object,
374-
X = X,
375-
bg_X = bg_X,
376-
pred_fun = pred_fun,
377-
feature_names = feature_names,
378-
bg_w = bg_w,
379-
exact = exact,
380-
hybrid_degree = hybrid_degree,
381-
paired_sampling = paired_sampling,
382-
m = m,
383-
tol = tol,
384-
max_iter = max_iter,
385-
parallel = parallel,
386-
parallel_args = parallel_args,
387-
verbose = verbose,
388-
...
389-
)
390-
}

R/permshap.R

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -168,26 +168,3 @@ permshap.ranger <- function(object, X, bg_X,
168168
)
169169
}
170170

171-
#' @describeIn permshap Permutation SHAP method for "mlr3" models, see Readme for an example.
172-
#' @export
173-
permshap.Learner <- function(object, X, bg_X,
174-
pred_fun = NULL,
175-
feature_names = colnames(X),
176-
bg_w = NULL, parallel = FALSE, parallel_args = NULL,
177-
verbose = TRUE, ...) {
178-
if (is.null(pred_fun)) {
179-
pred_fun <- mlr3_pred_fun(object, X = X)
180-
}
181-
permshap.default(
182-
object = object,
183-
X = X,
184-
bg_X = bg_X,
185-
pred_fun = pred_fun,
186-
feature_names = feature_names,
187-
bg_w = bg_w,
188-
parallel = parallel,
189-
parallel_args = parallel_args,
190-
verbose = verbose,
191-
...
192-
)
193-
}

R/utils.R

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -380,26 +380,3 @@ prep_w <- function(w, bg_n) {
380380
if (!is.double(w)) as.double(w) else w
381381
}
382382

383-
#' mlr3 Helper
384-
#'
385-
#' Returns the prediction function of a mlr3 Learner.
386-
#'
387-
#' @noRd
388-
#' @keywords internal
389-
#'
390-
#' @param object Learner object.
391-
#' @param X Dataframe like object.
392-
#'
393-
#' @returns A function.
394-
mlr3_pred_fun <- function(object, X) {
395-
if ("classif" %in% object$task_type) {
396-
# Check if probabilities are available
397-
test_pred <- object$predict_newdata(utils::head(X))
398-
if ("prob" %in% test_pred$predict_types) {
399-
return(function(m, X) m$predict_newdata(X)$prob)
400-
} else {
401-
stop("Set lrn(..., predict_type = 'prob') to allow for probabilistic classification.")
402-
}
403-
}
404-
function(m, X) m$predict_newdata(X)$response
405-
}

README.md

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -249,9 +249,11 @@ ps_class <- permshap(fit_class, X = iris[, -5], bg_X = iris)
249249

250250
![](man/figures/README-prob-dep.svg)
251251

252-
### Tidymodels
252+
### Meta-learners
253253

254-
Meta-learning packages like {tidymodels}, {caret} or {mlr3} are straight-forward to use. The following example additionally shows that the `...` argument of `permshap()` and `kernelshap()` is passed to `predict()`.
254+
Meta-learning packages like {tidymodels}, {caret} or {mlr3} are straightforward to use. The following examples additionally shows that the `...` arguments of `permshap()` and `kernelshap()` are passed to `predict()`.
255+
256+
#### Tidymodels
255257

256258
```r
257259
library(kernelshap)
@@ -285,6 +287,56 @@ $.pred_setosa
285287
[2,] 0.02628333 0.001315556 0.3683833 0.2706111
286288
```
287289

290+
#### caret
291+
292+
```r
293+
library(kernelshap)
294+
library(caret)
295+
296+
fit <- train(
297+
Sepal.Length ~ .,
298+
data = iris,
299+
method = "lm",
300+
tuneGrid = data.frame(intercept = TRUE),
301+
trControl = trainControl(method = "none")
302+
)
303+
304+
ps <- permshap(fit, iris[-1], bg_X = iris)
305+
```
306+
307+
#### mlr3
308+
309+
```r
310+
library(kernelshap)
311+
library(mlr3)
312+
library(mlr3learners)
313+
314+
set.seed(1)
315+
316+
task_classif <- TaskClassif$new(id = "1", backend = iris, target = "Species")
317+
learner_classif <- lrn("classif.rpart", predict_type = "prob")
318+
learner_classif$train(task_classif)
319+
320+
predict(learner_classif, head(iris)) # setosa setosa # Classes
321+
predict(learner_classif, head(iris), predict_type = "prob") # Probs per class
322+
323+
x <- learner_classif$selected_features()
324+
325+
# For *probabilistic* classification, pass predict_type = "prob" to mlr3's predict()
326+
ps <- permshap(
327+
learner_classif, X = iris, bg_X = iris, feature_names = x, predict_type = "prob"
328+
)
329+
ps
330+
# $setosa
331+
# Petal.Length Petal.Width
332+
# [1,] 0.6666667 0
333+
# [2,] 0.6666667 0
334+
335+
# Non-probabilistic classification uses auto-OHE internally
336+
ps <- permshap(learner_classif, X = iris, bg_X = iris, feature_names = x)
337+
ps
338+
```
339+
288340
## References
289341

290342
[1] Erik Štrumbelj and Igor Kononenko. Explaining prediction models and individual predictions with feature contributions. Knowledge and Information Systems 41, 2014.

man/kernelshap.Rd

Lines changed: 0 additions & 22 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/permshap.Rd

Lines changed: 0 additions & 16 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)