Skip to content

Commit 9582b1d

Browse files
Gornokasamet-akcay
andauthored
fix: efficient ad model_size str fixes (#2159)
* fix: model str conversion Signed-off-by: Lukas Hennies <lukas.hennies@anticipate.ml> * fix: str support for model_size param in efficient ad Signed-off-by: Lukas Hennies <lukas.hennies@anticipate.ml> * chore: update type hints for efficient_ad init also pre-commit Signed-off-by: Lukas Hennies <lukas.hennies@anticipate.ml> --------- Signed-off-by: Lukas Hennies <lukas.hennies@anticipate.ml> Co-authored-by: Samet Akcay <samet.akcay@intel.com>
1 parent b646d1a commit 9582b1d

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

src/anomalib/models/image/efficient_ad/lightning_model.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def __init__(
6363
self,
6464
imagenet_dir: Path | str = "./datasets/imagenette",
6565
teacher_out_channels: int = 384,
66-
model_size: EfficientAdModelSize = EfficientAdModelSize.S,
66+
model_size: EfficientAdModelSize | str = EfficientAdModelSize.S,
6767
lr: float = 0.0001,
6868
weight_decay: float = 0.00001,
6969
padding: bool = False,
@@ -72,24 +72,27 @@ def __init__(
7272
super().__init__()
7373

7474
self.imagenet_dir = Path(imagenet_dir)
75-
self.model_size = model_size
75+
if not isinstance(model_size, EfficientAdModelSize):
76+
model_size = EfficientAdModelSize(model_size)
77+
self.model_size: EfficientAdModelSize = model_size
7678
self.model: EfficientAdModel = EfficientAdModel(
7779
teacher_out_channels=teacher_out_channels,
7880
model_size=model_size,
7981
padding=padding,
8082
pad_maps=pad_maps,
8183
)
82-
self.batch_size = 1 # imagenet dataloader batch_size is 1 according to the paper
83-
self.lr = lr
84-
self.weight_decay = weight_decay
84+
self.batch_size: int = 1 # imagenet dataloader batch_size is 1 according to the paper
85+
self.lr: float = lr
86+
self.weight_decay: float = weight_decay
8587

8688
def prepare_pretrained_model(self) -> None:
8789
"""Prepare the pretrained teacher model."""
8890
pretrained_models_dir = Path("./pre_trained/")
8991
if not (pretrained_models_dir / "efficientad_pretrained_weights").is_dir():
9092
download_and_extract(pretrained_models_dir, WEIGHTS_DOWNLOAD_INFO)
93+
model_size_str = self.model_size.value if isinstance(self.model_size, EfficientAdModelSize) else self.model_size
9194
teacher_path = (
92-
pretrained_models_dir / "efficientad_pretrained_weights" / f"pretrained_teacher_{self.model_size.value}.pth"
95+
pretrained_models_dir / "efficientad_pretrained_weights" / f"pretrained_teacher_{model_size_str}.pth"
9396
)
9497
logger.info(f"Load pretrained teacher model from {teacher_path}")
9598
self.model.teacher.load_state_dict(torch.load(teacher_path, map_location=torch.device(self.device)))

0 commit comments

Comments
 (0)