@@ -26,6 +26,7 @@ If the training data is small, use the full training data. In cases with a natur
2626** Remarks**
2727
2828- Multivariate predictions are handled at no additional computational cost.
29+ - Factor-valued predictions are automatically turned into one-hot-encoded columns.
2930- By changing the defaults, the iterative pure sampling approach in [ 2] can be enforced.
3031- Case weights are supported via the argument ` bg_w ` .
3132
@@ -46,8 +47,8 @@ Let's model diamonds prices!
4647### Linear regression
4748
4849``` r
49- library(ggplot2 )
5050library(kernelshap )
51+ library(ggplot2 )
5152library(shapviz )
5253
5354diamonds <- transform(
@@ -221,6 +222,40 @@ shap_gam
221222# [2,] -0.5153642 -0.1080045 0.11967804 0.031341595
222223```
223224
225+ ## Multi-output models
226+
227+ {kernelshap} supports multivariate predictions, such as:
228+ - probabilistic classification,
229+ - non-probabilistic classification (factor-valued responses are turned into dummies),
230+ - regression with multivariate response, and
231+ - predictions found by applying multiple regression models.
232+
233+ ### Classification
234+
235+ We use {ranger} to fit a probabilistic and a non-probabilistic classification model.
236+
237+ ``` r
238+ library(kernelshap )
239+ library(ranger )
240+ library(shapviz )
241+
242+ # Probabilistic
243+ fit_prob <- ranger(Species ~ . , data = iris , num.trees = 20 , probability = TRUE , seed = 1 )
244+ ks_prob <- kernelshap(fit_prob , X = iris , bg_X = iris ) | >
245+ shapviz()
246+ sv_importance(ks_prob )
247+
248+ # Non-probabilistic: Predictions are factors
249+ fit_class <- ranger(Species ~ . , data = iris , num.trees = 20 , seed = 1 )
250+ ks_class <- kernelshap(fit_class , X = iris , bg_X = iris ) | >
251+ shapviz()
252+ sv_importance(ks_class )
253+ ```
254+
255+ ![ ] ( man/figures/README-prob-class.svg )
256+
257+ ![ ] ( man/figures/README-fact-class.svg )
258+
224259## Meta-learning packages
225260
226261Here, we provide some working examples for "tidymodels", "caret", and "mlr3".
@@ -283,7 +318,7 @@ fit_lm$train(task_iris)
283318s <- kernelshap(fit_lm , iris [- 1 ], bg_X = iris )
284319s
285320
286- # Probabilistic classification -> lrn(..., predict_type = "prob")
321+ # * Probabilistic* classification -> lrn(..., predict_type = "prob")
287322task_iris <- TaskClassif $ new(id = " class" , backend = iris , target = " Species" )
288323fit_rf <- lrn(" classif.ranger" , predict_type = " prob" , num.trees = 50 )
289324fit_rf $ train(task_iris )
0 commit comments