diff --git a/model.py b/model.py index bc9ba5c6..9c7a08d1 100644 --- a/model.py +++ b/model.py @@ -11,7 +11,7 @@ is_new_version = LooseVersion(torchvision.__version__) >= LooseVersion("0.13.0") if is_new_version: - from torchvision.models import ResNet50_Weights, DenseNet121_Weights + from torchvision.models import ResNet34_Weights, DenseNet121_Weights else: pass @@ -65,7 +65,7 @@ def __init__(self, backbone='resnet50', pretrained=True, weights=None): if is_old_version: self.encoder = getattr(models, backbone)(pretrained=pretrained) elif is_new_version: - self.encoder = getattr(models, backbone)(weights=ResNet50_Weights.IMAGENET1K_V1) + self.encoder = getattr(models, backbone)(weights=ResNet34_Weights.IMAGENET1K_V1) del self.encoder.fc, self.encoder.avgpool def forward(self, x):