File tree Expand file tree Collapse file tree 3 files changed +9
-9
lines changed Expand file tree Collapse file tree 3 files changed +9
-9
lines changed Original file line number Diff line number Diff line change @@ -556,19 +556,19 @@ def __init__(
556
556
557
557
dpr = [x .tolist () for x in torch .linspace (0 , drop_path_rate , sum (depths )).split (depths )]
558
558
stages = []
559
- for stage_idx in range (num_stages ):
560
- out_chs = embed_dims [stage_idx ]
559
+ for i in range (num_stages ):
560
+ out_chs = embed_dims [i ]
561
561
stage = DaVitStage (
562
562
in_chs ,
563
563
out_chs ,
564
- depth = depths [stage_idx ],
565
- downsample = stage_idx > 0 ,
564
+ depth = depths [i ],
565
+ downsample = i > 0 ,
566
566
attn_types = attn_types ,
567
- num_heads = num_heads [stage_idx ],
567
+ num_heads = num_heads [i ],
568
568
window_size = window_size ,
569
569
mlp_ratio = mlp_ratio ,
570
570
qkv_bias = qkv_bias ,
571
- drop_path_rates = dpr [stage_idx ],
571
+ drop_path_rates = dpr [i ],
572
572
norm_layer = norm_layer ,
573
573
norm_layer_cl = norm_layer_cl ,
574
574
ffn = ffn ,
@@ -579,7 +579,7 @@ def __init__(
579
579
)
580
580
in_chs = out_chs
581
581
stages .append (stage )
582
- self .feature_info += [dict (num_chs = out_chs , reduction = 2 , module = f'stages.{ stage_idx } ' )]
582
+ self .feature_info += [dict (num_chs = out_chs , reduction = 2 ** ( i + 2 ) , module = f'stages.{ i } ' )]
583
583
584
584
self .stages = nn .Sequential (* stages )
585
585
Original file line number Diff line number Diff line change @@ -407,7 +407,7 @@ def __init__(
407
407
)
408
408
prev_dim = embed_dims [i ]
409
409
stages .append (stage )
410
- self .feature_info += [dict (num_chs = embed_dims [i ], reduction = 2 ** (1 + i ), module = f'stages.{ i } ' )]
410
+ self .feature_info += [dict (num_chs = embed_dims [i ], reduction = 2 ** (i + 2 ), module = f'stages.{ i } ' )]
411
411
self .stages = nn .Sequential (* stages )
412
412
413
413
# Classifier head
Original file line number Diff line number Diff line change @@ -541,7 +541,7 @@ def __init__(
541
541
** kwargs ,
542
542
)]
543
543
prev_dim = dims [i ]
544
- self .feature_info += [dict (num_chs = dims [i ], reduction = 2 , module = f'stages.{ i } ' )]
544
+ self .feature_info += [dict (num_chs = dims [i ], reduction = 2 ** ( i + 2 ) , module = f'stages.{ i } ' )]
545
545
546
546
self .stages = nn .Sequential (* stages )
547
547
You can’t perform that action at this time.
0 commit comments