Skip to content

Commit df7ae11

Browse files
committed
Add device arg for patch embed resize, fix #2024
1 parent cd8d9d9 commit df7ae11

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

timm/layers/patch_embed.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ def get_resize_mat(_old_size, _new_size):
196196
return np.stack(mat).T
197197

198198
resize_mat = get_resize_mat(old_size, new_size)
199-
resize_mat_pinv = torch.Tensor(np.linalg.pinv(resize_mat.T))
199+
resize_mat_pinv = torch.tensor(np.linalg.pinv(resize_mat.T), device=patch_embed.device)
200200

201201
def resample_kernel(kernel):
202202
resampled_kernel = resize_mat_pinv @ kernel.reshape(-1)

0 commit comments

Comments
 (0)