Skip to content

Commit 3b0d330

Browse files
committed
handle stochastic depth
1 parent fcdb200 commit 3b0d330

33 files changed

+33
-34
lines changed

timm/models/beit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,7 @@ def __init__(
326326
else:
327327
self.rel_pos_bias = None
328328

329-
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
329+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth, device="cpu")] # stochastic depth decay rule
330330
self.blocks = nn.ModuleList([
331331
Block(
332332
dim=embed_dim,

timm/models/byobnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1113,7 +1113,7 @@ def create_byob_stages(
11131113
feature_info = []
11141114
block_cfgs = [expand_blocks_cfg(s) for s in cfg.blocks]
11151115
depths = [sum([bc.d for bc in stage_bcs]) for stage_bcs in block_cfgs]
1116-
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
1116+
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths), device="cpu").split(depths)]
11171117
dilation = 1
11181118
net_stride = stem_feat['reduction']
11191119
prev_chs = stem_feat['num_chs']

timm/models/convit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ def __init__(
292292
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
293293
trunc_normal_(self.pos_embed, std=.02)
294294

295-
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
295+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth, device="cpu")] # stochastic depth decay rule
296296
self.blocks = nn.ModuleList([
297297
Block(
298298
dim=embed_dim,

timm/models/convnext.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,7 @@ def __init__(
328328
stem_stride = 4
329329

330330
self.stages = nn.Sequential()
331-
dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
331+
dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths), device="cpu").split(depths)]
332332
stages = []
333333
prev_chs = dims[0]
334334
curr_stride = stem_stride

timm/models/crossvit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,7 @@ def __init__(
351351
self.pos_drop = nn.Dropout(p=pos_drop_rate)
352352

353353
total_depth = sum([sum(x[-2:]) for x in depth])
354-
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, total_depth)] # stochastic depth decay rule
354+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, total_depth, device="cpu")] # stochastic depth decay rule
355355
dpr_ptr = 0
356356
self.blocks = nn.ModuleList()
357357
for idx, block_cfg in enumerate(depth):

timm/models/cspnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -569,7 +569,7 @@ def create_csp_stages(
569569
cfg_dict = asdict(cfg.stages)
570570
num_stages = len(cfg.stages.depth)
571571
cfg_dict['block_dpr'] = [None] * num_stages if not drop_path_rate else \
572-
[x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.stages.depth)).split(cfg.stages.depth)]
572+
[x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.stages.depth), device="cpu").split(cfg.stages.depth)]
573573
stage_args = [dict(zip(cfg_dict.keys(), values)) for values in zip(*cfg_dict.values())]
574574
block_kwargs = dict(
575575
act_layer=cfg.act_layer,

timm/models/davit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -554,7 +554,7 @@ def __init__(
554554
self.stem = Stem(in_chans, embed_dims[0], norm_layer=norm_layer)
555555
in_chs = embed_dims[0]
556556

557-
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
557+
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths), device="cpu").split(depths)]
558558
stages = []
559559
for stage_idx in range(num_stages):
560560
out_chs = embed_dims[stage_idx]

timm/models/edgenext.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,7 @@ def __init__(
342342

343343
curr_stride = 4
344344
stages = []
345-
dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
345+
dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths), device="cpu").split(depths)]
346346
in_chs = dims[0]
347347
for i in range(4):
348348
stride = 2 if curr_stride == 2 or i > 0 else 1

timm/models/efficientformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -385,7 +385,7 @@ def __init__(
385385
# stochastic depth decay rule
386386
self.num_stages = len(depths)
387387
last_stage = self.num_stages - 1
388-
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
388+
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths), device="cpu").split(depths)]
389389
downsamples = downsamples or (False,) + (True,) * (self.num_stages - 1)
390390
stages = []
391391
self.feature_info = []

timm/models/efficientformer_v2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -542,7 +542,7 @@ def __init__(
542542
stride = 4
543543

544544
num_stages = len(depths)
545-
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
545+
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths), device="cpu").split(depths)]
546546
downsamples = downsamples or (False,) + (True,) * (len(depths) - 1)
547547
mlp_ratios = to_ntuple(num_stages)(mlp_ratios)
548548
stages = []

0 commit comments

Comments
 (0)