Skip to content

Commit 983177d

Browse files
Added maxpool
1 parent ba868e3 commit 983177d

File tree

1 file changed

+22
-3
lines changed

1 file changed

+22
-3
lines changed

stochman/nnj.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def _jacobian_wrt_input_sandwich(self, x: Tensor, val: Tensor, tmp: Tensor, diag
142142
def _jacobian_wrt_input_diag_sandwich(self, x: Tensor, val: Tensor, diag_tmp: Tensor) -> Tensor:
143143
b, c1, h1, w1 = x.shape
144144
c2, h2, w2 = val.shape[1:]
145-
145+
146146
weight = torch.ones(c2, c1, int(self.scale_factor), int(self.scale_factor), device=x.device)
147147

148148
diag_tmp = F.conv2d(
@@ -856,17 +856,36 @@ def _jacobian_wrt_input_diag_sandwich(self, x: Tensor, val: Tensor, diag_tmp: Te
856856

857857
new_tmp = torch.zeros_like(x)
858858
new_tmp = new_tmp.reshape(b * c1, h1 * w1)
859-
idx = self.idx.reshape(b * c2, h2 * w2)
859+
860+
# indexes for batch and channel
860861
arange_repeated = torch.repeat_interleave(torch.arange(b * c1), h2 * w2).long()
861862
arange_repeated = arange_repeated.reshape(b * c2, h2 * w2)
863+
# indexes for height and width
864+
idx = self.idx.reshape(b * c2, h2 * w2)
862865

863866
new_tmp[arange_repeated, idx] = diag_tmp.reshape(b * c2, h2 * w2)
864867

865868
return new_tmp.reshape(b, c1 * h1 * w1)
866869

867870
def _jacobian_wrt_input_full_sandwich(self, x: Tensor, val: Tensor, tmp: Tensor) -> Tensor:
871+
b, c1, h1, w1 = x.shape
872+
c2, h2, w2 = val.shape[1:]
873+
assert c1==c2
868874

869-
return tmp
875+
tmp = tmp.reshape(b, c1, h2*w2, c1, h2*w2).movedim(-2,-3).reshape(b*c1*c1, h2*w2, h2*w2)
876+
Jt_tmp_J = torch.zeros((b*c1*c1, h1*w1, h1*w1))
877+
# indexes for batch and channel
878+
arange_repeated = torch.repeat_interleave(torch.arange(b*c1*c1), h2*w2 * h2*w2).long()
879+
arange_repeated = arange_repeated.reshape(b*c1*c1, h2*w2, h2*w2)
880+
# indexes for height and width
881+
idx = self.idx.reshape(b, c1, h2 * w2).unsqueeze(2).expand(-1, -1, h2*w2, -1)
882+
idx_col = idx.unsqueeze(1).expand(-1, c1, -1, -1, -1).reshape(b*c1*c1, h2*w2, h2*w2)
883+
idx_row = idx.unsqueeze(2).expand(-1, -1, c1, -1, -1).reshape(b*c1*c1, h2*w2, h2*w2).movedim(-1,-2)
884+
885+
Jt_tmp_J[arange_repeated, idx_row, idx_col] = tmp
886+
Jt_tmp_J = Jt_tmp_J.reshape(b, c1, c1, h1*w1, h1*w1).movedim(-2,-3).reshape(b, c1*h1*w1, c1*h1*w1)
887+
888+
return Jt_tmp_J
870889

871890

872891
class MaxPool3d(AbstractJacobian, nn.MaxPool3d):

0 commit comments

Comments
 (0)