Skip to content

Commit be06b94

Browse files
author
Frederik Rahbaek Warburg
committed
move flip outside for loop
1 parent 1ee5018 commit be06b94

File tree

1 file changed

+57
-62
lines changed

1 file changed

+57
-62
lines changed

stochman/nnj.py

Lines changed: 57 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -79,16 +79,16 @@ def _jacobian_wrt_input_transpose_mult_left_vec(self, x: Tensor, val: Tensor, ja
7979

8080
def _jacobian_wrt_input(self, x: Tensor, val: Tensor) -> Tensor:
8181
return self.weight
82-
82+
8383
def _jacobian_wrt_weight(self, x: Tensor, val: Tensor) -> Tensor:
8484
b, c1 = x.shape
8585
c2 = val.shape[1]
8686
out_identity = torch.diag_embed(torch.ones(c2, device=x.device))
87-
jacobian = torch.einsum('bk,ij->bijk', x, out_identity).reshape(b,c2,c2*c1)
87+
jacobian = torch.einsum("bk,ij->bijk", x, out_identity).reshape(b, c2, c2 * c1)
8888
if self.bias is not None:
89-
jacobian = torch.cat([jacobian, out_identity.unsqueeze(0).expand(b,-1,-1)], dim=2)
89+
jacobian = torch.cat([jacobian, out_identity.unsqueeze(0).expand(b, -1, -1)], dim=2)
9090
return jacobian
91-
91+
9292
def _jacobian_wrt_input_sandwich(
9393
self, x: Tensor, val: Tensor, tmp: Tensor, diag_inp: bool = False, diag_out: bool = False
9494
) -> Tensor:
@@ -112,54 +112,44 @@ def _jacobian_wrt_weight_sandwich(
112112
return self._jacobian_wrt_weight_sandwich_diag_to_full(x, val, tmp)
113113
elif diag_inp and diag_out:
114114
return self._jacobian_wrt_weight_sandwich_diag_to_diag(x, val, tmp)
115-
116-
def _jacobian_wrt_input_sandwich_full_to_full(
117-
self, x: Tensor, val: Tensor, tmp: Tensor) -> Tensor:
115+
116+
def _jacobian_wrt_input_sandwich_full_to_full(self, x: Tensor, val: Tensor, tmp: Tensor) -> Tensor:
118117
return torch.einsum("nm,bnj,jk->bmk", self.weight, tmp, self.weight)
119118

120-
def _jacobian_wrt_input_sandwich_full_to_diag(
121-
self, x: Tensor, val: Tensor, tmp: Tensor) -> Tensor:
119+
def _jacobian_wrt_input_sandwich_full_to_diag(self, x: Tensor, val: Tensor, tmp: Tensor) -> Tensor:
122120
return torch.einsum("nm,bnj,jm->bm", self.weight, tmp, self.weight)
123121

124-
def _jacobian_wrt_input_sandwich_diag_to_full(
125-
self, x: Tensor, val: Tensor, tmp_diag: Tensor) -> Tensor:
122+
def _jacobian_wrt_input_sandwich_diag_to_full(self, x: Tensor, val: Tensor, tmp_diag: Tensor) -> Tensor:
126123
return torch.einsum("nm,bn,nk->bmk", self.weight, tmp_diag, self.weight)
127124

128-
def _jacobian_wrt_input_sandwich_diag_to_diag(
129-
self, x: Tensor, val: Tensor, tmp_diag: Tensor) -> Tensor:
125+
def _jacobian_wrt_input_sandwich_diag_to_diag(self, x: Tensor, val: Tensor, tmp_diag: Tensor) -> Tensor:
130126
return torch.einsum("nm,bn,nm->bm", self.weight, tmp_diag, self.weight)
131-
132-
def _jacobian_wrt_weight_sandwich_full_to_full(
133-
self, x: Tensor, val: Tensor, tmp: Tensor) -> Tensor:
134-
jacobian = self._jacobian_wrt_weight(x,val)
135-
return torch.einsum('bji,bjk,bkq->biq', jacobian, tmp, jacobian)
136-
137-
def _jacobian_wrt_weight_sandwich_full_to_diag(
138-
self, x: Tensor, val: Tensor, tmp: Tensor) -> Tensor:
127+
128+
def _jacobian_wrt_weight_sandwich_full_to_full(self, x: Tensor, val: Tensor, tmp: Tensor) -> Tensor:
129+
jacobian = self._jacobian_wrt_weight(x, val)
130+
return torch.einsum("bji,bjk,bkq->biq", jacobian, tmp, jacobian)
131+
132+
def _jacobian_wrt_weight_sandwich_full_to_diag(self, x: Tensor, val: Tensor, tmp: Tensor) -> Tensor:
139133
tmp_diag = torch.diagonal(tmp, dim1=1, dim2=2)
140134
return self._jacobian_wrt_weight_sandwich_diag_to_diag(x, val, tmp_diag)
141-
142-
def _jacobian_wrt_weight_sandwich_diag_to_full(
143-
self, x: Tensor, val: Tensor, tmp_diag: Tensor) -> Tensor:
144-
jacobian = self._jacobian_wrt_weight(x,val)
145-
return torch.einsum('bji,bj,bjq->biq', jacobian, tmp_diag, jacobian)
146135

147-
def _jacobian_wrt_weight_sandwich_diag_to_diag(
148-
self, x: Tensor, val: Tensor, tmp_diag: Tensor) -> Tensor:
136+
def _jacobian_wrt_weight_sandwich_diag_to_full(self, x: Tensor, val: Tensor, tmp_diag: Tensor) -> Tensor:
137+
jacobian = self._jacobian_wrt_weight(x, val)
138+
return torch.einsum("bji,bj,bjq->biq", jacobian, tmp_diag, jacobian)
139+
140+
def _jacobian_wrt_weight_sandwich_diag_to_diag(self, x: Tensor, val: Tensor, tmp_diag: Tensor) -> Tensor:
149141

150142
b, c1 = x.shape
151143
c2 = val.shape[1]
152144

153-
Jt_tmp_J = torch.bmm(tmp_diag.unsqueeze(2), (x**2).unsqueeze(1)).view(b, c1*c2)
145+
Jt_tmp_J = torch.bmm(tmp_diag.unsqueeze(2), (x**2).unsqueeze(1)).view(b, c1 * c2)
154146

155147
if self.bias is not None:
156148
Jt_tmp_J = torch.cat([Jt_tmp_J, tmp_diag], dim=1)
157149

158150
return Jt_tmp_J
159151

160152

161-
162-
163153
class PosLinear(AbstractJacobian, nn.Linear):
164154
def forward(self, x: Tensor):
165155
bias = F.softplus(self.bias) if self.bias is not None else self.bias
@@ -212,45 +202,45 @@ def _jacobian_wrt_input_sandwich_full_to_full(self, x: Tensor, val: Tensor, tmp:
212202
b, c1, h1, w1 = x.shape
213203
c2, h2, w2 = val.shape[1:]
214204

215-
assert c1==c2
205+
assert c1 == c2
216206

217207
weight = torch.ones(1, 1, int(self.scale_factor), int(self.scale_factor), device=x.device)
218208

219-
tmp = tmp.reshape(b, c2, h2*w2, c2, h2*w2)
220-
tmp = tmp.movedim(2,3)
209+
tmp = tmp.reshape(b, c2, h2 * w2, c2, h2 * w2)
210+
tmp = tmp.movedim(2, 3)
221211
tmp_J = F.conv2d(
222-
tmp.reshape(b*c2*c2 * h2*w2, 1, h2, w2),
212+
tmp.reshape(b * c2 * c2 * h2 * w2, 1, h2, w2),
223213
weight=weight,
224214
bias=None,
225215
stride=int(self.scale_factor),
226216
padding=0,
227217
dilation=1,
228218
groups=1,
229-
).reshape(b*c2*c2, h2*w2, h1*w1)
219+
).reshape(b * c2 * c2, h2 * w2, h1 * w1)
230220

231-
Jt_tmpt = tmp_J.movedim(-1,-2)
221+
Jt_tmpt = tmp_J.movedim(-1, -2)
232222

233223
Jt_tmpt_J = F.conv2d(
234-
Jt_tmpt.reshape(b*c2*c2 * h1*w1, 1, h2, w2),
224+
Jt_tmpt.reshape(b * c2 * c2 * h1 * w1, 1, h2, w2),
235225
weight=weight,
236226
bias=None,
237227
stride=int(self.scale_factor),
238228
padding=0,
239229
dilation=1,
240230
groups=1,
241-
).reshape(b*c2*c2, h1*w1, h1*w1)
231+
).reshape(b * c2 * c2, h1 * w1, h1 * w1)
242232

243-
Jt_tmp_J = Jt_tmpt_J.movedim(-1,-2)
233+
Jt_tmp_J = Jt_tmpt_J.movedim(-1, -2)
244234

245-
Jt_tmp_J = Jt_tmp_J.reshape(b, c2, c2, h1*w1, h1*w1)
246-
Jt_tmp_J = Jt_tmp_J.movedim(2,3)
247-
Jt_tmp_J = Jt_tmp_J.reshape(b, c2*h1*w1, c2*h1*w1)
235+
Jt_tmp_J = Jt_tmp_J.reshape(b, c2, c2, h1 * w1, h1 * w1)
236+
Jt_tmp_J = Jt_tmp_J.movedim(2, 3)
237+
Jt_tmp_J = Jt_tmp_J.reshape(b, c2 * h1 * w1, c2 * h1 * w1)
248238

249239
return Jt_tmp_J
250240

251241
def _jacobian_wrt_input_sandwich_full_to_diag(self, x: Tensor, val: Tensor, tmp: Tensor) -> Tensor:
252242
raise NotImplementedError
253-
243+
254244
def _jacobian_wrt_input_sandwich_diag_to_full(self, x: Tensor, val: Tensor, tmp_diag: Tensor) -> Tensor:
255245
raise NotImplementedError
256246

@@ -272,6 +262,7 @@ def _jacobian_wrt_input_sandwich_diag_to_diag(self, x: Tensor, val: Tensor, tmp_
272262

273263
return tmp_diag.reshape(b, c1 * h1 * w1)
274264

265+
275266
class Conv1d(AbstractJacobian, nn.Conv1d):
276267
def _jacobian_wrt_input_mult_left_vec(self, x: Tensor, val: Tensor, jac_in: Tensor) -> Tensor:
277268
b, c1, l1 = x.shape
@@ -668,7 +659,6 @@ def _jacobian_wrt_input_sandwich_diag_to_diag(self, x: Tensor, val: Tensor, tmp_
668659
diag_Jt_tmp_J = output_tmp.reshape(b, c1 * h1 * w1)
669660
return diag_Jt_tmp_J
670661

671-
672662
def _jacobian_wrt_weight_sandwich_full_to_full(self, x: Tensor, val: Tensor, tmp: Tensor) -> Tensor:
673663
return self._jacobian_wrt_weight_mult_left(
674664
x, val, self._jacobian_wrt_weight_T_mult_right(x, val, tmp)
@@ -677,7 +667,7 @@ def _jacobian_wrt_weight_sandwich_full_to_full(self, x: Tensor, val: Tensor, tmp
677667
def _jacobian_wrt_weight_sandwich_full_to_diag(self, x: Tensor, val: Tensor, tmp: Tensor) -> Tensor:
678668
### TODO: Implement this in a smarter way
679669
return torch.diagonal(self._jacobian_wrt_weight_sandwich_full_to_full(x, val, tmp), dim1=1, dim2=2)
680-
670+
681671
def _jacobian_wrt_weight_sandwich_diag_to_full(self, x: Tensor, val: Tensor, tmp_diag: Tensor) -> Tensor:
682672
raise NotImplementedError
683673

@@ -695,12 +685,11 @@ def _jacobian_wrt_weight_sandwich_diag_to_diag(self, x: Tensor, val: Tensor, tmp
695685

696686
# define moving sum for Jt_tmp
697687
output_tmp = torch.zeros(b, c2 * c1 * kernel_h * kernel_w, device=x.device)
688+
flip_squared_input = torch.flip(x, [-3, -2, -1]).movedim(0, 1) ** 2
689+
698690
for i in range(b):
699691
# set the weight to the convolution
700-
input_single_batch = x[i : i + 1, :, :, :]
701-
reversed_input_single_batch = torch.flip(input_single_batch, [-3, -2, -1]).movedim(0, 1)
702-
weigth_sq = reversed_input_single_batch**2
703-
692+
weigth_sq = flip_squared_input[:, i : i + 1, :, :]
704693
input_tmp_single_batch = input_tmp[:, i : i + 1, :, :]
705694

706695
output_tmp_single_batch = (
@@ -976,26 +965,32 @@ def _jacobian_wrt_input_sandwich(
976965
def _jacobian_wrt_input_sandwich_full_to_full(self, x: Tensor, val: Tensor, tmp: Tensor) -> Tensor:
977966
b, c1, h1, w1 = x.shape
978967
c2, h2, w2 = val.shape[1:]
979-
assert c1==c2
968+
assert c1 == c2
980969

981-
tmp = tmp.reshape(b, c1, h2*w2, c1, h2*w2).movedim(-2,-3).reshape(b*c1*c1, h2*w2, h2*w2)
982-
Jt_tmp_J = torch.zeros((b*c1*c1, h1*w1, h1*w1), device=tmp.device)
970+
tmp = tmp.reshape(b, c1, h2 * w2, c1, h2 * w2).movedim(-2, -3).reshape(b * c1 * c1, h2 * w2, h2 * w2)
971+
Jt_tmp_J = torch.zeros((b * c1 * c1, h1 * w1, h1 * w1), device=tmp.device)
983972
# indexes for batch and channel
984-
arange_repeated = torch.repeat_interleave(torch.arange(b*c1*c1), h2*w2 * h2*w2).long()
985-
arange_repeated = arange_repeated.reshape(b*c1*c1, h2*w2, h2*w2)
973+
arange_repeated = torch.repeat_interleave(torch.arange(b * c1 * c1), h2 * w2 * h2 * w2).long()
974+
arange_repeated = arange_repeated.reshape(b * c1 * c1, h2 * w2, h2 * w2)
986975
# indexes for height and width
987-
idx = self.idx.reshape(b, c1, h2 * w2).unsqueeze(2).expand(-1, -1, h2*w2, -1)
988-
idx_col = idx.unsqueeze(1).expand(-1, c1, -1, -1, -1).reshape(b*c1*c1, h2*w2, h2*w2)
989-
idx_row = idx.unsqueeze(2).expand(-1, -1, c1, -1, -1).reshape(b*c1*c1, h2*w2, h2*w2).movedim(-1,-2)
990-
976+
idx = self.idx.reshape(b, c1, h2 * w2).unsqueeze(2).expand(-1, -1, h2 * w2, -1)
977+
idx_col = idx.unsqueeze(1).expand(-1, c1, -1, -1, -1).reshape(b * c1 * c1, h2 * w2, h2 * w2)
978+
idx_row = (
979+
idx.unsqueeze(2).expand(-1, -1, c1, -1, -1).reshape(b * c1 * c1, h2 * w2, h2 * w2).movedim(-1, -2)
980+
)
981+
991982
Jt_tmp_J[arange_repeated, idx_row, idx_col] = tmp
992-
Jt_tmp_J = Jt_tmp_J.reshape(b, c1, c1, h1*w1, h1*w1).movedim(-2,-3).reshape(b, c1*h1*w1, c1*h1*w1)
983+
Jt_tmp_J = (
984+
Jt_tmp_J.reshape(b, c1, c1, h1 * w1, h1 * w1)
985+
.movedim(-2, -3)
986+
.reshape(b, c1 * h1 * w1, c1 * h1 * w1)
987+
)
993988

994989
return Jt_tmp_J
995990

996991
def _jacobian_wrt_input_sandwich_full_to_diag(self, x: Tensor, val: Tensor, tmp: Tensor) -> Tensor:
997992
raise NotImplementedError
998-
993+
999994
def _jacobian_wrt_input_sandwich_diag_to_full(self, x: Tensor, val: Tensor, diag_tmp: Tensor) -> Tensor:
1000995
raise NotImplementedError
1001996

@@ -1125,7 +1120,7 @@ def _jacobian_wrt_input_sandwich_full_to_full(self, x: Tensor, val: Tensor, tmp:
11251120
jac = torch.diag_embed(jac.view(x.shape[0], -1))
11261121
tmp = torch.einsum("bnm,bnj,bjk->bmk", jac, tmp, jac)
11271122
return tmp
1128-
1123+
11291124
def _jacobian_wrt_input_sandwich_full_to_diag(self, x: Tensor, val: Tensor, tmp: Tensor) -> Tensor:
11301125
jac = self._jacobian(x, val)
11311126
jac = torch.diag_embed(jac.view(x.shape[0], -1))

0 commit comments

Comments
 (0)