@@ -76,10 +76,8 @@ def _jacobian_wrt_input_mult_left_vec(self, x: Tensor, val: Tensor, jac_in: Tens
76
76
def _jacobian_wrt_input_transpose_mult_left_vec (self , x : Tensor , val : Tensor , jac_in : Tensor ) -> Tensor :
77
77
return F .linear (jac_in .movedim (1 , - 1 ), self .weight .T , bias = None ).movedim (- 1 , 1 )
78
78
79
- def _sandwich_full_wrt_input (self , x : Tensor , val : Tensor , tmp : Tensor ) -> Tensor :
80
- return torch .einsum ("nm,bnj,jk->bmk" , self .weight , tmp , self .weight )
81
-
82
- def _sandwich_full_wrt_weight (self , x : Tensor , val : Tensor , tmp : Tensor ) -> Tensor :
79
+ def _jacobian_wrt_weight_sandwich (self , x : Tensor , val : Tensor , tmp : Tensor , diag : bool = False ) -> Tensor :
80
+
83
81
b , c = x .shape
84
82
diag_elements = torch .diagonal (tmp , dim1 = 1 , dim2 = 2 )
85
83
feat_k2 = (x ** 2 ).unsqueeze (1 )
@@ -92,6 +90,9 @@ def _sandwich_full_wrt_weight(self, x: Tensor, val: Tensor, tmp: Tensor) -> Tens
92
90
93
91
return h_k
94
92
93
+ def _jacobian_wrt_input_sandwich (self , x : Tensor , val : Tensor , tmp : Tensor , diag : bool = False ) -> Tensor :
94
+ return torch .einsum ("nm,bnj,jk->bmk" , self .weight , tmp , self .weight )
95
+
95
96
96
97
class PosLinear (AbstractJacobian , nn .Linear ):
97
98
def forward (self , x : Tensor ):
@@ -123,6 +124,14 @@ def _jacobian_wrt_input_mult_left_vec(self, x: Tensor, val: Tensor, jac_in: Tens
123
124
.movedim (dims2 , dims1 )
124
125
)
125
126
127
+ def _jacobian_wrt_weight_sandwich (self , x : Tensor , val : Tensor , tmp : Tensor , diag : bool = False ) -> Tensor :
128
+ # non parametric, so return empty
129
+ return []
130
+
131
+ def _jacobian_wrt_input_sandwich (self , x : Tensor , val : Tensor , tmp : Tensor , diag : bool = False ) -> Tensor :
132
+ raise NotImplementedError
133
+
134
+
126
135
127
136
class Conv1d (AbstractJacobian , nn .Conv1d ):
128
137
def _jacobian_wrt_input_mult_left_vec (self , x : Tensor , val : Tensor , jac_in : Tensor ) -> Tensor :
@@ -358,7 +367,7 @@ def _jacobian_wrt_weight_T_mult_right(
358
367
dilation = self .dilation ,
359
368
groups = self .groups ,
360
369
)
361
- .reshape (b , * tmp_single_batch .shape [4 :], c1 , kernel_h , kernel_w )
370
+ .reshape (c2 , * tmp_single_batch .shape [4 :], c1 , kernel_h , kernel_w )
362
371
.movedim ((- 3 , - 2 , - 1 ), (1 , 2 , 3 ))
363
372
)
364
373
@@ -381,7 +390,7 @@ def _jacobian_wrt_weight_T_mult_right(
381
390
dilation = self .dilation ,
382
391
groups = self .groups ,
383
392
)
384
- .reshape (b , * tmp .shape [4 :], c1 , kernel_h , kernel_w )
393
+ .reshape (c2 , * tmp .shape [4 :], c1 , kernel_h , kernel_w )
385
394
.movedim ((- 3 , - 2 , - 1 ), (1 , 2 , 3 ))
386
395
)
387
396
@@ -390,6 +399,18 @@ def _jacobian_wrt_weight_T_mult_right(
390
399
391
400
return Jt_tmp
392
401
402
+ def _jacobian_wrt_input_sandwich (self , x : Tensor , val : Tensor , tmp : Tensor , diag : bool = False ) -> Tensor :
403
+ if diag :
404
+ return self ._jacobian_wrt_input_diag_sandwich (x , val , tmp )
405
+ else :
406
+ return self ._jacobian_wrt_input_full_sandwich (x , val , tmp )
407
+
408
+ def _jacobian_wrt_weight_sandwich (self , x : Tensor , val : Tensor , tmp : Tensor , diag : bool = False ) -> Tensor :
409
+ if diag :
410
+ return self ._jacobian_wrt_weight_diag_sandwich (x , val , tmp )
411
+ else :
412
+ return self ._jacobian_wrt_weight_full_sandwich (x , val , tmp )
413
+
393
414
def _jacobian_wrt_input_full_sandwich (self , x : Tensor , val : Tensor , tmp : Tensor ) -> Tensor :
394
415
return self ._jacobian_wrt_input_mult_left (x , val , self ._jacobian_wrt_input_T_mult_right (x , val , tmp ))
395
416
@@ -662,14 +683,21 @@ def _jacobian_wrt_input_mult_left_vec(self, x: Tensor, val: Tensor, jac_in: Tens
662
683
jac_in = jac_in [arange_repeated , idx , :, :, :].reshape (* val .shape , * jac_in_orig_shape [4 :])
663
684
return jac_in
664
685
665
- def _sandwich_full_wrt_input (self , x : Tensor , val : Tensor , tmp : Tensor ) -> Tensor :
686
+ def _jacobian_wrt_weight_sandwich (self , x : Tensor , val : Tensor , tmp : Tensor , diag : bool = False ) -> Tensor :
687
+ # non parametric, so return empty
688
+ return []
689
+
690
+ def _jacobian_wrt_input_sandwich (self , x : Tensor , val : Tensor , tmp : Tensor , diag : bool = False ) -> Tensor :
691
+ return self ._jacobian_wrt_input_full_sandwich (x , val , tmp )
692
+
693
+ def _jacobian_wrt_input_full_sandwich (self , x : Tensor , val : Tensor , tmp : Tensor ) -> Tensor :
666
694
667
695
new_tmp = torch .zeros_like (x )
668
696
new_tmp [self .idx ] = tmp
669
697
670
698
return new_tmp
671
699
672
- def _sandwich_diag_wrt_input (self , x : Tensor , val : Tensor , tmp : Tensor ) -> Tensor :
700
+ def _jacobian_wrt_input_diag_sandwich (self , x : Tensor , val : Tensor , tmp : Tensor ) -> Tensor :
673
701
pass
674
702
675
703
@@ -758,15 +786,25 @@ def _jacobian(self, x: Tensor, val: Tensor) -> Tensor:
758
786
jac = 1.0 - val ** 2
759
787
return jac
760
788
761
- def _sandwich_full_wrt_input (self , x : Tensor , val : Tensor , tmp : Tensor ) -> Tensor :
789
+ def _jacobian_wrt_weight_sandwich (self , x : Tensor , val : Tensor , tmp : Tensor , diag : bool = False ) -> Tensor :
790
+ # non parametric, so return empty
791
+ return []
792
+
793
+ def _jacobian_wrt_input_sandwich (self , x : Tensor , val : Tensor , tmp : Tensor , diag : bool = False ) -> Tensor :
794
+ if diag :
795
+ return self ._jacobian_wrt_input_diag_sandwich (x , val , tmp )
796
+ else :
797
+ return self ._jacobian_wrt_input_full_sandwich (x , val , tmp )
798
+
799
+ def _jacobian_wrt_input_full_sandwich (self , x : Tensor , val : Tensor , tmp : Tensor ) -> Tensor :
762
800
763
801
jac = self ._jacobian (x , val )
764
802
jac = torch .diag_embed (jac .view (x .shape [0 ], - 1 ))
765
803
tmp = torch .einsum ("bnm,bnj,bjk->bmk" , jac , tmp , jac )
766
804
767
805
return tmp
768
806
769
- def _sandwich_diag_wrt_input (self , x : Tensor , val : Tensor , tmp : Tensor ) -> Tensor :
807
+ def _jacobian_wrt_input_diag_sandwich (self , x : Tensor , val : Tensor , tmp : Tensor ) -> Tensor :
770
808
771
809
jac = self ._jacobian (x , val )
772
810
jac = jac .view (x .shape [0 ], - 1 )
0 commit comments