Skip to content

Commit 3998ff1

Browse files
Added bias for conv: diagonal case
1 parent be06b94 commit 3998ff1

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

stochman/nnj.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -711,6 +711,11 @@ def _jacobian_wrt_weight_sandwich_diag_to_diag(self, x: Tensor, val: Tensor, tmp
711711
output_tmp_single_batch = output_tmp_single_batch.reshape(c2 * c1 * kernel_h * kernel_w)
712712
output_tmp[i, :] = output_tmp_single_batch
713713

714+
if self.bias is not None:
715+
bias_term = tmp_diag.reshape(b, c2, h2*w2)
716+
bias_term = torch.sum(bias_term, 2)
717+
output_tmp = torch.cat([output_tmp, bias_term], dim=1)
718+
714719
return output_tmp
715720

716721

0 commit comments

Comments
 (0)