1
- # Copyright 2021 MONAI Consortium
1
+ # Copyright 2021-2025 MONAI Consortium
2
2
# Licensed under the Apache License, Version 2.0 (the "License");
3
3
# you may not use this file except in compliance with the License.
4
4
# You may obtain a copy of the License at
@@ -24,11 +24,12 @@ class TorchScriptModel(Model):
24
24
"""Represents TorchScript model.
25
25
26
26
TorchScript serialization format (TorchScript model file) is created by torch.jit.save() method and
27
- the serialized model (which usually has .pt or .pth extension) is a ZIP archive containing many files .
27
+ the serialized model (which usually has .pt or .pth extension) is a ZIP archive.
28
28
(https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/docs/serialization.md)
29
29
30
- We consider that the model is a torchscript model if its unzipped archive contains files named 'data.pkl' and
31
- 'constants.pkl', and folders named 'code' and 'data'.
30
+ We identify a file as a TorchScript model if its unzipped archive contains a 'code/' directory
31
+ and a 'data.pkl' file. For tensor constants, it may contain either a 'constants.pkl' file (older format)
32
+ or a 'constants/' directory (newer format).
32
33
33
34
When predictor property is accessed or the object is called (__call__), the model is loaded in `evaluation mode`
34
35
from the serialized model file (if it is not loaded yet) and the model is ready to be used.
@@ -85,31 +86,38 @@ def train(self, mode: bool = True) -> "TorchScriptModel":
85
86
86
87
@classmethod
87
88
def accept (cls , path : str ):
88
- prefix_code = False
89
- prefix_data = False
90
- prefix_constants_pkl = False
91
- prefix_data_pkl = False
89
+ # These are the files and directories we expect to find in a TorchScript zip archive.
90
+ # See: https://github.com/pytorch/pytorch/blob/main/torch/csrc/jit/docs/serialization.md
91
+ has_code_dir = False
92
+ has_constants_dir = False
93
+ has_constants_pkl = False
94
+ has_data_pkl = False
92
95
93
96
if not os .path .isfile (path ):
94
97
return False , None
95
98
96
99
try :
97
- zip_file = ZipFile (path )
98
- for data in zip_file .filelist :
99
- file_name = data .filename
100
- pivot = file_name .find ("/" )
101
- if pivot != - 1 and not prefix_code and file_name [pivot :].startswith ("/code/" ):
102
- prefix_code = True
103
- if pivot != - 1 and not prefix_data and file_name [pivot :].startswith ("/data/" ):
104
- prefix_data = True
105
- if pivot != - 1 and not prefix_constants_pkl and file_name [pivot :] == "/constants.pkl" :
106
- prefix_constants_pkl = True
107
- if pivot != - 1 and not prefix_data_pkl and file_name [pivot :] == "/data.pkl" :
108
- prefix_data_pkl = True
109
- except BadZipFile :
100
+ with ZipFile (path ) as zip_file :
101
+ # Top-level directory name in the zip file (e.g., 'model_name/')
102
+ top_level_dir = ""
103
+ if "/" in zip_file .filelist [0 ].filename :
104
+ top_level_dir = zip_file .filelist [0 ].filename .split ("/" , 1 )[0 ] + "/"
105
+
106
+ filenames = {f .filename for f in zip_file .filelist }
107
+
108
+ # Check for required files and directories
109
+ has_data_pkl = (top_level_dir + "data.pkl" ) in filenames
110
+ has_code_dir = any (f .startswith (top_level_dir + "code/" ) for f in filenames )
111
+
112
+ # Check for either constants.pkl (older format) or constants/ (newer format)
113
+ has_constants_pkl = (top_level_dir + "constants.pkl" ) in filenames
114
+ has_constants_dir = any (f .startswith (top_level_dir + "constants/" ) for f in filenames )
115
+
116
+ except (BadZipFile , IndexError ):
110
117
return False , None
111
118
112
- if prefix_code and prefix_data and prefix_constants_pkl and prefix_data_pkl :
119
+ # A valid TorchScript model must have code/, data.pkl, and either constants.pkl or constants/
120
+ if has_code_dir and has_data_pkl and (has_constants_pkl or has_constants_dir ):
113
121
return True , cls .model_type
114
122
115
123
return False , None
0 commit comments