@@ -142,7 +142,7 @@ def _jacobian_wrt_input_sandwich(self, x: Tensor, val: Tensor, tmp: Tensor, diag
142
142
def _jacobian_wrt_input_diag_sandwich (self , x : Tensor , val : Tensor , diag_tmp : Tensor ) -> Tensor :
143
143
b , c1 , h1 , w1 = x .shape
144
144
c2 , h2 , w2 = val .shape [1 :]
145
-
145
+
146
146
weight = torch .ones (c2 , c1 , int (self .scale_factor ), int (self .scale_factor ), device = x .device )
147
147
148
148
diag_tmp = F .conv2d (
@@ -856,17 +856,36 @@ def _jacobian_wrt_input_diag_sandwich(self, x: Tensor, val: Tensor, diag_tmp: Te
856
856
857
857
new_tmp = torch .zeros_like (x )
858
858
new_tmp = new_tmp .reshape (b * c1 , h1 * w1 )
859
- idx = self .idx .reshape (b * c2 , h2 * w2 )
859
+
860
+ # indexes for batch and channel
860
861
arange_repeated = torch .repeat_interleave (torch .arange (b * c1 ), h2 * w2 ).long ()
861
862
arange_repeated = arange_repeated .reshape (b * c2 , h2 * w2 )
863
+ # indexes for height and width
864
+ idx = self .idx .reshape (b * c2 , h2 * w2 )
862
865
863
866
new_tmp [arange_repeated , idx ] = diag_tmp .reshape (b * c2 , h2 * w2 )
864
867
865
868
return new_tmp .reshape (b , c1 * h1 * w1 )
866
869
867
870
def _jacobian_wrt_input_full_sandwich (self , x : Tensor , val : Tensor , tmp : Tensor ) -> Tensor :
871
+ b , c1 , h1 , w1 = x .shape
872
+ c2 , h2 , w2 = val .shape [1 :]
873
+ assert c1 == c2
868
874
869
- return tmp
875
+ tmp = tmp .reshape (b , c1 , h2 * w2 , c1 , h2 * w2 ).movedim (- 2 ,- 3 ).reshape (b * c1 * c1 , h2 * w2 , h2 * w2 )
876
+ Jt_tmp_J = torch .zeros ((b * c1 * c1 , h1 * w1 , h1 * w1 ))
877
+ # indexes for batch and channel
878
+ arange_repeated = torch .repeat_interleave (torch .arange (b * c1 * c1 ), h2 * w2 * h2 * w2 ).long ()
879
+ arange_repeated = arange_repeated .reshape (b * c1 * c1 , h2 * w2 , h2 * w2 )
880
+ # indexes for height and width
881
+ idx = self .idx .reshape (b , c1 , h2 * w2 ).unsqueeze (2 ).expand (- 1 , - 1 , h2 * w2 , - 1 )
882
+ idx_col = idx .unsqueeze (1 ).expand (- 1 , c1 , - 1 , - 1 , - 1 ).reshape (b * c1 * c1 , h2 * w2 , h2 * w2 )
883
+ idx_row = idx .unsqueeze (2 ).expand (- 1 , - 1 , c1 , - 1 , - 1 ).reshape (b * c1 * c1 , h2 * w2 , h2 * w2 ).movedim (- 1 ,- 2 )
884
+
885
+ Jt_tmp_J [arange_repeated , idx_row , idx_col ] = tmp
886
+ Jt_tmp_J = Jt_tmp_J .reshape (b , c1 , c1 , h1 * w1 , h1 * w1 ).movedim (- 2 ,- 3 ).reshape (b , c1 * h1 * w1 , c1 * h1 * w1 )
887
+
888
+ return Jt_tmp_J
870
889
871
890
872
891
class MaxPool3d (AbstractJacobian , nn .MaxPool3d ):
0 commit comments