@@ -15,25 +15,32 @@ kernelshap_one <- function(
1515 max_iter ,
1616 v0 ,
1717 precalc ,
18+ bg_n ,
1819 ... ) {
1920 p <- length(feature_names )
2021 K <- ncol(v1 )
2122 K_names <- colnames(v1 )
2223
2324 # Calculate A_exact and b_exact
2425 if (exact || deg > = 1L ) {
25- A_exact <- precalc [[" A" ]] # (p x p)
26- bg_X_exact <- precalc [[" bg_X_exact" ]] # (m_ex*n_bg x p)
27- Z <- precalc [[" Z" ]] # (m_ex x p)
26+ A_exact <- precalc $ A # (p x p)
27+ Z <- precalc $ Z # (m_ex x p)
2828 m_exact <- nrow(Z )
2929 v0_m_exact <- v0 [rep.int(1L , m_exact ), , drop = FALSE ] # (m_ex x K)
3030
3131 # Most expensive part
3232 vz <- get_vz(
33- x = x , bg = bg_X_exact , Z = Z , object = object , pred_fun = pred_fun , w = bg_w , ...
33+ x = x ,
34+ bg_rep = precalc $ bg_exact_rep , # (m_ex*bg_n x p)
35+ Z_rep = precalc $ Z_exact_rep , # (m_ex*bg_n x p)
36+ object = object ,
37+ pred_fun = pred_fun ,
38+ w = bg_w ,
39+ bg_n = bg_n ,
40+ ...
3441 )
3542 # Note: w is correctly replicated along columns of (vz - v0_m_exact)
36- b_exact <- crossprod(Z , precalc [[ " w " ]] * (vz - v0_m_exact )) # (p x K)
43+ b_exact <- crossprod(Z , precalc $ w * (vz - v0_m_exact )) # (p x K)
3744
3845 # Some of the hybrid cases are exact as well
3946 if (exact || trunc(p / 2 ) == deg ) {
@@ -43,7 +50,8 @@ kernelshap_one <- function(
4350 }
4451
4552 # Iterative sampling part, always using A_exact and b_exact to fill up the weights
46- bg_X_m <- precalc [[" bg_X_m" ]] # (m*n_bg x p)
53+ g <- rep_each(m , each = bg_n )
54+
4755 v0_m <- v0 [rep.int(1L , m ), , drop = FALSE ] # (m x K)
4856 est_m <- array (
4957 data = 0 , dim = c(max_iter , p , K ), dimnames = list (NULL , feature_names , K_names )
@@ -62,16 +70,23 @@ kernelshap_one <- function(
6270 while (! converged && n_iter < max_iter ) {
6371 n_iter <- n_iter + 1L
6472 input <- input_sampling(p = p , m = m , deg = deg , feature_names = feature_names )
65- Z <- input [[ " Z " ]]
73+ Z <- input $ Z
6674
6775 # Expensive # (m x K)
6876 vz <- get_vz(
69- x = x , bg = bg_X_m , Z = Z , object = object , pred_fun = pred_fun , w = bg_w , ...
77+ x = x ,
78+ bg_rep = precalc $ bg_sampling_rep , # (m*bg_n x p)
79+ Z_rep = Z [g , , drop = FALSE ],
80+ object = object ,
81+ pred_fun = pred_fun ,
82+ w = bg_w ,
83+ bg_n = bg_n ,
84+ ...
7085 )
7186
72- # The sum of weights of A_exact and input[["A"]] is 1, same for b
73- A_temp <- A_exact + input [[ " A " ]] # (p x p)
74- b_temp <- b_exact + crossprod(Z , input [[ " w " ]] * (vz - v0_m )) # (p x K)
87+ # The sum of weights of A_exact and input$A is 1, same for b
88+ A_temp <- A_exact + input $ A # (p x p)
89+ b_temp <- b_exact + crossprod(Z , input $ w * (vz - v0_m )) # (p x K)
7590 A_sum <- A_sum + A_temp # (p x p)
7691 b_sum <- b_sum + b_temp # (p x K)
7792
0 commit comments