@@ -76,6 +76,18 @@ def _jacobian_wrt_input_mult_left_vec(self, x: Tensor, val: Tensor, jac_in: Tens
76
76
77
77
def _jacobian_wrt_input_transpose_mult_left_vec (self , x : Tensor , val : Tensor , jac_in : Tensor ) -> Tensor :
78
78
return F .linear (jac_in .movedim (1 , - 1 ), self .weight .T , bias = None ).movedim (- 1 , 1 )
79
+
80
+ def _jacobian_wrt_input (self , x : Tensor , val : Tensor ) -> Tensor :
81
+ return self .weight
82
+
83
+ def _jacobian_wrt_weight (self , x : Tensor , val : Tensor ) -> Tensor :
84
+ b , c1 = x .shape
85
+ c2 = val .shape [1 ]
86
+ out_identity = torch .diag_embed (torch .ones (c2 ))
87
+ jacobian = torch .einsum ('bk,ij->bijk' , x , out_identity ).reshape (b ,c2 ,c2 * c1 )
88
+ if self .bias is not None :
89
+ jacobian = torch .cat ([jacobian , out_identity .unsqueeze (0 ).expand (b ,- 1 ,- 1 )], dim = 2 )
90
+ return jacobian
79
91
80
92
def _jacobian_wrt_input_sandwich (
81
93
self , x : Tensor , val : Tensor , tmp : Tensor , diag_inp : bool = False , diag_out : bool = False
@@ -119,30 +131,31 @@ def _jacobian_wrt_input_sandwich_diag_to_diag(
119
131
120
132
def _jacobian_wrt_weight_sandwich_full_to_full (
121
133
self , x : Tensor , val : Tensor , tmp : Tensor ) -> Tensor :
122
- raise NotImplementedError
134
+ jacobian = self ._jacobian_wrt_weight (x ,val )
135
+ return torch .einsum ('bji,bjk,bkq->biq' , jacobian , tmp , jacobian )
123
136
124
137
def _jacobian_wrt_weight_sandwich_full_to_diag (
125
138
self , x : Tensor , val : Tensor , tmp : Tensor ) -> Tensor :
126
-
127
- b , c = x .shape
128
- diag_elements = torch .diagonal (tmp , dim1 = 1 , dim2 = 2 )
129
- feat_k2 = (x ** 2 ).unsqueeze (1 )
130
-
131
- h_k = torch .bmm (diag_elements .unsqueeze (2 ), feat_k2 ).view (b , - 1 )
132
-
133
- # has a bias
134
- if self .bias is not None :
135
- h_k = torch .cat ([h_k , diag_elements ], dim = 1 )
136
-
137
- return h_k
138
-
139
+ tmp_diag = torch .diagonal (tmp , dim1 = 1 , dim2 = 2 )
140
+ return self ._jacobian_wrt_weight_sandwich_diag_to_diag (x , val , tmp_diag )
141
+
139
142
def _jacobian_wrt_weight_sandwich_diag_to_full (
140
143
self , x : Tensor , val : Tensor , tmp_diag : Tensor ) -> Tensor :
141
- raise NotImplementedError
144
+ jacobian = self ._jacobian_wrt_weight (x ,val )
145
+ return torch .einsum ('bji,bj,bjq->biq' , jacobian , tmp_diag , jacobian )
142
146
143
147
def _jacobian_wrt_weight_sandwich_diag_to_diag (
144
148
self , x : Tensor , val : Tensor , tmp_diag : Tensor ) -> Tensor :
145
- raise NotImplementedError
149
+
150
+ b , c1 = x .shape
151
+ c2 = val .shape [1 ]
152
+
153
+ Jt_tmp_J = torch .bmm (tmp_diag .unsqueeze (2 ), (x ** 2 ).unsqueeze (1 )).view (b , c1 * c2 )
154
+
155
+ if self .bias is not None :
156
+ Jt_tmp_J = torch .cat ([Jt_tmp_J , tmp_diag ], dim = 1 )
157
+
158
+ return Jt_tmp_J
146
159
147
160
148
161
0 commit comments