Skip to content

Commit 39bc256

Browse files
author
Frederik Rahbaek Warburg
committed
device error
1 parent 983177d commit 39bc256

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

stochman/nnj.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -873,7 +873,7 @@ def _jacobian_wrt_input_full_sandwich(self, x: Tensor, val: Tensor, tmp: Tensor)
873873
assert c1==c2
874874

875875
tmp = tmp.reshape(b, c1, h2*w2, c1, h2*w2).movedim(-2,-3).reshape(b*c1*c1, h2*w2, h2*w2)
876-
Jt_tmp_J = torch.zeros((b*c1*c1, h1*w1, h1*w1))
876+
Jt_tmp_J = torch.zeros((b*c1*c1, h1*w1, h1*w1), device=tmp.device)
877877
# indexes for batch and channel
878878
arange_repeated = torch.repeat_interleave(torch.arange(b*c1*c1), h2*w2 * h2*w2).long()
879879
arange_repeated = arange_repeated.reshape(b*c1*c1, h2*w2, h2*w2)

0 commit comments

Comments
 (0)