diff --git a/efficientnet_pytorch/model.py b/efficientnet_pytorch/model.py index ce850cd..5f4ee78 100755 --- a/efficientnet_pytorch/model.py +++ b/efficientnet_pytorch/model.py @@ -376,7 +376,7 @@ def from_pretrained(cls, model_name, weights_path=None, advprop=False, """ model = cls.from_name(model_name, num_classes=num_classes, **override_params) load_pretrained_weights(model, model_name, weights_path=weights_path, - load_fc=(num_classes == 1000), advprop=advprop) + load_fc=(num_classes == 1000) and model._global_params.include_top, advprop=advprop) model._change_in_channels(in_channels) return model diff --git a/efficientnet_pytorch/utils.py b/efficientnet_pytorch/utils.py index 826a627..c95317e 100755 --- a/efficientnet_pytorch/utils.py +++ b/efficientnet_pytorch/utils.py @@ -608,7 +608,7 @@ def load_pretrained_weights(model, model_name, weights_path=None, load_fc=True, state_dict.pop('_fc.weight') state_dict.pop('_fc.bias') ret = model.load_state_dict(state_dict, strict=False) - assert set(ret.missing_keys) == set( + assert not ret.missing_keys or set(ret.missing_keys) == set( ['_fc.weight', '_fc.bias']), 'Missing keys when loading pretrained weights: {}'.format(ret.missing_keys) assert not ret.unexpected_keys, 'Missing keys when loading pretrained weights: {}'.format(ret.unexpected_keys)