@@ -63,7 +63,7 @@ def __init__(
63
63
self ,
64
64
imagenet_dir : Path | str = "./datasets/imagenette" ,
65
65
teacher_out_channels : int = 384 ,
66
- model_size : EfficientAdModelSize = EfficientAdModelSize .S ,
66
+ model_size : EfficientAdModelSize | str = EfficientAdModelSize .S ,
67
67
lr : float = 0.0001 ,
68
68
weight_decay : float = 0.00001 ,
69
69
padding : bool = False ,
@@ -72,24 +72,27 @@ def __init__(
72
72
super ().__init__ ()
73
73
74
74
self .imagenet_dir = Path (imagenet_dir )
75
- self .model_size = model_size
75
+ if not isinstance (model_size , EfficientAdModelSize ):
76
+ model_size = EfficientAdModelSize (model_size )
77
+ self .model_size : EfficientAdModelSize = model_size
76
78
self .model : EfficientAdModel = EfficientAdModel (
77
79
teacher_out_channels = teacher_out_channels ,
78
80
model_size = model_size ,
79
81
padding = padding ,
80
82
pad_maps = pad_maps ,
81
83
)
82
- self .batch_size = 1 # imagenet dataloader batch_size is 1 according to the paper
83
- self .lr = lr
84
- self .weight_decay = weight_decay
84
+ self .batch_size : int = 1 # imagenet dataloader batch_size is 1 according to the paper
85
+ self .lr : float = lr
86
+ self .weight_decay : float = weight_decay
85
87
86
88
def prepare_pretrained_model (self ) -> None :
87
89
"""Prepare the pretrained teacher model."""
88
90
pretrained_models_dir = Path ("./pre_trained/" )
89
91
if not (pretrained_models_dir / "efficientad_pretrained_weights" ).is_dir ():
90
92
download_and_extract (pretrained_models_dir , WEIGHTS_DOWNLOAD_INFO )
93
+ model_size_str = self .model_size .value if isinstance (self .model_size , EfficientAdModelSize ) else self .model_size
91
94
teacher_path = (
92
- pretrained_models_dir / "efficientad_pretrained_weights" / f"pretrained_teacher_{ self . model_size . value } .pth"
95
+ pretrained_models_dir / "efficientad_pretrained_weights" / f"pretrained_teacher_{ model_size_str } .pth"
93
96
)
94
97
logger .info (f"Load pretrained teacher model from { teacher_path } " )
95
98
self .model .teacher .load_state_dict (torch .load (teacher_path , map_location = torch .device (self .device )))
0 commit comments