@@ -134,14 +134,19 @@ def _jacobian_wrt_weight_sandwich(
134
134
return None
135
135
136
136
def _jacobian_wrt_input_sandwich (self , x : Tensor , val : Tensor , tmp : Tensor , diag : bool = False ) -> Tensor :
137
+ if diag :
138
+ return self ._jacobian_wrt_input_diag_sandwich (x , val , tmp )
139
+ else :
140
+ return self ._jacobian_wrt_input_full_sandwich (x , val , tmp )
137
141
142
+ def _jacobian_wrt_input_diag_sandwich (self , x : Tensor , val : Tensor , diag_tmp : Tensor ) -> Tensor :
138
143
b , c1 , h1 , w1 = x .shape
139
144
c2 , h2 , w2 = val .shape [1 :]
140
145
141
146
weight = torch .ones (c2 , c1 , int (self .scale_factor ), int (self .scale_factor ), device = x .device )
142
147
143
- tmp = F .conv2d (
144
- tmp .reshape (- 1 , c2 , h2 , w2 ),
148
+ diag_tmp = F .conv2d (
149
+ diag_tmp .reshape (- 1 , c2 , h2 , w2 ),
145
150
weight = weight ,
146
151
bias = None ,
147
152
stride = int (self .scale_factor ),
@@ -150,8 +155,47 @@ def _jacobian_wrt_input_sandwich(self, x: Tensor, val: Tensor, tmp: Tensor, diag
150
155
groups = 1 ,
151
156
)
152
157
153
- return tmp .reshape (b , c1 * h1 * w1 )
158
+ return diag_tmp .reshape (b , c1 * h1 * w1 )
159
+
160
+ def _jacobian_wrt_input_full_sandwich (self , x : Tensor , val : Tensor , tmp : Tensor ) -> Tensor :
161
+ b , c1 , h1 , w1 = x .shape
162
+ c2 , h2 , w2 = val .shape [1 :]
163
+
164
+ assert c1 == c2
165
+
166
+ weight = torch .ones (1 , 1 , int (self .scale_factor ), int (self .scale_factor ), device = x .device )
167
+
168
+ tmp = tmp .reshape (b , c2 , h2 * w2 , c2 , h2 * w2 )
169
+ tmp = tmp .movedim (2 ,3 )
170
+ tmp_J = F .conv2d (
171
+ tmp .reshape (b * c2 * c2 * h2 * w2 , 1 , h2 , w2 ),
172
+ weight = weight ,
173
+ bias = None ,
174
+ stride = int (self .scale_factor ),
175
+ padding = 0 ,
176
+ dilation = 1 ,
177
+ groups = 1 ,
178
+ ).reshape (b * c2 * c2 , h2 * w2 , h1 * w1 )
179
+
180
+ Jt_tmpt = tmp_J .movedim (- 1 ,- 2 )
181
+
182
+ Jt_tmpt_J = F .conv2d (
183
+ Jt_tmpt .reshape (b * c2 * c2 * h1 * w1 , 1 , h2 , w2 ),
184
+ weight = weight ,
185
+ bias = None ,
186
+ stride = int (self .scale_factor ),
187
+ padding = 0 ,
188
+ dilation = 1 ,
189
+ groups = 1 ,
190
+ ).reshape (b * c2 * c2 , h1 * w1 , h1 * w1 )
191
+
192
+ Jt_tmp_J = Jt_tmpt_J .movedim (- 1 ,- 2 )
193
+
194
+ Jt_tmp_J = Jt_tmp_J .reshape (b , c2 , c2 , h1 * w1 , h1 * w1 )
195
+ Jt_tmp_J = Jt_tmp_J .movedim (2 ,3 )
196
+ Jt_tmp_J = Jt_tmp_J .reshape (b , c2 * h1 * w1 , c2 * h1 * w1 )
154
197
198
+ return Jt_tmp_J
155
199
156
200
class Conv1d (AbstractJacobian , nn .Conv1d ):
157
201
def _jacobian_wrt_input_mult_left_vec (self , x : Tensor , val : Tensor , jac_in : Tensor ) -> Tensor :
@@ -490,7 +534,7 @@ def _jacobian_wrt_weight_mult_left(
490
534
# transpose
491
535
tmp_J = Jt_tmptt_cols .movedim (0 , 1 )
492
536
493
- return tmp
537
+ return tmp_J
494
538
495
539
def _jacobian_wrt_input_sandwich (self , x : Tensor , val : Tensor , tmp : Tensor , diag : bool = False ) -> Tensor :
496
540
if diag :
@@ -801,9 +845,12 @@ def _jacobian_wrt_weight_sandwich(
801
845
return None
802
846
803
847
def _jacobian_wrt_input_sandwich (self , x : Tensor , val : Tensor , tmp : Tensor , diag : bool = False ) -> Tensor :
804
- return self ._jacobian_wrt_input_diag_sandwich (x , val , tmp )
848
+ if diag :
849
+ return self ._jacobian_wrt_input_diag_sandwich (x , val , tmp )
850
+ else :
851
+ return self ._jacobian_wrt_input_full_sandwich (x , val , tmp )
805
852
806
- def _jacobian_wrt_input_diag_sandwich (self , x : Tensor , val : Tensor , tmp : Tensor ) -> Tensor :
853
+ def _jacobian_wrt_input_diag_sandwich (self , x : Tensor , val : Tensor , diag_tmp : Tensor ) -> Tensor :
807
854
b , c1 , h1 , w1 = x .shape
808
855
c2 , h2 , w2 = val .shape [1 :]
809
856
@@ -813,10 +860,14 @@ def _jacobian_wrt_input_diag_sandwich(self, x: Tensor, val: Tensor, tmp: Tensor)
813
860
arange_repeated = torch .repeat_interleave (torch .arange (b * c1 ), h2 * w2 ).long ()
814
861
arange_repeated = arange_repeated .reshape (b * c2 , h2 * w2 )
815
862
816
- new_tmp [arange_repeated , idx ] = tmp .reshape (b * c2 , h2 * w2 )
863
+ new_tmp [arange_repeated , idx ] = diag_tmp .reshape (b * c2 , h2 * w2 )
817
864
818
865
return new_tmp .reshape (b , c1 * h1 * w1 )
819
866
867
+ def _jacobian_wrt_input_full_sandwich (self , x : Tensor , val : Tensor , tmp : Tensor ) -> Tensor :
868
+
869
+ return tmp
870
+
820
871
821
872
class MaxPool3d (AbstractJacobian , nn .MaxPool3d ):
822
873
def forward (self , input : Tensor ):
0 commit comments