Skip to content

Commit 24a596c

Browse files
authored
Merge pull request #125 from ModelOriented/readme-more-improvements
better GAM, better random forest in readme
2 parents 060d4eb + 6eae1c3 commit 24a596c

File tree

7 files changed

+12625
-4170
lines changed

7 files changed

+12625
-4170
lines changed

README.md

Lines changed: 54 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ devtools::install_github("ModelOriented/kernelshap")
5151

5252
## Basic Usage
5353

54-
Let's model diamond prices with a (not too complex) random forest. As an alternative, you could use the {treeshap} package.
54+
Let's model diamond prices with a (not too complex) random forest. As an alternative, you could use the {treeshap} package in this situation.
5555

5656
```r
5757
library(kernelshap)
@@ -70,10 +70,10 @@ xvars <- c("log_carat", "clarity", "color", "cut")
7070
fit <- ranger(
7171
log_price ~ log_carat + clarity + color + cut,
7272
data = diamonds,
73-
num.trees = 200,
74-
max.depth = 8,
73+
num.trees = 100,
7574
seed = 20
7675
)
76+
fit # OOB R-squared 0.989
7777

7878
# 1) Sample rows to be explained
7979
set.seed(10)
@@ -82,26 +82,26 @@ X <- diamonds[sample(nrow(diamonds), 1000), xvars]
8282
# 2) Select background data
8383
bg_X <- diamonds[sample(nrow(diamonds), 200), ]
8484

85-
# 3) Crunch SHAP values for all 1000 rows of X (~27 seconds)
85+
# 3) Crunch SHAP values for all 1000 rows of X (54 seconds)
8686
# Note: Since the number of features is small, we use permshap()
8787
system.time(
8888
ps <- permshap(fit, X, bg_X = bg_X)
8989
)
9090
ps
9191

9292
# SHAP values of first observations:
93-
# log_carat clarity color cut
94-
# [1,] 1.1203506 0.06502841 -0.14040781 0.003009253
95-
# [2,] -0.4798394 -0.09477441 0.07874888 0.019542365
93+
log_carat clarity color cut
94+
[1,] 1.1913247 0.09005467 -0.13430720 0.000682593
95+
[2,] -0.4931989 -0.11724773 0.09868921 0.028563613
9696

9797
# Kernel SHAP gives almost the same:
9898
system.time( # 28 s
9999
ks <- kernelshap(fit, X, bg_X = bg_X)
100100
)
101101
ks
102102
# log_carat clarity color cut
103-
# [1,] 1.1204982 0.06499042 -0.14074578 0.00323761
104-
# [2,] -0.4795507 -0.09443787 0.07867385 0.01899212
103+
# [1,] 1.1911791 0.0900462 -0.13531648 0.001845958
104+
# [2,] -0.4927482 -0.1168517 0.09815062 0.028255442
105105

106106
# 4) Analyze with our sister package {shapviz}
107107
ps <- shapviz(ps)
@@ -117,6 +117,51 @@ sv_dependence(ps, xvars)
117117

118118
{kernelshap} can deal with almost any situation. We will show some of the flexibility here. The first two examples require you to run at least up to Step 2 of the "Basic Usage" code.
119119

120+
### Parallel computing
121+
122+
Parallel computing is supported via {foreach}. Note that this does not work with all models, and that there is no progress bar.
123+
124+
On Windows, sometimes not all packages or global objects are passed to the parallel sessions. Often, this can be fixed via `parallel_args`, see the generalized additive model below.
125+
126+
```r
127+
library(doFuture)
128+
library(mgcv)
129+
130+
registerDoFuture()
131+
plan(multisession, workers = 4) # Windows
132+
# plan(multicore, workers = 4) # Linux, macOS, Solaris
133+
134+
fit <- gam(log_price ~ s(log_carat) + clarity * color + cut, data = diamonds)
135+
136+
system.time( # 9 seconds in parallel
137+
ps <- permshap(
138+
fit, X, bg_X = bg_X, parallel = TRUE, parallel_args = list(.packages = "mgcv")
139+
)
140+
)
141+
ps
142+
143+
# SHAP values of first observations:
144+
# log_carat clarity color cut
145+
# [1,] 1.26801 0.1023518 -0.09223291 0.004512402
146+
# [2,] -0.51546 -0.1174766 0.11122775 0.030243973
147+
148+
# Because there are no interactions of order above 2, Kernel SHAP gives the same:
149+
system.time( # 27 s non-parallel
150+
ks <- kernelshap(fit, X, bg_X = bg_X)
151+
)
152+
all.equal(ps$S, ks$S)
153+
# [1] TRUE
154+
155+
# Now the usual plots:
156+
sv <- shapviz(ps)
157+
sv_importance(sv, kind = "bee")
158+
sv_dependence(sv, xvars)
159+
```
160+
161+
![](man/figures/README-gam-imp.svg)
162+
163+
![](man/figures/README-gam-dep.svg)
164+
120165
### Taylored predict()
121166

122167
In this {keras} example, we show how to use a tailored `predict()` function that complies with
@@ -170,41 +215,6 @@ sv_dependence(ps, xvars)
170215

171216
![](man/figures/README-nn-dep.svg)
172217

173-
### Parallel computing
174-
175-
Parallel computing is supported via {foreach}. Note that this does not work with all models, and that there is no progress bar.
176-
177-
On Windows, sometimes not all packages or global objects are passed to the parallel sessions. In this case, the necessary instructions to {foreach} can be specified through a named list via `parallel_args`, see the following example.
178-
179-
```r
180-
library(doFuture)
181-
library(mgcv)
182-
183-
registerDoFuture()
184-
plan(multisession, workers = 4) # Windows
185-
# plan(multicore, workers = 4) # Linux, macOS, Solaris
186-
187-
fit <- gam(log_price ~ s(log_carat) + clarity + color + cut, data = diamonds)
188-
189-
system.time( # 9 seconds
190-
ps <- permshap(
191-
fit, X, bg_X = bg_X, parallel = TRUE, parallel_args = list(.packages = "mgcv")
192-
)
193-
)
194-
ps
195-
196-
# SHAP values of first observations:
197-
# log_carat clarity color cut
198-
# [1,] 1.2714988 0.1115546 -0.08454955 0.003220451
199-
# [2,] -0.5153642 -0.1080045 0.11967804 0.031341595
200-
201-
# Because there are no high-order interactions, Kernel SHAP gives the same:
202-
kernelshap(fit, X[1:2, ], bg_X = bg_X)
203-
# log_carat clarity color cut
204-
# [1,] 1.2714988 0.1115546 -0.08454955 0.003220451
205-
# [2,] -0.5153642 -0.1080045 0.11967804 0.031341595
206-
```
207-
208218
### Multi-output models
209219

210220
{kernelshap} supports multivariate predictions like:

backlog/plot_settings

Lines changed: 0 additions & 1 deletion
This file was deleted.

backlog/plot_settings.R

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
ggsave("man/figures/README-rf-imp.svg", scale = 2)
2+
ggsave("man/figures/README-rf-dep.svg", width = 8.5, height = 6)
3+
4+
ggsave("man/figures/README-gam-imp.svg", scale = 2)
5+
ggsave("man/figures/README-gam-dep.svg", width = 8.5, height = 6)

0 commit comments

Comments
 (0)