Skip to content

Commit 560f25a

Browse files
Linear Layer completed
1 parent 851acdc commit 560f25a

File tree

1 file changed

+29
-16
lines changed

1 file changed

+29
-16
lines changed

stochman/nnj.py

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,18 @@ def _jacobian_wrt_input_mult_left_vec(self, x: Tensor, val: Tensor, jac_in: Tens
7676

7777
def _jacobian_wrt_input_transpose_mult_left_vec(self, x: Tensor, val: Tensor, jac_in: Tensor) -> Tensor:
7878
return F.linear(jac_in.movedim(1, -1), self.weight.T, bias=None).movedim(-1, 1)
79+
80+
def _jacobian_wrt_input(self, x: Tensor, val: Tensor) -> Tensor:
81+
return self.weight
82+
83+
def _jacobian_wrt_weight(self, x: Tensor, val: Tensor) -> Tensor:
84+
b, c1 = x.shape
85+
c2 = val.shape[1]
86+
out_identity = torch.diag_embed(torch.ones(c2))
87+
jacobian = torch.einsum('bk,ij->bijk', x, out_identity).reshape(b,c2,c2*c1)
88+
if self.bias is not None:
89+
jacobian = torch.cat([jacobian, out_identity.unsqueeze(0).expand(b,-1,-1)], dim=2)
90+
return jacobian
7991

8092
def _jacobian_wrt_input_sandwich(
8193
self, x: Tensor, val: Tensor, tmp: Tensor, diag_inp: bool = False, diag_out: bool = False
@@ -119,30 +131,31 @@ def _jacobian_wrt_input_sandwich_diag_to_diag(
119131

120132
def _jacobian_wrt_weight_sandwich_full_to_full(
121133
self, x: Tensor, val: Tensor, tmp: Tensor) -> Tensor:
122-
raise NotImplementedError
134+
jacobian = self._jacobian_wrt_weight(x,val)
135+
return torch.einsum('bji,bjk,bkq->biq', jacobian, tmp, jacobian)
123136

124137
def _jacobian_wrt_weight_sandwich_full_to_diag(
125138
self, x: Tensor, val: Tensor, tmp: Tensor) -> Tensor:
126-
127-
b, c = x.shape
128-
diag_elements = torch.diagonal(tmp, dim1=1, dim2=2)
129-
feat_k2 = (x**2).unsqueeze(1)
130-
131-
h_k = torch.bmm(diag_elements.unsqueeze(2), feat_k2).view(b, -1)
132-
133-
# has a bias
134-
if self.bias is not None:
135-
h_k = torch.cat([h_k, diag_elements], dim=1)
136-
137-
return h_k
138-
139+
tmp_diag = torch.diagonal(tmp, dim1=1, dim2=2)
140+
return self._jacobian_wrt_weight_sandwich_diag_to_diag(x, val, tmp_diag)
141+
139142
def _jacobian_wrt_weight_sandwich_diag_to_full(
140143
self, x: Tensor, val: Tensor, tmp_diag: Tensor) -> Tensor:
141-
raise NotImplementedError
144+
jacobian = self._jacobian_wrt_weight(x,val)
145+
return torch.einsum('bji,bj,bjq->biq', jacobian, tmp_diag, jacobian)
142146

143147
def _jacobian_wrt_weight_sandwich_diag_to_diag(
144148
self, x: Tensor, val: Tensor, tmp_diag: Tensor) -> Tensor:
145-
raise NotImplementedError
149+
150+
b, c1 = x.shape
151+
c2 = val.shape[1]
152+
153+
Jt_tmp_J = torch.bmm(tmp_diag.unsqueeze(2), (x**2).unsqueeze(1)).view(b, c1*c2)
154+
155+
if self.bias is not None:
156+
Jt_tmp_J = torch.cat([Jt_tmp_J, tmp_diag], dim=1)
157+
158+
return Jt_tmp_J
146159

147160

148161

0 commit comments

Comments
 (0)