Skip to content

Commit 809a2b5

Browse files
committed
more stochastic depth stuff
1 parent 3b0d330 commit 809a2b5

File tree

8 files changed

+8
-8
lines changed

8 files changed

+8
-8
lines changed

timm/models/maxxvit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1174,7 +1174,7 @@ def __init__(
11741174

11751175
num_stages = len(cfg.embed_dim)
11761176
assert len(cfg.depths) == num_stages
1177-
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.depths)).split(cfg.depths)]
1177+
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.depths), device="cpu").split(cfg.depths)]
11781178
in_chs = self.stem.out_chs
11791179
stages = []
11801180
for i in range(num_stages):

timm/models/mvitv2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -749,7 +749,7 @@ def __init__(
749749
num_stages = len(cfg.embed_dim)
750750
feat_size = patch_dims
751751
curr_stride = max(cfg.patch_stride)
752-
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.depths)).split(cfg.depths)]
752+
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.depths), device="cpu").split(cfg.depths)]
753753
self.stages = nn.ModuleList()
754754
self.feature_info = []
755755
for i in range(num_stages):

timm/models/nfnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,7 @@ def __init__(
337337
)
338338

339339
self.feature_info = [stem_feat]
340-
drop_path_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.depths)).split(cfg.depths)]
340+
drop_path_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.depths), device="cpu").split(cfg.depths)]
341341
prev_chs = stem_chs
342342
net_stride = stem_stride
343343
dilation = 1

timm/models/pit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def __init__(
183183

184184
transformers = []
185185
# stochastic depth decay rule
186-
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depth)).split(depth)]
186+
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depth), device="cpu").split(depth)]
187187
prev_dim = embed_dim
188188
for i in range(len(depth)):
189189
pool = None

timm/models/rdnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ def __init__(
213213
self.num_stages = len(growth_rates)
214214
curr_stride = stem_stride
215215
num_features = num_init_features
216-
dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(num_blocks_list)).split(num_blocks_list)]
216+
dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(num_blocks_list), device="cpu").split(num_blocks_list)]
217217

218218
dense_stages = []
219219
for i in range(self.num_stages):

timm/models/resnetv2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -467,7 +467,7 @@ def __init__(
467467
prev_chs = stem_chs
468468
curr_stride = 4
469469
dilation = 1
470-
block_dprs = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(layers)).split(layers)]
470+
block_dprs = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(layers), device="cpu").split(layers)]
471471
if preact:
472472
block_fn = PreActBasic if basic else PreActBottleneck
473473
else:

timm/models/tresnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def __init__(
135135
self.inplanes = self.inplanes // 8 * 8
136136
self.planes = self.planes // 8 * 8
137137

138-
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(layers)).split(layers)]
138+
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(layers), device="cpu").split(layers)]
139139
conv1 = ConvNormAct(in_chans * 16, self.planes, stride=1, kernel_size=3, act_layer=act_layer)
140140
layer1 = self._make_layer(
141141
Bottleneck if v2 else BasicBlock,

timm/models/vovnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ def __init__(
213213
current_stride = stem_stride
214214

215215
# OSA stages
216-
stage_dpr = torch.split(torch.linspace(0, drop_path_rate, sum(block_per_stage)), block_per_stage)
216+
stage_dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(block_per_stage), device="cpu").split(block_per_stage)]
217217
in_ch_list = stem_chs[-1:] + stage_out_chs[:-1]
218218
stage_args = dict(residual=cfg["residual"], depthwise=cfg["depthwise"], attn=cfg["attn"], **conv_kwargs)
219219
stages = []

0 commit comments

Comments
 (0)