Skip to content

Commit 1ee5018

Browse files
author
Frederik Rahbaek Warburg
committed
fixed small device error
1 parent 560f25a commit 1ee5018

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

stochman/nnj.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def _jacobian_wrt_input(self, x: Tensor, val: Tensor) -> Tensor:
8383
def _jacobian_wrt_weight(self, x: Tensor, val: Tensor) -> Tensor:
8484
b, c1 = x.shape
8585
c2 = val.shape[1]
86-
out_identity = torch.diag_embed(torch.ones(c2))
86+
out_identity = torch.diag_embed(torch.ones(c2, device=x.device))
8787
jacobian = torch.einsum('bk,ij->bijk', x, out_identity).reshape(b,c2,c2*c1)
8888
if self.bias is not None:
8989
jacobian = torch.cat([jacobian, out_identity.unsqueeze(0).expand(b,-1,-1)], dim=2)

0 commit comments

Comments
 (0)