Skip to content

Commit 55c8ea7

Browse files
authored
Merge pull request #143 from ModelOriented/improve-tests
More compact tests
2 parents 16f9c3e + 3494213 commit 55c8ea7

12 files changed

+383
-711
lines changed
Lines changed: 35 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,60 +1,34 @@
1-
test_that("simple additive formula gives same as permshap() if full training data is used as bg data", {
2-
form <- Sepal.Length ~ .
3-
fit_lm <- lm(form, data = iris)
4-
fit_glm <- glm(form, data = iris, family = quasipoisson)
5-
6-
s_add_lm <- additive_shap(fit_lm, head(iris), verbose = FALSE)
7-
s_add_glm <- additive_shap(fit_glm, head(iris), verbose = FALSE)
8-
9-
X <- head(iris[-1L])
10-
s_perm_lm <- permshap(fit_lm, X = X, bg_X = iris, verbose = FALSE)
11-
s_perm_glm <- permshap(
12-
fit_glm, X = X, bg_X = iris, verbose = FALSE
1+
test_that("Additive formulas give same as agnostic SHAP with full training data as bg data", {
2+
formulas <- list(
3+
Sepal.Length ~ .,
4+
Sepal.Length ~ log(Sepal.Width) + poly(Sepal.Width, 2) + Petal.Length,
5+
form <- Sepal.Length ~ log(Sepal.Width) + Species + poly(Petal.Length, 2)
136
)
14-
expect_equal(s_add_lm$S, s_perm_lm$S)
15-
expect_equal(s_add_glm$S, s_perm_glm$S)
16-
expect_equal(s_add_lm$predictions, unname(predict(fit_lm, newdata = X)))
17-
expect_equal(s_add_glm$predictions, unname(predict(fit_glm, newdata = X)))
18-
})
19-
20-
test_that("formula where feature appears in two terms gives same as permshap() if full training data is used as bg data", {
21-
form <- Sepal.Length ~ log(Sepal.Width) + poly(Sepal.Width, 2) + Petal.Length
22-
fit_lm <- lm(form, data = iris)
23-
fit_glm <- glm(form, data = iris, family = quasipoisson)
24-
25-
s_add_lm <- additive_shap(fit_lm, head(iris), verbose = FALSE)
26-
s_add_glm <- additive_shap(fit_glm, head(iris), verbose = FALSE)
27-
28-
X <- head(iris[2:3])
29-
s_perm_lm <- permshap(fit_lm, X = X, bg_X = iris, verbose = FALSE)
30-
s_perm_glm <- permshap(
31-
fit_glm, X = X, bg_X = iris, verbose = FALSE
7+
xvars <- list(
8+
setdiff(colnames(iris), "Sepal.Length"),
9+
c("Sepal.Width", "Petal.Length"),
10+
xvars <- c("Sepal.Width", "Petal.Length", "Species")
3211
)
33-
expect_equal(s_add_lm$S, s_perm_lm$S)
34-
expect_equal(s_add_glm$S, s_perm_glm$S)
35-
expect_equal(s_add_lm$predictions, unname(predict(fit_lm, newdata = X)))
36-
expect_equal(s_add_glm$predictions, unname(predict(fit_glm, newdata = X)))
37-
})
38-
39-
test_that("formula with complicated terms gives same as permshap() if full training data is used as bg data", {
40-
form <- Sepal.Length ~
41-
log(Sepal.Width) + Species + poly(Petal.Length, 2)
4212

43-
fit_lm <- lm(form, data = iris)
44-
fit_glm <- glm(form, data = iris, family = quasipoisson)
45-
46-
s_add_lm <- additive_shap(fit_lm, head(iris), verbose = FALSE)
47-
s_add_glm <- additive_shap(fit_glm, head(iris), verbose = FALSE)
48-
49-
X <- head(iris[c(2, 3, 5)])
50-
s_perm_lm <- permshap(fit_lm, X = X, bg_X = iris, verbose = FALSE)
51-
s_perm_glm <- permshap(
52-
fit_glm, X = X, bg_X = iris, verbose = FALSE
53-
)
54-
expect_equal(s_add_lm$S, s_perm_lm$S)
55-
expect_equal(s_add_glm$S, s_perm_glm$S)
56-
expect_equal(s_add_lm$predictions, unname(predict(fit_lm, newdata = X)))
57-
expect_equal(s_add_glm$predictions, unname(predict(fit_glm, newdata = X)))
13+
for (j in seq_along(formulas)) {
14+
fit <- list(
15+
lm = lm(formulas[[j]], data = iris),
16+
glm = glm(formulas[[j]], data = iris, family = quasipoisson)
17+
)
18+
19+
shap1 <- lapply(fit, additive_shap, head(iris), verbose = FALSE)
20+
shap2 <- lapply(
21+
fit, permshap, head(iris), bg_X = iris, verbose = FALSE, feature_names = xvars[[j]]
22+
)
23+
shap3 <- lapply(
24+
fit, kernelshap, head(iris), bg_X = iris, verbose = FALSE, feature_names = xvars[[j]]
25+
)
26+
27+
for (i in seq_along(fit)) {
28+
expect_equal(shap1[[i]]$S, shap2[[i]]$S)
29+
expect_equal(shap1[[i]]$S, shap3[[i]]$S)
30+
}
31+
}
5832
})
5933

6034
test_that("formulas with more than one covariate per term fail", {
@@ -65,10 +39,12 @@ test_that("formulas with more than one covariate per term fail", {
6539
)
6640

6741
for (formula in formulas_bad) {
68-
fit <- lm(formula, data = iris)
69-
expect_error(s <- additive_shap(fit, head(iris), verbose = FALSE))
70-
71-
fit <- glm(formula, data = iris, family = quasipoisson)
72-
expect_error(s <- additive_shap(fit, head(iris), verbose = FALSE))
42+
fit <- list(
43+
lm = lm(formula, data = iris),
44+
glm = glm(formula, data = iris, family = quasipoisson)
45+
)
46+
for (f in fit)
47+
expect_error(additive_shap(f, head(iris), verbose = FALSE))
7348
}
7449
})
50+

tests/testthat/test-basic.R

Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
# Model with non-linearities and interactions
2+
fit <- lm(
3+
Sepal.Length ~ poly(Petal.Width, degree = 2L) * Species + Petal.Length, data = iris
4+
)
5+
x <- c("Petal.Width", "Species", "Petal.Length")
6+
preds <- unname(predict(fit, iris))
7+
J <- c(1L, 51L, 101L)
8+
9+
shap <- list(
10+
kernelshap(fit, iris[x], bg_X = iris, verbose = FALSE),
11+
permshap(fit, iris[x], bg_X = iris, verbose = FALSE)
12+
)
13+
14+
test_that("baseline equals average prediction on background data", {
15+
for (s in shap)
16+
expect_equal(s$baseline, mean(iris$Sepal.Length))
17+
})
18+
19+
test_that("SHAP + baseline = prediction for exact mode", {
20+
for (s in shap)
21+
expect_equal(rowSums(s$S) + s$baseline, preds)
22+
})
23+
24+
test_that("auto-selection of background data works", {
25+
# Here, the background data equals the full X
26+
shap2 <- list(
27+
kernelshap(fit, iris[x], verbose = FALSE),
28+
permshap(fit, iris[x], verbose = FALSE)
29+
)
30+
31+
for (i in 1:2) {
32+
expect_equal(shap$S, shap2$S)
33+
}
34+
})
35+
36+
test_that("missing bg_X gives error if X is very small", {
37+
for (algo in c(kernelshap, permshap))
38+
expect_error(algo(fit, iris[1:10, x], verbose = FALSE))
39+
40+
})
41+
42+
test_that("missing bg_X gives warning if X is quite small", {
43+
for (algo in c(kernelshap, permshap))
44+
expect_warning(algo(fit, iris[1:30, x], verbose = FALSE))
45+
})
46+
47+
test_that("selection of bg_X can be controlled via bg_n", {
48+
for (algo in c(kernelshap, permshap)) {
49+
s <- algo(fit, iris[x], verbose = FALSE, bg_n = 20L)
50+
expect_equal(nrow(s$bg_X), 20L)
51+
}
52+
})
53+
54+
test_that("using foreach (non-parallel) gives the same as normal mode", {
55+
for (algo in c(kernelshap, permshap)) {
56+
s <- algo(fit, iris[J, x], bg_X = iris, verbose = FALSE)
57+
s2 <- suppressWarnings(
58+
algo(fit, iris[J, x], bg_X = iris, verbose = FALSE, parallel = TRUE)
59+
)
60+
expect_equal(s, s2)
61+
}
62+
})
63+
64+
test_that("verbose is chatty", {
65+
for (algo in c(kernelshap, permshap)) {
66+
capture_output(expect_message(algo(fit, iris[J, x], bg_X = iris, verbose = TRUE)))
67+
}
68+
})
69+
70+
test_that("large background data cause warning", {
71+
# Takes a bit of time, thus only for one algo
72+
large_bg <- iris[rep(1:150, 230), ]
73+
expect_warning(
74+
kernelshap(fit, iris[1L, x], bg_X = large_bg, verbose = FALSE)
75+
)
76+
})
77+
78+
test_that("Decomposing a single row works", {
79+
for (algo in c(kernelshap, permshap)) {
80+
s <- algo(fit, iris[1L, x], bg_X = iris, verbose = FALSE)
81+
expect_equal(s$baseline, mean(iris$Sepal.Length))
82+
expect_equal(rowSums(s$S) + s$baseline, preds[1])
83+
}
84+
})
85+
86+
test_that("Background data can contain additional columns", {
87+
for (algo in c(kernelshap, permshap)) {
88+
s <- algo(fit, iris[1L, x], bg_X = cbind(d = 1, iris), verbose = FALSE)
89+
expect_true(is.kernelshap(s))
90+
}
91+
})
92+
93+
test_that("Background data can contain only one single row", {
94+
for (algo in c(kernelshap, permshap))
95+
expect_no_error(algo(fit, iris[1L, x], bg_X = iris[150L, ], verbose = FALSE))
96+
})
97+
98+
test_that("feature_names can drop columns from SHAP calculations", {
99+
for (algo in c(kernelshap, permshap)) {
100+
s <- algo(fit, iris[J, ], bg_X = iris, feature_names = x, verbose = FALSE)
101+
expect_equal(colnames(s$S), x)
102+
}
103+
})
104+
105+
test_that("feature_names can rearrange column names in result", {
106+
for (algo in c(kernelshap, permshap)) {
107+
s <- algo(fit, iris[J, ], bg_X = iris, feature_names = rev(x), verbose = FALSE)
108+
expect_equal(colnames(s$S), rev(x))
109+
}
110+
})
111+
112+
test_that("feature_names must be in colnames(X) and colnames(bg_X)", {
113+
for (algo in c(kernelshap, permshap)) {
114+
expect_error(algo(fit, iris, bg_X = cbind(iris, a = 1), feature_names = "a"))
115+
expect_error(algo(fit, cbind(iris, a = 1), bg_X = iris, feature_names = "a"))
116+
}
117+
})
118+
119+
test_that("Matrix input is fine", {
120+
X <- data.matrix(iris)
121+
pred_fun <- function(m, X) {
122+
data <- as.data.frame(X) |>
123+
transform(Species = factor(Species, labels = levels(iris$Species)))
124+
predict(m, data)
125+
}
126+
127+
for (algo in c(kernelshap, permshap)) {
128+
s <- algo(fit, X[J, x], pred_fun = pred_fun, bg_X = X, verbose = FALSE)
129+
130+
expect_equal(s$baseline, mean(iris$Sepal.Length)) # baseline is mean of bg
131+
expect_equal(rowSums(s$S) + s$baseline, preds[J]) # sum shap = centered preds
132+
expect_no_error( # additional cols in bg are ok
133+
algo(fit, X[J, x], pred_fun = pred_fun, bg_X = cbind(d = 1, X), verbose = FALSE)
134+
)
135+
expect_error( # feature_names are less flexible
136+
algo(fit, X[J, ], pred_fun = pred_fun, bg_X = X,
137+
verbose = FALSE, feature_names = "Sepal.Width")
138+
)
139+
}
140+
})
141+
142+
test_that("Special case p = 1 works only for kernelshap()", {
143+
capture_output(
144+
expect_message(
145+
s <- kernelshap(fit, X = iris[J, ], bg_X = iris, feature_names = "Petal.Width")
146+
)
147+
)
148+
expect_equal(s$baseline, mean(iris$Sepal.Length))
149+
expect_equal(unname(rowSums(s$S)) + s$baseline, preds[J])
150+
expect_equal(s$SE[1L], 0)
151+
152+
expect_error( # Not implemented
153+
permshap(
154+
fit, iris[J, ], bg_X = iris, verbose = FALSE, feature_names = "Petal.Width"
155+
)
156+
)
157+
})
158+
159+
test_that("exact hybrid kernelshap() is similar to exact (non-hybrid)", {
160+
s1 <- kernelshap(
161+
fit, iris[J, x], bg_X = iris, exact = FALSE, hybrid_degree = 1L, verbose = FALSE
162+
)
163+
expect_equal(s1$S, shap[[1L]]$S[J, ])
164+
})
165+
166+
test_that("baseline equals average prediction on background data in sampling mode", {
167+
s2 <- s_sampling <- kernelshap(
168+
fit, iris[J, x], bg_X = iris, hybrid_degree = 0L, verbose = FALSE, exact = FALSE
169+
)
170+
expect_equal(s2$baseline, mean(iris$Sepal.Length))
171+
})
172+
173+
test_that("SHAP + baseline = prediction for sampling mode", {
174+
s2 <- s_sampling <- kernelshap(
175+
fit, iris[J, x], bg_X = iris, hybrid_degree = 0L, verbose = FALSE, exact = FALSE
176+
)
177+
expect_equal(rowSums(s2$S) + s2$baseline, preds[J])
178+
})
179+
180+
test_that("kernelshap works for large p (hybrid case)", {
181+
set.seed(9L)
182+
X <- data.frame(matrix(rnorm(20000L), ncol = 100L))
183+
y <- X[, 1L] * X[, 2L] * X[, 3L]
184+
fit <- lm(y ~ X1:X2:X3 + ., data = cbind(y = y, X))
185+
s <- kernelshap(fit, X[1L, ], bg_X = X, verbose = FALSE)
186+
187+
expect_equal(s$baseline, mean(y))
188+
expect_equal(rowSums(s$S) + s$baseline, unname(predict(fit, X[1L, ])))
189+
})
190+

tests/testthat/test-kernelshap-multioutput.R

Lines changed: 0 additions & 88 deletions
This file was deleted.

tests/testthat/test-kernelshap-utils.R

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
test_that("Sum of kernel weights is 1", {
1+
test_that("sum of kernel weights is 1", {
22
for (p in 2:10) {
33
expect_equal(sum(kernel_weights(p)), 1.0)
44
}
@@ -121,3 +121,4 @@ test_that("input_partly_exact(p, deg) fails for bad p or deg", {
121121
expect_error(input_partly_exact(2L, deg = 0L, feature_names = LETTERS[1:p]))
122122
expect_error(input_partly_exact(5L, deg = 3L, feature_names = LETTERS[1:p]))
123123
})
124+

0 commit comments

Comments
 (0)