Skip to content

Commit dec9305

Browse files
committed
Improve TorchScript model identification logic in torch_model.py
- Updated copyright notice to reflect the years 2021-2025. - Enhanced the logic for identifying TorchScript models by clarifying the required files and directories in the zip archive format. Signed-off-by: Gigon Bae <gbae@nvidia.com>
1 parent 3fb64ee commit dec9305

File tree

1 file changed

+30
-22
lines changed

1 file changed

+30
-22
lines changed

monai/deploy/core/models/torch_model.py

Lines changed: 30 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2021 MONAI Consortium
1+
# Copyright 2021-2025 MONAI Consortium
22
# Licensed under the Apache License, Version 2.0 (the "License");
33
# you may not use this file except in compliance with the License.
44
# You may obtain a copy of the License at
@@ -24,11 +24,12 @@ class TorchScriptModel(Model):
2424
"""Represents TorchScript model.
2525
2626
TorchScript serialization format (TorchScript model file) is created by torch.jit.save() method and
27-
the serialized model (which usually has .pt or .pth extension) is a ZIP archive containing many files.
27+
the serialized model (which usually has .pt or .pth extension) is a ZIP archive.
2828
(https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/docs/serialization.md)
2929
30-
We consider that the model is a torchscript model if its unzipped archive contains files named 'data.pkl' and
31-
'constants.pkl', and folders named 'code' and 'data'.
30+
We identify a file as a TorchScript model if its unzipped archive contains a 'code/' directory
31+
and a 'data.pkl' file. For tensor constants, it may contain either a 'constants.pkl' file (older format)
32+
or a 'constants/' directory (newer format).
3233
3334
When predictor property is accessed or the object is called (__call__), the model is loaded in `evaluation mode`
3435
from the serialized model file (if it is not loaded yet) and the model is ready to be used.
@@ -85,31 +86,38 @@ def train(self, mode: bool = True) -> "TorchScriptModel":
8586

8687
@classmethod
8788
def accept(cls, path: str):
88-
prefix_code = False
89-
prefix_data = False
90-
prefix_constants_pkl = False
91-
prefix_data_pkl = False
89+
# These are the files and directories we expect to find in a TorchScript zip archive.
90+
# See: https://github.com/pytorch/pytorch/blob/main/torch/csrc/jit/docs/serialization.md
91+
has_code_dir = False
92+
has_constants_dir = False
93+
has_constants_pkl = False
94+
has_data_pkl = False
9295

9396
if not os.path.isfile(path):
9497
return False, None
9598

9699
try:
97-
zip_file = ZipFile(path)
98-
for data in zip_file.filelist:
99-
file_name = data.filename
100-
pivot = file_name.find("/")
101-
if pivot != -1 and not prefix_code and file_name[pivot:].startswith("/code/"):
102-
prefix_code = True
103-
if pivot != -1 and not prefix_data and file_name[pivot:].startswith("/data/"):
104-
prefix_data = True
105-
if pivot != -1 and not prefix_constants_pkl and file_name[pivot:] == "/constants.pkl":
106-
prefix_constants_pkl = True
107-
if pivot != -1 and not prefix_data_pkl and file_name[pivot:] == "/data.pkl":
108-
prefix_data_pkl = True
109-
except BadZipFile:
100+
with ZipFile(path) as zip_file:
101+
# Top-level directory name in the zip file (e.g., 'model_name/')
102+
top_level_dir = ""
103+
if "/" in zip_file.filelist[0].filename:
104+
top_level_dir = zip_file.filelist[0].filename.split("/", 1)[0] + "/"
105+
106+
filenames = {f.filename for f in zip_file.filelist}
107+
108+
# Check for required files and directories
109+
has_data_pkl = (top_level_dir + "data.pkl") in filenames
110+
has_code_dir = any(f.startswith(top_level_dir + "code/") for f in filenames)
111+
112+
# Check for either constants.pkl (older format) or constants/ (newer format)
113+
has_constants_pkl = (top_level_dir + "constants.pkl") in filenames
114+
has_constants_dir = any(f.startswith(top_level_dir + "constants/") for f in filenames)
115+
116+
except (BadZipFile, IndexError):
110117
return False, None
111118

112-
if prefix_code and prefix_data and prefix_constants_pkl and prefix_data_pkl:
119+
# A valid TorchScript model must have code/, data.pkl, and either constants.pkl or constants/
120+
if has_code_dir and has_data_pkl and (has_constants_pkl or has_constants_dir):
113121
return True, cls.model_type
114122

115123
return False, None

0 commit comments

Comments
 (0)