Skip to content

Improve TorchScript model identification logic in torch_model.py #546

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 22, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 30 additions & 22 deletions monai/deploy/core/models/torch_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2021 MONAI Consortium
# Copyright 2021-2025 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand All @@ -24,11 +24,12 @@ class TorchScriptModel(Model):
"""Represents TorchScript model.

TorchScript serialization format (TorchScript model file) is created by torch.jit.save() method and
the serialized model (which usually has .pt or .pth extension) is a ZIP archive containing many files.
the serialized model (which usually has .pt or .pth extension) is a ZIP archive.
(https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/docs/serialization.md)

We consider that the model is a torchscript model if its unzipped archive contains files named 'data.pkl' and
'constants.pkl', and folders named 'code' and 'data'.
We identify a file as a TorchScript model if its unzipped archive contains a 'code/' directory
and a 'data.pkl' file. For tensor constants, it may contain either a 'constants.pkl' file (older format)
or a 'constants/' directory (newer format).

When predictor property is accessed or the object is called (__call__), the model is loaded in `evaluation mode`
from the serialized model file (if it is not loaded yet) and the model is ready to be used.
Expand Down Expand Up @@ -85,31 +86,38 @@ def train(self, mode: bool = True) -> "TorchScriptModel":

@classmethod
def accept(cls, path: str):
prefix_code = False
prefix_data = False
prefix_constants_pkl = False
prefix_data_pkl = False
# These are the files and directories we expect to find in a TorchScript zip archive.
# See: https://github.com/pytorch/pytorch/blob/main/torch/csrc/jit/docs/serialization.md
has_code_dir = False
has_constants_dir = False
has_constants_pkl = False
has_data_pkl = False

if not os.path.isfile(path):
return False, None

try:
zip_file = ZipFile(path)
for data in zip_file.filelist:
file_name = data.filename
pivot = file_name.find("/")
if pivot != -1 and not prefix_code and file_name[pivot:].startswith("/code/"):
prefix_code = True
if pivot != -1 and not prefix_data and file_name[pivot:].startswith("/data/"):
prefix_data = True
if pivot != -1 and not prefix_constants_pkl and file_name[pivot:] == "/constants.pkl":
prefix_constants_pkl = True
if pivot != -1 and not prefix_data_pkl and file_name[pivot:] == "/data.pkl":
prefix_data_pkl = True
except BadZipFile:
with ZipFile(path) as zip_file:
# Top-level directory name in the zip file (e.g., 'model_name/')
top_level_dir = ""
if "/" in zip_file.filelist[0].filename:
top_level_dir = zip_file.filelist[0].filename.split("/", 1)[0] + "/"

filenames = {f.filename for f in zip_file.filelist}

# Check for required files and directories
has_data_pkl = (top_level_dir + "data.pkl") in filenames
has_code_dir = any(f.startswith(top_level_dir + "code/") for f in filenames)

# Check for either constants.pkl (older format) or constants/ (newer format)
has_constants_pkl = (top_level_dir + "constants.pkl") in filenames
has_constants_dir = any(f.startswith(top_level_dir + "constants/") for f in filenames)

except (BadZipFile, IndexError):
return False, None

if prefix_code and prefix_data and prefix_constants_pkl and prefix_data_pkl:
# A valid TorchScript model must have code/, data.pkl, and either constants.pkl or constants/
if has_code_dir and has_data_pkl and (has_constants_pkl or has_constants_dir):
return True, cls.model_type

return False, None