@@ -79,16 +79,16 @@ def _jacobian_wrt_input_transpose_mult_left_vec(self, x: Tensor, val: Tensor, ja
79
79
80
80
def _jacobian_wrt_input (self , x : Tensor , val : Tensor ) -> Tensor :
81
81
return self .weight
82
-
82
+
83
83
def _jacobian_wrt_weight (self , x : Tensor , val : Tensor ) -> Tensor :
84
84
b , c1 = x .shape
85
85
c2 = val .shape [1 ]
86
86
out_identity = torch .diag_embed (torch .ones (c2 , device = x .device ))
87
- jacobian = torch .einsum (' bk,ij->bijk' , x , out_identity ).reshape (b ,c2 ,c2 * c1 )
87
+ jacobian = torch .einsum (" bk,ij->bijk" , x , out_identity ).reshape (b , c2 , c2 * c1 )
88
88
if self .bias is not None :
89
- jacobian = torch .cat ([jacobian , out_identity .unsqueeze (0 ).expand (b ,- 1 ,- 1 )], dim = 2 )
89
+ jacobian = torch .cat ([jacobian , out_identity .unsqueeze (0 ).expand (b , - 1 , - 1 )], dim = 2 )
90
90
return jacobian
91
-
91
+
92
92
def _jacobian_wrt_input_sandwich (
93
93
self , x : Tensor , val : Tensor , tmp : Tensor , diag_inp : bool = False , diag_out : bool = False
94
94
) -> Tensor :
@@ -112,54 +112,44 @@ def _jacobian_wrt_weight_sandwich(
112
112
return self ._jacobian_wrt_weight_sandwich_diag_to_full (x , val , tmp )
113
113
elif diag_inp and diag_out :
114
114
return self ._jacobian_wrt_weight_sandwich_diag_to_diag (x , val , tmp )
115
-
116
- def _jacobian_wrt_input_sandwich_full_to_full (
117
- self , x : Tensor , val : Tensor , tmp : Tensor ) -> Tensor :
115
+
116
+ def _jacobian_wrt_input_sandwich_full_to_full (self , x : Tensor , val : Tensor , tmp : Tensor ) -> Tensor :
118
117
return torch .einsum ("nm,bnj,jk->bmk" , self .weight , tmp , self .weight )
119
118
120
- def _jacobian_wrt_input_sandwich_full_to_diag (
121
- self , x : Tensor , val : Tensor , tmp : Tensor ) -> Tensor :
119
+ def _jacobian_wrt_input_sandwich_full_to_diag (self , x : Tensor , val : Tensor , tmp : Tensor ) -> Tensor :
122
120
return torch .einsum ("nm,bnj,jm->bm" , self .weight , tmp , self .weight )
123
121
124
- def _jacobian_wrt_input_sandwich_diag_to_full (
125
- self , x : Tensor , val : Tensor , tmp_diag : Tensor ) -> Tensor :
122
+ def _jacobian_wrt_input_sandwich_diag_to_full (self , x : Tensor , val : Tensor , tmp_diag : Tensor ) -> Tensor :
126
123
return torch .einsum ("nm,bn,nk->bmk" , self .weight , tmp_diag , self .weight )
127
124
128
- def _jacobian_wrt_input_sandwich_diag_to_diag (
129
- self , x : Tensor , val : Tensor , tmp_diag : Tensor ) -> Tensor :
125
+ def _jacobian_wrt_input_sandwich_diag_to_diag (self , x : Tensor , val : Tensor , tmp_diag : Tensor ) -> Tensor :
130
126
return torch .einsum ("nm,bn,nm->bm" , self .weight , tmp_diag , self .weight )
131
-
132
- def _jacobian_wrt_weight_sandwich_full_to_full (
133
- self , x : Tensor , val : Tensor , tmp : Tensor ) -> Tensor :
134
- jacobian = self ._jacobian_wrt_weight (x ,val )
135
- return torch .einsum ('bji,bjk,bkq->biq' , jacobian , tmp , jacobian )
136
-
137
- def _jacobian_wrt_weight_sandwich_full_to_diag (
138
- self , x : Tensor , val : Tensor , tmp : Tensor ) -> Tensor :
127
+
128
+ def _jacobian_wrt_weight_sandwich_full_to_full (self , x : Tensor , val : Tensor , tmp : Tensor ) -> Tensor :
129
+ jacobian = self ._jacobian_wrt_weight (x , val )
130
+ return torch .einsum ("bji,bjk,bkq->biq" , jacobian , tmp , jacobian )
131
+
132
+ def _jacobian_wrt_weight_sandwich_full_to_diag (self , x : Tensor , val : Tensor , tmp : Tensor ) -> Tensor :
139
133
tmp_diag = torch .diagonal (tmp , dim1 = 1 , dim2 = 2 )
140
134
return self ._jacobian_wrt_weight_sandwich_diag_to_diag (x , val , tmp_diag )
141
-
142
- def _jacobian_wrt_weight_sandwich_diag_to_full (
143
- self , x : Tensor , val : Tensor , tmp_diag : Tensor ) -> Tensor :
144
- jacobian = self ._jacobian_wrt_weight (x ,val )
145
- return torch .einsum ('bji,bj,bjq->biq' , jacobian , tmp_diag , jacobian )
146
135
147
- def _jacobian_wrt_weight_sandwich_diag_to_diag (
148
- self , x : Tensor , val : Tensor , tmp_diag : Tensor ) -> Tensor :
136
+ def _jacobian_wrt_weight_sandwich_diag_to_full (self , x : Tensor , val : Tensor , tmp_diag : Tensor ) -> Tensor :
137
+ jacobian = self ._jacobian_wrt_weight (x , val )
138
+ return torch .einsum ("bji,bj,bjq->biq" , jacobian , tmp_diag , jacobian )
139
+
140
+ def _jacobian_wrt_weight_sandwich_diag_to_diag (self , x : Tensor , val : Tensor , tmp_diag : Tensor ) -> Tensor :
149
141
150
142
b , c1 = x .shape
151
143
c2 = val .shape [1 ]
152
144
153
- Jt_tmp_J = torch .bmm (tmp_diag .unsqueeze (2 ), (x ** 2 ).unsqueeze (1 )).view (b , c1 * c2 )
145
+ Jt_tmp_J = torch .bmm (tmp_diag .unsqueeze (2 ), (x ** 2 ).unsqueeze (1 )).view (b , c1 * c2 )
154
146
155
147
if self .bias is not None :
156
148
Jt_tmp_J = torch .cat ([Jt_tmp_J , tmp_diag ], dim = 1 )
157
149
158
150
return Jt_tmp_J
159
151
160
152
161
-
162
-
163
153
class PosLinear (AbstractJacobian , nn .Linear ):
164
154
def forward (self , x : Tensor ):
165
155
bias = F .softplus (self .bias ) if self .bias is not None else self .bias
@@ -212,45 +202,45 @@ def _jacobian_wrt_input_sandwich_full_to_full(self, x: Tensor, val: Tensor, tmp:
212
202
b , c1 , h1 , w1 = x .shape
213
203
c2 , h2 , w2 = val .shape [1 :]
214
204
215
- assert c1 == c2
205
+ assert c1 == c2
216
206
217
207
weight = torch .ones (1 , 1 , int (self .scale_factor ), int (self .scale_factor ), device = x .device )
218
208
219
- tmp = tmp .reshape (b , c2 , h2 * w2 , c2 , h2 * w2 )
220
- tmp = tmp .movedim (2 ,3 )
209
+ tmp = tmp .reshape (b , c2 , h2 * w2 , c2 , h2 * w2 )
210
+ tmp = tmp .movedim (2 , 3 )
221
211
tmp_J = F .conv2d (
222
- tmp .reshape (b * c2 * c2 * h2 * w2 , 1 , h2 , w2 ),
212
+ tmp .reshape (b * c2 * c2 * h2 * w2 , 1 , h2 , w2 ),
223
213
weight = weight ,
224
214
bias = None ,
225
215
stride = int (self .scale_factor ),
226
216
padding = 0 ,
227
217
dilation = 1 ,
228
218
groups = 1 ,
229
- ).reshape (b * c2 * c2 , h2 * w2 , h1 * w1 )
219
+ ).reshape (b * c2 * c2 , h2 * w2 , h1 * w1 )
230
220
231
- Jt_tmpt = tmp_J .movedim (- 1 ,- 2 )
221
+ Jt_tmpt = tmp_J .movedim (- 1 , - 2 )
232
222
233
223
Jt_tmpt_J = F .conv2d (
234
- Jt_tmpt .reshape (b * c2 * c2 * h1 * w1 , 1 , h2 , w2 ),
224
+ Jt_tmpt .reshape (b * c2 * c2 * h1 * w1 , 1 , h2 , w2 ),
235
225
weight = weight ,
236
226
bias = None ,
237
227
stride = int (self .scale_factor ),
238
228
padding = 0 ,
239
229
dilation = 1 ,
240
230
groups = 1 ,
241
- ).reshape (b * c2 * c2 , h1 * w1 , h1 * w1 )
231
+ ).reshape (b * c2 * c2 , h1 * w1 , h1 * w1 )
242
232
243
- Jt_tmp_J = Jt_tmpt_J .movedim (- 1 ,- 2 )
233
+ Jt_tmp_J = Jt_tmpt_J .movedim (- 1 , - 2 )
244
234
245
- Jt_tmp_J = Jt_tmp_J .reshape (b , c2 , c2 , h1 * w1 , h1 * w1 )
246
- Jt_tmp_J = Jt_tmp_J .movedim (2 ,3 )
247
- Jt_tmp_J = Jt_tmp_J .reshape (b , c2 * h1 * w1 , c2 * h1 * w1 )
235
+ Jt_tmp_J = Jt_tmp_J .reshape (b , c2 , c2 , h1 * w1 , h1 * w1 )
236
+ Jt_tmp_J = Jt_tmp_J .movedim (2 , 3 )
237
+ Jt_tmp_J = Jt_tmp_J .reshape (b , c2 * h1 * w1 , c2 * h1 * w1 )
248
238
249
239
return Jt_tmp_J
250
240
251
241
def _jacobian_wrt_input_sandwich_full_to_diag (self , x : Tensor , val : Tensor , tmp : Tensor ) -> Tensor :
252
242
raise NotImplementedError
253
-
243
+
254
244
def _jacobian_wrt_input_sandwich_diag_to_full (self , x : Tensor , val : Tensor , tmp_diag : Tensor ) -> Tensor :
255
245
raise NotImplementedError
256
246
@@ -272,6 +262,7 @@ def _jacobian_wrt_input_sandwich_diag_to_diag(self, x: Tensor, val: Tensor, tmp_
272
262
273
263
return tmp_diag .reshape (b , c1 * h1 * w1 )
274
264
265
+
275
266
class Conv1d (AbstractJacobian , nn .Conv1d ):
276
267
def _jacobian_wrt_input_mult_left_vec (self , x : Tensor , val : Tensor , jac_in : Tensor ) -> Tensor :
277
268
b , c1 , l1 = x .shape
@@ -668,7 +659,6 @@ def _jacobian_wrt_input_sandwich_diag_to_diag(self, x: Tensor, val: Tensor, tmp_
668
659
diag_Jt_tmp_J = output_tmp .reshape (b , c1 * h1 * w1 )
669
660
return diag_Jt_tmp_J
670
661
671
-
672
662
def _jacobian_wrt_weight_sandwich_full_to_full (self , x : Tensor , val : Tensor , tmp : Tensor ) -> Tensor :
673
663
return self ._jacobian_wrt_weight_mult_left (
674
664
x , val , self ._jacobian_wrt_weight_T_mult_right (x , val , tmp )
@@ -677,7 +667,7 @@ def _jacobian_wrt_weight_sandwich_full_to_full(self, x: Tensor, val: Tensor, tmp
677
667
def _jacobian_wrt_weight_sandwich_full_to_diag (self , x : Tensor , val : Tensor , tmp : Tensor ) -> Tensor :
678
668
### TODO: Implement this in a smarter way
679
669
return torch .diagonal (self ._jacobian_wrt_weight_sandwich_full_to_full (x , val , tmp ), dim1 = 1 , dim2 = 2 )
680
-
670
+
681
671
def _jacobian_wrt_weight_sandwich_diag_to_full (self , x : Tensor , val : Tensor , tmp_diag : Tensor ) -> Tensor :
682
672
raise NotImplementedError
683
673
@@ -695,12 +685,11 @@ def _jacobian_wrt_weight_sandwich_diag_to_diag(self, x: Tensor, val: Tensor, tmp
695
685
696
686
# define moving sum for Jt_tmp
697
687
output_tmp = torch .zeros (b , c2 * c1 * kernel_h * kernel_w , device = x .device )
688
+ flip_squared_input = torch .flip (x , [- 3 , - 2 , - 1 ]).movedim (0 , 1 ) ** 2
689
+
698
690
for i in range (b ):
699
691
# set the weight to the convolution
700
- input_single_batch = x [i : i + 1 , :, :, :]
701
- reversed_input_single_batch = torch .flip (input_single_batch , [- 3 , - 2 , - 1 ]).movedim (0 , 1 )
702
- weigth_sq = reversed_input_single_batch ** 2
703
-
692
+ weigth_sq = flip_squared_input [:, i : i + 1 , :, :]
704
693
input_tmp_single_batch = input_tmp [:, i : i + 1 , :, :]
705
694
706
695
output_tmp_single_batch = (
@@ -976,26 +965,32 @@ def _jacobian_wrt_input_sandwich(
976
965
def _jacobian_wrt_input_sandwich_full_to_full (self , x : Tensor , val : Tensor , tmp : Tensor ) -> Tensor :
977
966
b , c1 , h1 , w1 = x .shape
978
967
c2 , h2 , w2 = val .shape [1 :]
979
- assert c1 == c2
968
+ assert c1 == c2
980
969
981
- tmp = tmp .reshape (b , c1 , h2 * w2 , c1 , h2 * w2 ).movedim (- 2 ,- 3 ).reshape (b * c1 * c1 , h2 * w2 , h2 * w2 )
982
- Jt_tmp_J = torch .zeros ((b * c1 * c1 , h1 * w1 , h1 * w1 ), device = tmp .device )
970
+ tmp = tmp .reshape (b , c1 , h2 * w2 , c1 , h2 * w2 ).movedim (- 2 , - 3 ).reshape (b * c1 * c1 , h2 * w2 , h2 * w2 )
971
+ Jt_tmp_J = torch .zeros ((b * c1 * c1 , h1 * w1 , h1 * w1 ), device = tmp .device )
983
972
# indexes for batch and channel
984
- arange_repeated = torch .repeat_interleave (torch .arange (b * c1 * c1 ), h2 * w2 * h2 * w2 ).long ()
985
- arange_repeated = arange_repeated .reshape (b * c1 * c1 , h2 * w2 , h2 * w2 )
973
+ arange_repeated = torch .repeat_interleave (torch .arange (b * c1 * c1 ), h2 * w2 * h2 * w2 ).long ()
974
+ arange_repeated = arange_repeated .reshape (b * c1 * c1 , h2 * w2 , h2 * w2 )
986
975
# indexes for height and width
987
- idx = self .idx .reshape (b , c1 , h2 * w2 ).unsqueeze (2 ).expand (- 1 , - 1 , h2 * w2 , - 1 )
988
- idx_col = idx .unsqueeze (1 ).expand (- 1 , c1 , - 1 , - 1 , - 1 ).reshape (b * c1 * c1 , h2 * w2 , h2 * w2 )
989
- idx_row = idx .unsqueeze (2 ).expand (- 1 , - 1 , c1 , - 1 , - 1 ).reshape (b * c1 * c1 , h2 * w2 , h2 * w2 ).movedim (- 1 ,- 2 )
990
-
976
+ idx = self .idx .reshape (b , c1 , h2 * w2 ).unsqueeze (2 ).expand (- 1 , - 1 , h2 * w2 , - 1 )
977
+ idx_col = idx .unsqueeze (1 ).expand (- 1 , c1 , - 1 , - 1 , - 1 ).reshape (b * c1 * c1 , h2 * w2 , h2 * w2 )
978
+ idx_row = (
979
+ idx .unsqueeze (2 ).expand (- 1 , - 1 , c1 , - 1 , - 1 ).reshape (b * c1 * c1 , h2 * w2 , h2 * w2 ).movedim (- 1 , - 2 )
980
+ )
981
+
991
982
Jt_tmp_J [arange_repeated , idx_row , idx_col ] = tmp
992
- Jt_tmp_J = Jt_tmp_J .reshape (b , c1 , c1 , h1 * w1 , h1 * w1 ).movedim (- 2 ,- 3 ).reshape (b , c1 * h1 * w1 , c1 * h1 * w1 )
983
+ Jt_tmp_J = (
984
+ Jt_tmp_J .reshape (b , c1 , c1 , h1 * w1 , h1 * w1 )
985
+ .movedim (- 2 , - 3 )
986
+ .reshape (b , c1 * h1 * w1 , c1 * h1 * w1 )
987
+ )
993
988
994
989
return Jt_tmp_J
995
990
996
991
def _jacobian_wrt_input_sandwich_full_to_diag (self , x : Tensor , val : Tensor , tmp : Tensor ) -> Tensor :
997
992
raise NotImplementedError
998
-
993
+
999
994
def _jacobian_wrt_input_sandwich_diag_to_full (self , x : Tensor , val : Tensor , diag_tmp : Tensor ) -> Tensor :
1000
995
raise NotImplementedError
1001
996
@@ -1125,7 +1120,7 @@ def _jacobian_wrt_input_sandwich_full_to_full(self, x: Tensor, val: Tensor, tmp:
1125
1120
jac = torch .diag_embed (jac .view (x .shape [0 ], - 1 ))
1126
1121
tmp = torch .einsum ("bnm,bnj,bjk->bmk" , jac , tmp , jac )
1127
1122
return tmp
1128
-
1123
+
1129
1124
def _jacobian_wrt_input_sandwich_full_to_diag (self , x : Tensor , val : Tensor , tmp : Tensor ) -> Tensor :
1130
1125
jac = self ._jacobian (x , val )
1131
1126
jac = torch .diag_embed (jac .view (x .shape [0 ], - 1 ))
0 commit comments