Skip to content

Commit c59cefa

Browse files
author
Frederik Rahbaek Warburg
committed
bug fix and add new sandwiches
1 parent 2c48179 commit c59cefa

File tree

1 file changed

+48
-10
lines changed

1 file changed

+48
-10
lines changed

stochman/nnj.py

Lines changed: 48 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,8 @@ def _jacobian_wrt_input_mult_left_vec(self, x: Tensor, val: Tensor, jac_in: Tens
7676
def _jacobian_wrt_input_transpose_mult_left_vec(self, x: Tensor, val: Tensor, jac_in: Tensor) -> Tensor:
7777
return F.linear(jac_in.movedim(1, -1), self.weight.T, bias=None).movedim(-1, 1)
7878

79-
def _sandwich_full_wrt_input(self, x: Tensor, val: Tensor, tmp: Tensor) -> Tensor:
80-
return torch.einsum("nm,bnj,jk->bmk", self.weight, tmp, self.weight)
81-
82-
def _sandwich_full_wrt_weight(self, x: Tensor, val: Tensor, tmp: Tensor) -> Tensor:
79+
def _jacobian_wrt_weight_sandwich(self, x: Tensor, val: Tensor, tmp: Tensor, diag: bool = False) -> Tensor:
80+
8381
b, c = x.shape
8482
diag_elements = torch.diagonal(tmp, dim1=1, dim2=2)
8583
feat_k2 = (x**2).unsqueeze(1)
@@ -92,6 +90,9 @@ def _sandwich_full_wrt_weight(self, x: Tensor, val: Tensor, tmp: Tensor) -> Tens
9290

9391
return h_k
9492

93+
def _jacobian_wrt_input_sandwich(self, x: Tensor, val: Tensor, tmp: Tensor, diag: bool = False) -> Tensor:
94+
return torch.einsum("nm,bnj,jk->bmk", self.weight, tmp, self.weight)
95+
9596

9697
class PosLinear(AbstractJacobian, nn.Linear):
9798
def forward(self, x: Tensor):
@@ -123,6 +124,14 @@ def _jacobian_wrt_input_mult_left_vec(self, x: Tensor, val: Tensor, jac_in: Tens
123124
.movedim(dims2, dims1)
124125
)
125126

127+
def _jacobian_wrt_weight_sandwich(self, x: Tensor, val: Tensor, tmp: Tensor, diag: bool = False) -> Tensor:
128+
# non parametric, so return empty
129+
return []
130+
131+
def _jacobian_wrt_input_sandwich(self, x: Tensor, val: Tensor, tmp: Tensor, diag: bool = False) -> Tensor:
132+
raise NotImplementedError
133+
134+
126135

127136
class Conv1d(AbstractJacobian, nn.Conv1d):
128137
def _jacobian_wrt_input_mult_left_vec(self, x: Tensor, val: Tensor, jac_in: Tensor) -> Tensor:
@@ -358,7 +367,7 @@ def _jacobian_wrt_weight_T_mult_right(
358367
dilation=self.dilation,
359368
groups=self.groups,
360369
)
361-
.reshape(b, *tmp_single_batch.shape[4:], c1, kernel_h, kernel_w)
370+
.reshape(c2, *tmp_single_batch.shape[4:], c1, kernel_h, kernel_w)
362371
.movedim((-3, -2, -1), (1, 2, 3))
363372
)
364373

@@ -381,7 +390,7 @@ def _jacobian_wrt_weight_T_mult_right(
381390
dilation=self.dilation,
382391
groups=self.groups,
383392
)
384-
.reshape(b, *tmp.shape[4:], c1, kernel_h, kernel_w)
393+
.reshape(c2, *tmp.shape[4:], c1, kernel_h, kernel_w)
385394
.movedim((-3, -2, -1), (1, 2, 3))
386395
)
387396

@@ -390,6 +399,18 @@ def _jacobian_wrt_weight_T_mult_right(
390399

391400
return Jt_tmp
392401

402+
def _jacobian_wrt_input_sandwich(self, x: Tensor, val: Tensor, tmp: Tensor, diag: bool = False) -> Tensor:
403+
if diag:
404+
return self._jacobian_wrt_input_diag_sandwich(x, val, tmp)
405+
else:
406+
return self._jacobian_wrt_input_full_sandwich(x, val, tmp)
407+
408+
def _jacobian_wrt_weight_sandwich(self, x: Tensor, val: Tensor, tmp: Tensor, diag: bool = False) -> Tensor:
409+
if diag:
410+
return self._jacobian_wrt_weight_diag_sandwich(x, val, tmp)
411+
else:
412+
return self._jacobian_wrt_weight_full_sandwich(x, val, tmp)
413+
393414
def _jacobian_wrt_input_full_sandwich(self, x: Tensor, val: Tensor, tmp: Tensor) -> Tensor:
394415
return self._jacobian_wrt_input_mult_left(x, val, self._jacobian_wrt_input_T_mult_right(x, val, tmp))
395416

@@ -662,14 +683,21 @@ def _jacobian_wrt_input_mult_left_vec(self, x: Tensor, val: Tensor, jac_in: Tens
662683
jac_in = jac_in[arange_repeated, idx, :, :, :].reshape(*val.shape, *jac_in_orig_shape[4:])
663684
return jac_in
664685

665-
def _sandwich_full_wrt_input(self, x: Tensor, val: Tensor, tmp: Tensor) -> Tensor:
686+
def _jacobian_wrt_weight_sandwich(self, x: Tensor, val: Tensor, tmp: Tensor, diag: bool = False) -> Tensor:
687+
# non parametric, so return empty
688+
return []
689+
690+
def _jacobian_wrt_input_sandwich(self, x: Tensor, val: Tensor, tmp: Tensor, diag: bool = False) -> Tensor:
691+
return self._jacobian_wrt_input_full_sandwich(x, val, tmp)
692+
693+
def _jacobian_wrt_input_full_sandwich(self, x: Tensor, val: Tensor, tmp: Tensor) -> Tensor:
666694

667695
new_tmp = torch.zeros_like(x)
668696
new_tmp[self.idx] = tmp
669697

670698
return new_tmp
671699

672-
def _sandwich_diag_wrt_input(self, x: Tensor, val: Tensor, tmp: Tensor) -> Tensor:
700+
def _jacobian_wrt_input_diag_sandwich(self, x: Tensor, val: Tensor, tmp: Tensor) -> Tensor:
673701
pass
674702

675703

@@ -758,15 +786,25 @@ def _jacobian(self, x: Tensor, val: Tensor) -> Tensor:
758786
jac = 1.0 - val**2
759787
return jac
760788

761-
def _sandwich_full_wrt_input(self, x: Tensor, val: Tensor, tmp: Tensor) -> Tensor:
789+
def _jacobian_wrt_weight_sandwich(self, x: Tensor, val: Tensor, tmp: Tensor, diag: bool = False) -> Tensor:
790+
# non parametric, so return empty
791+
return []
792+
793+
def _jacobian_wrt_input_sandwich(self, x: Tensor, val: Tensor, tmp: Tensor, diag: bool = False) -> Tensor:
794+
if diag:
795+
return self._jacobian_wrt_input_diag_sandwich(x, val, tmp)
796+
else:
797+
return self._jacobian_wrt_input_full_sandwich(x, val, tmp)
798+
799+
def _jacobian_wrt_input_full_sandwich(self, x: Tensor, val: Tensor, tmp: Tensor) -> Tensor:
762800

763801
jac = self._jacobian(x, val)
764802
jac = torch.diag_embed(jac.view(x.shape[0], -1))
765803
tmp = torch.einsum("bnm,bnj,bjk->bmk", jac, tmp, jac)
766804

767805
return tmp
768806

769-
def _sandwich_diag_wrt_input(self, x: Tensor, val: Tensor, tmp: Tensor) -> Tensor:
807+
def _jacobian_wrt_input_diag_sandwich(self, x: Tensor, val: Tensor, tmp: Tensor) -> Tensor:
770808

771809
jac = self._jacobian(x, val)
772810
jac = jac.view(x.shape[0], -1)

0 commit comments

Comments
 (0)