@@ -193,18 +193,7 @@ test_that("exact hybrid kernelshap() is similar to exact (non-hybrid)", {
193193 expect_equal(s1 $ S , shap [[1L ]]$ S )
194194})
195195
196- # test_that("kernelshap works for large p (hybrid case)", {
197- # set.seed(9L)
198- # X <- data.frame(matrix(rnorm(20000L), ncol = 100L))
199- # y <- X[, 1L] * X[, 2L] * X[, 3L]
200- # fit <- lm(y ~ X1:X2:X3 + ., data = cbind(y = y, X))
201- # s <- kernelshap(fit, X[1L, ], bg_X = X, verbose = FALSE)
202- #
203- # expect_equal(s$baseline, mean(y))
204- # expect_equal(rowSums(s$S) + s$baseline, unname(predict(fit, X[1L, ])))
205- # })
206-
207- test_that(" kernelshap and permshap work for models with high-order interactions" , {
196+ test_that(" exact kernelshap and permshap agree with Python on model with high-order interactions" , {
208197 # Expected: Python output
209198 # import numpy as np
210199 # import shap # 0.47.2
@@ -263,10 +252,36 @@ test_that("kernelshap and permshap work for models with high-order interactions"
263252
264253 ps <- permshap(pf , head(X , 2 ), bg_X = X , pred_fun = pf , verbose = FALSE )
265254 expect_equal(unname(ps $ S ), expected )
255+ })
266256
267- # Sampling versions of KernelSHAP is quite close
257+ test_that(" Test that sampling versions are close to exact with interaction of order 3" , {
258+ # We use a different example with less multi-collinearity, but still there is strong
259+ # collinearity between the first two features
260+ n <- 100
268261 set.seed(1 )
269- ksh2 <- kernelshap(
262+ X <- data.frame (
263+ x1 = 1 : n ,
264+ x2 = sqrt(1 : n ),
265+ x3 = cos(1 : n ),
266+ x4 = rnorm(n ),
267+ x5 = rexp(n ),
268+ x6 = runif(n )
269+ )
270+
271+ pf <- function (model , newdata ) {
272+ x <- newdata
273+ x [, 1 ] * x [, 2 ] * x [, 3 ] + x [, 4 ]
274+ }
275+ ks <- kernelshap(pf , head(X , 1 ), bg_X = X , pred_fun = pf , verbose = FALSE )
276+ ps <- permshap(pf , head(X , 1 ), bg_X = X , pred_fun = pf , verbose = FALSE )
277+ expect_equal(ks $ S , ps $ S )
278+ expect_equal(
279+ c(ks $ S ), c(- 44.71698 , - 32.89963 , 78.34841 , - 0.7353412 , 0 , 0 ),
280+ tolerance = 0.0001
281+ )
282+
283+ # Sampling versions of KernelSHAP are very close
284+ ks2 <- kernelshap(
270285 pf ,
271286 head(X , 1 ),
272287 bg_X = X ,
@@ -276,16 +291,17 @@ test_that("kernelshap and permshap work for models with high-order interactions"
276291 m = 1000 ,
277292 max_iter = 100 ,
278293 tol = 0.001 ,
279- verbose = FALSE
294+ verbose = FALSE ,
295+ seed = 1
280296 )
297+ expect_true(ks2 $ converged )
281298 expect_equal(
282- c(ksh2 $ S ),
283- c(- 1.194878 , - 1.24747 , - 0.9596389 , 3.883523 , - 0.3349787 , 0.5453894 ),
284- tolerance = 1e-4
299+ c(ks2 $ S ), c( - 44.52355 , - 32.93511 , 78.16014 , - 0.7202121 , - 0.02362539 , 0.03880312 ),
300+ # Exact c(-44.71698 , -32.89963, 78.34841, -0.7353412, 0 , 0)
301+ tolerance = 0.0001
285302 )
286303
287- set.seed(1 )
288- ksh1 <- kernelshap(
304+ ks1 <- kernelshap(
289305 pf ,
290306 head(X , 1 ),
291307 bg_X = X ,
@@ -294,37 +310,56 @@ test_that("kernelshap and permshap work for models with high-order interactions"
294310 exact = FALSE ,
295311 m = 1000 ,
296312 max_iter = 1000 ,
297- tol = 0.002 ,
298- verbose = FALSE
313+ tol = 0.001 ,
314+ verbose = FALSE ,
315+ seed = 1
299316 )
317+ expect_true(ks1 $ converged )
300318 expect_equal(
301- c(ksh1 $ S ),
302- c(- 1.196958 , - 1.256924 , - 0.9603291 , 3.886163 , - 0.3277153 , 0.5477104 ),
303- tolerance = 1e-3
319+ c(ks1 $ S ), c( - 44.8478 , - 32.81717 , 78.46633 , - 0.8514861 , 0.01075054 , 0.03582543 ),
320+ # Exact c(-44.71698 , -32.89963, 78.34841, -0.7353412, 0 , 0)
321+ tolerance = 0.001
304322 )
305323
306- set.seed(1 )
307- ksh0 <- suppressWarnings(
308- kernelshap(
309- pf ,
310- head(X , 1 ),
311- bg_X = X ,
312- pred_fun = pf ,
313- hybrid_degree = 0 ,
314- exact = FALSE ,
315- m = 10000 ,
316- max_iter = 10000 ,
317- tol = 0.003 ,
318- verbose = FALSE
319- )
324+ ks0 <- kernelshap(
325+ pf ,
326+ head(X , 1 ),
327+ bg_X = X ,
328+ pred_fun = pf ,
329+ hybrid_degree = 0 ,
330+ exact = FALSE ,
331+ m = 1000 ,
332+ max_iter = 1000 ,
333+ tol = 0.005 ,
334+ verbose = FALSE ,
335+ seed = 1
320336 )
337+ expect_true(ks0 $ converged )
321338 expect_equal(
322- c(ksh0 $ S ),
323- c(- 1.18917 , - 1.2298 , - 0.9247995 , 3.80673 , - 0.3144175 , 0.5434034 ),
324- tolerance = 1e-3
339+ c(ks0 $ S ), c( - 44.29753 , - 33.39267 , 78.67423 , - 0.290739 , - 0.3779175 , - 0.3189199 ),
340+ # Exact c(-44.71698 , -32.89963, 78.34841, -0.7353412, 0 , 0)
341+ tolerance = 0.001
325342 )
326- })
327343
344+ # Too slow for closer results, but we can see additive recovery for x4-x6
345+ pss <- permshap(
346+ pf ,
347+ head(X , 1 ),
348+ bg_X = X ,
349+ pred_fun = pf ,
350+ exact = FALSE ,
351+ max_iter = 100000 ,
352+ tol = 0.005 ,
353+ verbose = FALSE ,
354+ seed = 1
355+ )
356+ expect_true(pss $ converged )
357+ expect_equal(
358+ c(pss $ S ), c(- 44.36299 , - 32.79343 , 77.88822 , - 0.7353412 , 0 , 0 ),
359+ # Exact c(-44.71698, -32.89963, 78.34841, -0.7353412, 0, 0)
360+ tolerance = 0.001
361+ )
362+ })
328363
329364test_that(" Random seed works" , {
330365 n <- 100
@@ -344,9 +379,36 @@ test_that("Random seed works", {
344379 }
345380
346381 for (algo in c(permshap , kernelshap )) {
347- s1a <- algo(pf , head(X , 2 ), bg_X = X , pred_fun = pf , verbose = FALSE , seed = 1 , exact = FALSE , hybrid_degree = 0 )
348- s1b <- algo(pf , head(X , 2 ), bg_X = X , pred_fun = pf , verbose = FALSE , seed = 1 , exact = FALSE , hybrid_degree = 0 )
349- s2 <- algo(pf , head(X , 2 ), bg_X = X , pred_fun = pf , verbose = FALSE , seed = 2 , exact = FALSE , hybrid_degree = 0 )
382+ s1a <- algo(
383+ pf ,
384+ head(X , 2 ),
385+ bg_X = X ,
386+ pred_fun = pf ,
387+ verbose = FALSE ,
388+ seed = 1 ,
389+ exact = FALSE ,
390+ hybrid_degree = 0
391+ )
392+ s1b <- algo(
393+ pf ,
394+ head(X , 2 ),
395+ bg_X = X ,
396+ pred_fun = pf ,
397+ verbose = FALSE ,
398+ seed = 1 ,
399+ exact = FALSE ,
400+ hybrid_degree = 0
401+ )
402+ s2 <- algo(
403+ pf ,
404+ head(X , 2 ),
405+ bg_X = X ,
406+ pred_fun = pf ,
407+ verbose = FALSE ,
408+ seed = 2 ,
409+ exact = FALSE ,
410+ hybrid_degree = 0
411+ )
350412 expect_equal(s1a , s1b )
351413 expect_false(identical(s1a $ S , s2 $ S ))
352414 }
0 commit comments