Skip to content

Commit e6da810

Browse files
blaz-rsamet-akcay
andauthored
Fix SuperSimpleNet pretrained weights. (#2712)
* Add pretrained weights fix Signed-off-by: blaz-r <blaz.rolih@gmail.com> * Update SSN readme Co-authored-by: Samet Akcay <samet.akcay@intel.com> Signed-off-by: Blaž Rolih <61357777+blaz-r@users.noreply.github.com> * Fix readme linting error Signed-off-by: blaz-r <blaz.rolih@gmail.com> --------- Signed-off-by: blaz-r <blaz.rolih@gmail.com> Signed-off-by: Blaž Rolih <61357777+blaz-r@users.noreply.github.com> Co-authored-by: Samet Akcay <samet.akcay@intel.com>
1 parent 9887e1e commit e6da810

File tree

3 files changed

+10
-4
lines changed

3 files changed

+10
-4
lines changed

src/anomalib/models/image/supersimplenet/README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,12 @@ This implementation supports both unsupervised and supervised setting, but Anoma
2828

2929
`anomalib train --model SuperSimpleNet --data MVTecAD --data.category <category>`
3030

31+
> IMPORTANT!
32+
>
33+
> The model is verified to work with WideResNet50 using torchvision V1 weights.
34+
> It should work with most ResNets and WideResNets, but make sure you use V1 weights if you use default noise std value.
35+
> Correct weight name ends with ".tv\_[...]", not "tv2" (e.g. "wide_resnet50_2.tv_in1k").
36+
>
3137
> It is recommended to train the model for 300 epochs with batch size of 32 to achieve stable training with random anomaly generation. Training with lower parameter values will still work, but might not yield the optimal results.
3238
>
3339
> For supervised learning, refer to the [official code](https://github.com/blaz-r/SuperSimpleNet).

src/anomalib/models/image/supersimplenet/lightning_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ class Supersimplenet(AnomalibModule):
6161
6262
Args:
6363
perlin_threshold (float): threshold value for Perlin noise thresholding during anomaly generation.
64-
backbone (str): backbone name
64+
backbone (str): backbone name. IMPORTANT! use only backbones with torchvision V1 weights ending on ".tv".
6565
layers (list[str]): backbone layers utilised
6666
supervised (bool): whether the model will be trained in supervised mode. False by default (unsupervised).
6767
pre_processor (PreProcessor | bool, optional): Pre-processor instance or
@@ -77,7 +77,7 @@ class Supersimplenet(AnomalibModule):
7777
def __init__(
7878
self,
7979
perlin_threshold: float = 0.2,
80-
backbone: str = "wide_resnet50_2",
80+
backbone: str = "wide_resnet50_2.tv_in1k", # IMPORTANT: use .tv weights, not tv2
8181
layers: list[str] = ["layer2", "layer3"], # noqa: B006
8282
supervised: bool = False,
8383
pre_processor: PreProcessor | bool = True,

src/anomalib/models/image/supersimplenet/torch_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,15 +33,15 @@ class SupersimplenetModel(nn.Module):
3333
3434
Args:
3535
perlin_threshold (float): threshold value for Perlin noise thresholding during anomaly generation.
36-
backbone (str): backbone name
36+
backbone (str): backbone name. IMPORTANT! use only backbones with torchvision V1 weights ending on ".tv".
3737
layers (list[str]): backbone layers utilised
3838
stop_grad (bool): whether to stop gradient from class. to seg. head.
3939
"""
4040

4141
def __init__(
4242
self,
4343
perlin_threshold: float = 0.2,
44-
backbone: str = "wide_resnet50_2",
44+
backbone: str = "wide_resnet50_2.tv_in1k", # IMPORTANT: use .tv weights, not tv2
4545
layers: list[str] = ["layer2", "layer3"], # noqa: B006
4646
stop_grad: bool = True,
4747
) -> None:

0 commit comments

Comments
 (0)