Skip to content

Commit 7fd7206

Browse files
committed
add explicit CPU for persistent buffer
1 parent 809a2b5 commit 7fd7206

11 files changed

+35
-28
lines changed

timm/layers/lambda_layer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131

3232
def rel_pos_indices(size):
3333
size = to_2tuple(size)
34-
pos = torch.stack(ndgrid(torch.arange(size[0]), torch.arange(size[1]))).flatten(1)
34+
pos = torch.stack(ndgrid(torch.arange(size[0], device="cpu"), torch.arange(size[1], device="cpu"))).flatten(1)
3535
rel_pos = pos[:, None, :] - pos[:, :, None]
3636
rel_pos[0] += size[0] - 1
3737
rel_pos[1] += size[1] - 1

timm/layers/pos_embed_rel.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def gen_relative_position_index(
2727
# get pair-wise relative position index for each token inside the window
2828
assert k_size is None, 'Different q & k sizes not currently supported' # FIXME
2929

30-
coords = torch.stack(ndgrid(torch.arange(q_size[0]), torch.arange(q_size[1]))).flatten(1) # 2, Wh, Ww
30+
coords = torch.stack(ndgrid(torch.arange(q_size[0], device="cpu"), torch.arange(q_size[1], device="cpu"))).flatten(1) # 2, Wh, Ww
3131
relative_coords = coords[:, :, None] - coords[:, None, :] # 2, Wh*Ww, Wh*Ww
3232
relative_coords = relative_coords.permute(1, 2, 0) # Qh*Qw, Kh*Kw, 2
3333
relative_coords[:, :, 0] += q_size[0] - 1 # shift to start from 0
@@ -307,8 +307,8 @@ def gen_relative_log_coords(
307307
):
308308
assert mode in ('swin', 'cr')
309309
# as per official swin-v2 impl, supporting timm specific 'cr' log coords as well
310-
relative_coords_h = torch.arange(-(win_size[0] - 1), win_size[0]).to(torch.float32)
311-
relative_coords_w = torch.arange(-(win_size[1] - 1), win_size[1]).to(torch.float32)
310+
relative_coords_h = torch.arange(-(win_size[0] - 1), win_size[0], device="cpu").to(torch.float32)
311+
relative_coords_w = torch.arange(-(win_size[1] - 1), win_size[1], device="cpu").to(torch.float32)
312312
relative_coords_table = torch.stack(ndgrid(relative_coords_h, relative_coords_w))
313313
relative_coords_table = relative_coords_table.permute(1, 2, 0).contiguous() # 2*Wh-1, 2*Ww-1, 2
314314
if mode == 'swin':
@@ -415,7 +415,7 @@ def generate_lookup_tensor(
415415
max_relative_position = length - 1
416416
# Return the cached lookup tensor, otherwise compute it and cache it.
417417
vocab_size = 2 * max_relative_position + 1
418-
ret = torch.zeros(length, length, vocab_size)
418+
ret = torch.zeros(length, length, vocab_size, device="cpu")
419419
for i in range(length):
420420
for x in range(length):
421421
v = x - i + max_relative_position

timm/layers/pos_embed_sincos.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def __init__(
163163
self.keep_spatial = keep_spatial
164164
self.register_buffer(
165165
'bands',
166-
pixel_freq_bands(max_res, num_bands),
166+
pixel_freq_bands(max_res, num_bands, device="cpu"),
167167
persistent=False,
168168
)
169169

@@ -305,12 +305,14 @@ def __init__(
305305
dim // 4,
306306
float(max_res),
307307
linear_bands=linear_bands,
308+
device="cpu",
308309
)
309310
else:
310311
bands = freq_bands(
311312
dim // 4,
312313
temperature=temperature,
313314
step=1,
315+
device="cpu",
314316
)
315317
self.register_buffer(
316318
'bands',
@@ -328,6 +330,7 @@ def __init__(
328330
linear_bands=linear_bands,
329331
in_pixels=in_pixels,
330332
ref_feat_shape=self.ref_feat_shape,
333+
device="cpu",
331334
)
332335
self.bands = None
333336
self.register_buffer(
@@ -392,12 +395,14 @@ def __init__(
392395
dim // 4,
393396
float(max_res),
394397
linear_bands=linear_bands,
398+
device="cpu",
395399
)
396400
else:
397401
bands = freq_bands(
398402
dim // 4,
399403
temperature=temperature,
400404
step=1,
405+
device="cpu",
401406
)
402407
self.register_buffer(
403408
'bands',
@@ -414,6 +419,7 @@ def __init__(
414419
linear_bands=linear_bands,
415420
in_pixels=in_pixels,
416421
ref_feat_shape=self.ref_feat_shape,
422+
device="cpu",
417423
)
418424
self.bands = None
419425
self.register_buffer(

timm/models/beit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def gen_relative_position_index(window_size: Tuple[int, int]) -> torch.Tensor:
6363
# cls to token & token 2 cls & cls to cls
6464
# get pair-wise relative position index for each token inside the window
6565
window_area = window_size[0] * window_size[1]
66-
coords = torch.stack(ndgrid(torch.arange(window_size[0]), torch.arange(window_size[1]))) # 2, Wh, Ww
66+
coords = torch.stack(ndgrid(torch.arange(window_size[0], device="cpu"), torch.arange(window_size[1], device="cpu"))) # 2, Wh, Ww
6767
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
6868
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
6969
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
@@ -105,7 +105,7 @@ def __init__(
105105
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
106106
if qkv_bias:
107107
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
108-
self.register_buffer('k_bias', torch.zeros(all_head_dim), persistent=False)
108+
self.register_buffer('k_bias', torch.zeros(all_head_dim, device="cpu"), persistent=False)
109109
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
110110
else:
111111
self.q_bias = None

timm/models/efficientformer_v2.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def __init__(
131131
self.act = act_layer()
132132
self.proj = ConvNorm(self.dh, dim, 1)
133133

134-
pos = torch.stack(ndgrid(torch.arange(self.resolution[0]), torch.arange(self.resolution[1]))).flatten(1)
134+
pos = torch.stack(ndgrid(torch.arange(self.resolution[0], device="cpu"), torch.arange(self.resolution[1], device="cpu"))).flatten(1)
135135
rel_pos = (pos[..., :, None] - pos[..., None, :]).abs()
136136
rel_pos = (rel_pos[0] * self.resolution[1]) + rel_pos[1]
137137
self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, self.N))
@@ -233,10 +233,10 @@ def __init__(
233233
self.proj = ConvNorm(self.dh, self.out_dim, 1)
234234

235235
self.attention_biases = nn.Parameter(torch.zeros(num_heads, self.N))
236-
k_pos = torch.stack(ndgrid(torch.arange(self.resolution[0]), torch.arange(self.resolution[1]))).flatten(1)
236+
k_pos = torch.stack(ndgrid(torch.arange(self.resolution[0], device="cpu"), torch.arange(self.resolution[1], device="cpu"))).flatten(1)
237237
q_pos = torch.stack(ndgrid(
238-
torch.arange(0, self.resolution[0], step=2),
239-
torch.arange(0, self.resolution[1], step=2)
238+
torch.arange(0, self.resolution[0], step=2, device="cpu"),
239+
torch.arange(0, self.resolution[1], step=2, device="cpu")
240240
)).flatten(1)
241241
rel_pos = (q_pos[..., :, None] - k_pos[..., None, :]).abs()
242242
rel_pos = (rel_pos[0] * self.resolution[1]) + rel_pos[1]

timm/models/eva.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def __init__(
8888
self.q_proj = self.k_proj = self.v_proj = None
8989
if qkv_bias:
9090
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
91-
self.register_buffer('k_bias', torch.zeros(all_head_dim), persistent=False)
91+
self.register_buffer('k_bias', torch.zeros(all_head_dim, device="cpu"), persistent=False)
9292
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
9393
else:
9494
self.q_bias = self.k_bias = self.v_bias = None

timm/models/levit.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ def __init__(
195195
]))
196196

197197
self.attention_biases = nn.Parameter(torch.zeros(num_heads, resolution[0] * resolution[1]))
198-
pos = torch.stack(ndgrid(torch.arange(resolution[0]), torch.arange(resolution[1]))).flatten(1)
198+
pos = torch.stack(ndgrid(torch.arange(resolution[0], device="cpu"), torch.arange(resolution[1], device="cpu"))).flatten(1)
199199
rel_pos = (pos[..., :, None] - pos[..., None, :]).abs()
200200
rel_pos = (rel_pos[0] * resolution[1]) + rel_pos[1]
201201
self.register_buffer('attention_bias_idxs', rel_pos, persistent=False)
@@ -291,10 +291,10 @@ def __init__(
291291
]))
292292

293293
self.attention_biases = nn.Parameter(torch.zeros(num_heads, resolution[0] * resolution[1]))
294-
k_pos = torch.stack(ndgrid(torch.arange(resolution[0]), torch.arange(resolution[1]))).flatten(1)
294+
k_pos = torch.stack(ndgrid(torch.arange(resolution[0], device="cpu"), torch.arange(resolution[1], device="cpu"))).flatten(1)
295295
q_pos = torch.stack(ndgrid(
296-
torch.arange(0, resolution[0], step=stride),
297-
torch.arange(0, resolution[1], step=stride)
296+
torch.arange(0, resolution[0], step=stride, device="cpu"),
297+
torch.arange(0, resolution[1], step=stride, device="cpu"),
298298
)).flatten(1)
299299
rel_pos = (q_pos[..., :, None] - k_pos[..., None, :]).abs()
300300
rel_pos = (rel_pos[0] * resolution[1]) + rel_pos[1]

timm/models/swin_transformer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def window_reverse(windows, window_size: Tuple[int, int], H: int, W: int):
7979

8080
def get_relative_position_index(win_h: int, win_w: int):
8181
# get pair-wise relative position index for each token inside the window
82-
coords = torch.stack(ndgrid(torch.arange(win_h), torch.arange(win_w))) # 2, Wh, Ww
82+
coords = torch.stack(ndgrid(torch.arange(win_h, device="cpu"), torch.arange(win_w, device="cpu"))) # 2, Wh, Ww
8383
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
8484
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
8585
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
@@ -291,7 +291,7 @@ def get_attn_mask(self, x: Optional[torch.Tensor] = None) -> Optional[torch.Tens
291291
dtype = x.dtype
292292
else:
293293
H, W = self.input_resolution
294-
device = None
294+
device = "cpu"
295295
dtype = None
296296
H = math.ceil(H / self.window_size[0]) * self.window_size[0]
297297
W = math.ceil(W / self.window_size[1]) * self.window_size[1]

timm/models/swin_transformer_v2.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def __init__(
110110
self.qkv = nn.Linear(dim, dim * 3, bias=False)
111111
if qkv_bias:
112112
self.q_bias = nn.Parameter(torch.zeros(dim))
113-
self.register_buffer('k_bias', torch.zeros(dim), persistent=False)
113+
self.register_buffer('k_bias', torch.zeros(dim, device="cpu"), persistent=False)
114114
self.v_bias = nn.Parameter(torch.zeros(dim))
115115
else:
116116
self.q_bias = None
@@ -125,8 +125,8 @@ def __init__(
125125

126126
def _make_pair_wise_relative_positions(self):
127127
# get relative_coords_table
128-
relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0]).to(torch.float32)
129-
relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1]).to(torch.float32)
128+
relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], device="cpu").to(torch.float32)
129+
relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], device="cpu").to(torch.float32)
130130
relative_coords_table = torch.stack(ndgrid(relative_coords_h, relative_coords_w))
131131
relative_coords_table = relative_coords_table.permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2
132132
if self.pretrained_window_size[0] > 0:
@@ -141,8 +141,8 @@ def _make_pair_wise_relative_positions(self):
141141
self.register_buffer("relative_coords_table", relative_coords_table, persistent=False)
142142

143143
# get pair-wise relative position index for each token inside the window
144-
coords_h = torch.arange(self.window_size[0])
145-
coords_w = torch.arange(self.window_size[1])
144+
coords_h = torch.arange(self.window_size[0], device="cpu")
145+
coords_w = torch.arange(self.window_size[1], device="cpu")
146146
coords = torch.stack(ndgrid(coords_h, coords_w)) # 2, Wh, Ww
147147
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
148148
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
@@ -293,7 +293,7 @@ def get_attn_mask(self, x: Optional[torch.Tensor] = None) -> Optional[torch.Tens
293293
if any(self.shift_size):
294294
# calculate attention mask for SW-MSA
295295
if x is None:
296-
img_mask = torch.zeros((1, *self.input_resolution, 1)) # 1 H W 1
296+
img_mask = torch.zeros((1, *self.input_resolution, 1), device="cpu") # 1 H W 1
297297
else:
298298
img_mask = torch.zeros((1, x.shape[1], x.shape[2], 1), dtype=x.dtype, device=x.device) # 1 H W 1
299299
cnt = 0

timm/models/swin_transformer_v2_cr.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,8 @@ def __init__(
141141

142142
def _make_pair_wise_relative_positions(self) -> None:
143143
"""Method initializes the pair-wise relative positions to compute the positional biases."""
144-
device = self.logit_scale.device
144+
# device = self.logit_scale.device
145+
device = "cpu"
145146
coordinates = torch.stack(ndgrid(
146147
torch.arange(self.window_size[0], device=device),
147148
torch.arange(self.window_size[1], device=device)
@@ -314,7 +315,7 @@ def get_attn_mask(self, x: Optional[torch.Tensor] = None) -> Optional[torch.Tens
314315
if any(self.shift_size):
315316
# calculate attention mask for SW-MSA
316317
if x is None:
317-
img_mask = torch.zeros((1, *self.feat_size, 1)) # 1 H W 1
318+
img_mask = torch.zeros((1, *self.feat_size, 1), device="cpu") # 1 H W 1
318319
else:
319320
img_mask = torch.zeros((1, x.shape[1], x.shape[2], 1), dtype=x.dtype, device=x.device) # 1 H W 1
320321
cnt = 0

0 commit comments

Comments
 (0)