Skip to content

Commit 5c38a67

Browse files
authored
add weights_only argument for torch.load (#3997)
1 parent c5893b1 commit 5c38a67

File tree

2 files changed

+10
-3
lines changed

2 files changed

+10
-3
lines changed

models/public/regnetx-3.2gf/model.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,18 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import pycls.core.checkpoint
15+
import torch
16+
import pycls.core.config
1617
import pycls.models.model_zoo
18+
from pycls.core.checkpoint import unwrap_model
1719

1820
def regnet(config_path, weights_path):
1921
pycls.core.config.cfg.merge_from_file(config_path)
2022
model = pycls.models.model_zoo.RegNet()
21-
pycls.core.checkpoint.load_checkpoint(weights_path, model)
23+
checkpoint = torch.load(weights_path, map_location="cpu", weights_only=False)
24+
test_err = checkpoint.get("test_err", 100)
25+
ema_err = checkpoint.get("ema_err", 100)
26+
ema_state = "ema_state" if "ema_state" in checkpoint else "model_state"
27+
best_state = "model_state" if test_err <= ema_err else ema_state
28+
unwrap_model(model).load_state_dict(checkpoint[best_state])
2229
return model

tools/model_tools/src/omz_tools/internal_scripts/pytorch_to_onnx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def load_model(model_name, weights, model_paths, module_name, model_params):
144144

145145
try:
146146
if weights:
147-
model.load_state_dict(torch.load(weights, map_location='cpu'))
147+
model.load_state_dict(torch.load(weights, map_location='cpu', weights_only=False))
148148
except RuntimeError as err:
149149
print('ERROR: Weights from {} cannot be loaded for model {}! Check matching between model and weights'.format(
150150
weights, model_name))

0 commit comments

Comments
 (0)