Skip to content

Commit e0cb669

Browse files
committed
Make features_only=True work with mnv5 & enc, uses forward_intermediates()
1 parent 739b46c commit e0cb669

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

timm/models/mobilenetv5.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -554,6 +554,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
554554

555555

556556
def _create_mnv5_encoder(variant: str, pretrained: bool = False, **kwargs) -> MobileNetV5Encoder:
557+
out_indices = kwargs.pop('out_indices', (0, 1, 2, 3, 4))
558+
feature_cfg = dict(out_indices=out_indices, feature_cls='getter')
557559
kwargs_filter = (
558560
'num_classes',
559561
'num_features',
@@ -567,18 +569,22 @@ def _create_mnv5_encoder(variant: str, pretrained: bool = False, **kwargs) -> Mo
567569
variant,
568570
pretrained,
569571
pretrained_strict=False,
572+
feature_cfg=feature_cfg,
570573
kwargs_filter=kwargs_filter,
571574
**kwargs,
572575
)
573576
return model
574577

575578

576579
def _create_mnv5(variant: str, pretrained: bool = False, **kwargs) -> MobileNetV5Encoder:
580+
out_indices = kwargs.pop('out_indices', (0, 1, 2, 3, 4))
581+
feature_cfg = dict(out_indices=out_indices, feature_cls='getter')
577582
model = build_model_with_cfg(
578583
MobileNetV5,
579584
variant,
580585
pretrained,
581586
pretrained_strict=False,
587+
feature_cfg=feature_cfg,
582588
**kwargs,
583589
)
584590
return model

0 commit comments

Comments
 (0)