Skip to content

Commit 43456ff

Browse files
added upsampling
1 parent d16119d commit 43456ff

File tree

1 file changed

+58
-7
lines changed

1 file changed

+58
-7
lines changed

stochman/nnj.py

Lines changed: 58 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -134,14 +134,19 @@ def _jacobian_wrt_weight_sandwich(
134134
return None
135135

136136
def _jacobian_wrt_input_sandwich(self, x: Tensor, val: Tensor, tmp: Tensor, diag: bool = False) -> Tensor:
137+
if diag:
138+
return self._jacobian_wrt_input_diag_sandwich(x, val, tmp)
139+
else:
140+
return self._jacobian_wrt_input_full_sandwich(x, val, tmp)
137141

142+
def _jacobian_wrt_input_diag_sandwich(self, x: Tensor, val: Tensor, diag_tmp: Tensor) -> Tensor:
138143
b, c1, h1, w1 = x.shape
139144
c2, h2, w2 = val.shape[1:]
140145

141146
weight = torch.ones(c2, c1, int(self.scale_factor), int(self.scale_factor), device=x.device)
142147

143-
tmp = F.conv2d(
144-
tmp.reshape(-1, c2, h2, w2),
148+
diag_tmp = F.conv2d(
149+
diag_tmp.reshape(-1, c2, h2, w2),
145150
weight=weight,
146151
bias=None,
147152
stride=int(self.scale_factor),
@@ -150,8 +155,47 @@ def _jacobian_wrt_input_sandwich(self, x: Tensor, val: Tensor, tmp: Tensor, diag
150155
groups=1,
151156
)
152157

153-
return tmp.reshape(b, c1 * h1 * w1)
158+
return diag_tmp.reshape(b, c1 * h1 * w1)
159+
160+
def _jacobian_wrt_input_full_sandwich(self, x: Tensor, val: Tensor, tmp: Tensor) -> Tensor:
161+
b, c1, h1, w1 = x.shape
162+
c2, h2, w2 = val.shape[1:]
163+
164+
assert c1==c2
165+
166+
weight = torch.ones(1, 1, int(self.scale_factor), int(self.scale_factor), device=x.device)
167+
168+
tmp = tmp.reshape(b, c2, h2*w2, c2, h2*w2)
169+
tmp = tmp.movedim(2,3)
170+
tmp_J = F.conv2d(
171+
tmp.reshape(b*c2*c2 * h2*w2, 1, h2, w2),
172+
weight=weight,
173+
bias=None,
174+
stride=int(self.scale_factor),
175+
padding=0,
176+
dilation=1,
177+
groups=1,
178+
).reshape(b*c2*c2, h2*w2, h1*w1)
179+
180+
Jt_tmpt = tmp_J.movedim(-1,-2)
181+
182+
Jt_tmpt_J = F.conv2d(
183+
Jt_tmpt.reshape(b*c2*c2 * h1*w1, 1, h2, w2),
184+
weight=weight,
185+
bias=None,
186+
stride=int(self.scale_factor),
187+
padding=0,
188+
dilation=1,
189+
groups=1,
190+
).reshape(b*c2*c2, h1*w1, h1*w1)
191+
192+
Jt_tmp_J = Jt_tmpt_J.movedim(-1,-2)
193+
194+
Jt_tmp_J = Jt_tmp_J.reshape(b, c2, c2, h1*w1, h1*w1)
195+
Jt_tmp_J = Jt_tmp_J.movedim(2,3)
196+
Jt_tmp_J = Jt_tmp_J.reshape(b, c2*h1*w1, c2*h1*w1)
154197

198+
return Jt_tmp_J
155199

156200
class Conv1d(AbstractJacobian, nn.Conv1d):
157201
def _jacobian_wrt_input_mult_left_vec(self, x: Tensor, val: Tensor, jac_in: Tensor) -> Tensor:
@@ -490,7 +534,7 @@ def _jacobian_wrt_weight_mult_left(
490534
# transpose
491535
tmp_J = Jt_tmptt_cols.movedim(0, 1)
492536

493-
return tmp
537+
return tmp_J
494538

495539
def _jacobian_wrt_input_sandwich(self, x: Tensor, val: Tensor, tmp: Tensor, diag: bool = False) -> Tensor:
496540
if diag:
@@ -801,9 +845,12 @@ def _jacobian_wrt_weight_sandwich(
801845
return None
802846

803847
def _jacobian_wrt_input_sandwich(self, x: Tensor, val: Tensor, tmp: Tensor, diag: bool = False) -> Tensor:
804-
return self._jacobian_wrt_input_diag_sandwich(x, val, tmp)
848+
if diag:
849+
return self._jacobian_wrt_input_diag_sandwich(x, val, tmp)
850+
else:
851+
return self._jacobian_wrt_input_full_sandwich(x, val, tmp)
805852

806-
def _jacobian_wrt_input_diag_sandwich(self, x: Tensor, val: Tensor, tmp: Tensor) -> Tensor:
853+
def _jacobian_wrt_input_diag_sandwich(self, x: Tensor, val: Tensor, diag_tmp: Tensor) -> Tensor:
807854
b, c1, h1, w1 = x.shape
808855
c2, h2, w2 = val.shape[1:]
809856

@@ -813,10 +860,14 @@ def _jacobian_wrt_input_diag_sandwich(self, x: Tensor, val: Tensor, tmp: Tensor)
813860
arange_repeated = torch.repeat_interleave(torch.arange(b * c1), h2 * w2).long()
814861
arange_repeated = arange_repeated.reshape(b * c2, h2 * w2)
815862

816-
new_tmp[arange_repeated, idx] = tmp.reshape(b * c2, h2 * w2)
863+
new_tmp[arange_repeated, idx] = diag_tmp.reshape(b * c2, h2 * w2)
817864

818865
return new_tmp.reshape(b, c1 * h1 * w1)
819866

867+
def _jacobian_wrt_input_full_sandwich(self, x: Tensor, val: Tensor, tmp: Tensor) -> Tensor:
868+
869+
return tmp
870+
820871

821872
class MaxPool3d(AbstractJacobian, nn.MaxPool3d):
822873
def forward(self, input: Tensor):

0 commit comments

Comments
 (0)