Skip to content

Commit a02b1a8

Browse files
authored
Merge pull request #2369 from brianhou0208/fix_reduction
Fix feature_info.reduction
2 parents ea23107 + ab0a70d commit a02b1a8

File tree

3 files changed

+9
-9
lines changed

3 files changed

+9
-9
lines changed

timm/models/davit.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -556,19 +556,19 @@ def __init__(
556556

557557
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
558558
stages = []
559-
for stage_idx in range(num_stages):
560-
out_chs = embed_dims[stage_idx]
559+
for i in range(num_stages):
560+
out_chs = embed_dims[i]
561561
stage = DaVitStage(
562562
in_chs,
563563
out_chs,
564-
depth=depths[stage_idx],
565-
downsample=stage_idx > 0,
564+
depth=depths[i],
565+
downsample=i > 0,
566566
attn_types=attn_types,
567-
num_heads=num_heads[stage_idx],
567+
num_heads=num_heads[i],
568568
window_size=window_size,
569569
mlp_ratio=mlp_ratio,
570570
qkv_bias=qkv_bias,
571-
drop_path_rates=dpr[stage_idx],
571+
drop_path_rates=dpr[i],
572572
norm_layer=norm_layer,
573573
norm_layer_cl=norm_layer_cl,
574574
ffn=ffn,
@@ -579,7 +579,7 @@ def __init__(
579579
)
580580
in_chs = out_chs
581581
stages.append(stage)
582-
self.feature_info += [dict(num_chs=out_chs, reduction=2, module=f'stages.{stage_idx}')]
582+
self.feature_info += [dict(num_chs=out_chs, reduction=2**(i+2), module=f'stages.{i}')]
583583

584584
self.stages = nn.Sequential(*stages)
585585

timm/models/efficientformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,7 @@ def __init__(
407407
)
408408
prev_dim = embed_dims[i]
409409
stages.append(stage)
410-
self.feature_info += [dict(num_chs=embed_dims[i], reduction=2**(1+i), module=f'stages.{i}')]
410+
self.feature_info += [dict(num_chs=embed_dims[i], reduction=2**(i+2), module=f'stages.{i}')]
411411
self.stages = nn.Sequential(*stages)
412412

413413
# Classifier head

timm/models/metaformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -541,7 +541,7 @@ def __init__(
541541
**kwargs,
542542
)]
543543
prev_dim = dims[i]
544-
self.feature_info += [dict(num_chs=dims[i], reduction=2, module=f'stages.{i}')]
544+
self.feature_info += [dict(num_chs=dims[i], reduction=2**(i+2), module=f'stages.{i}')]
545545

546546
self.stages = nn.Sequential(*stages)
547547

0 commit comments

Comments
 (0)