Skip to content

Commit 754119a

Browse files
authored
πŸ”’ Security Fix: Mitigate Remote Code Execution Risk in torch.load (#2729)
* Security Fix: Mitigate Remote Code Execution Risk in Signed-off-by: samet-akcay <samet.akcay@intel.com> * Add semgrep ignore line Signed-off-by: samet-akcay <samet.akcay@intel.com> * Add env var to unit tests Signed-off-by: samet-akcay <samet.akcay@intel.com> * Add env var to integration tests Signed-off-by: samet-akcay <samet.akcay@intel.com> * add weights only to true explicitly in all usages Signed-off-by: samet-akcay <samet.akcay@intel.com> * add weights only to true explicitly in all usages Signed-off-by: samet-akcay <samet.akcay@intel.com> * Address torch.Tensor semgrep issue Signed-off-by: samet-akcay <samet.akcay@intel.com> --------- Signed-off-by: samet-akcay <samet.akcay@intel.com>
1 parent 53e2cf4 commit 754119a

File tree

10 files changed

+155
-52
lines changed

10 files changed

+155
-52
lines changed

β€Žsrc/anomalib/callbacks/model_loader.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,4 +81,6 @@ def setup(self, trainer: Trainer, pl_module: AnomalibModule, stage: str | None =
8181
del trainer, stage # These variables are not used.
8282

8383
logger.info("Loading the model from %s", self.weights_path)
84-
pl_module.load_state_dict(torch.load(self.weights_path, map_location=pl_module.device)["state_dict"])
84+
pl_module.load_state_dict(
85+
torch.load(self.weights_path, map_location=pl_module.device, weights_only=True)["state_dict"],
86+
)

β€Žsrc/anomalib/deploy/inferencers/torch_inferencer.py

Lines changed: 64 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,24 @@
33
This module provides the PyTorch inferencer implementation for running inference
44
with trained PyTorch models.
55
6+
.. warning::
7+
This is a legacy inferencer. It is recommended to use :class:`anomalib.engine.Engine.predict()`
8+
instead, which provides a more modern and feature-rich interface for model inference.
9+
10+
.. danger::
11+
**Security Notice**: PyTorch model loading uses Python's pickle module, which can execute code from the checkpoint
12+
file.This is a standard PyTorch behavior, not specific to this library. For security, load models only from trusted
13+
sources and consider using safer formats like ONNX or TorchScript for production use. To proceed with loading, set:
14+
15+
- Load models only from trusted sources
16+
- Consider using safer formats like ONNX or TorchScript for production use
17+
18+
To proceed with loading, set:
19+
20+
.. code-block:: bash
21+
22+
export TRUST_REMOTE_CODE=1
23+
624
Example:
725
Assume we have a PyTorch model saved as a ``.pt`` file:
826
@@ -36,6 +54,8 @@
3654
# Copyright (C) 2022-2025 Intel Corporation
3755
# SPDX-License-Identifier: Apache-2.0
3856

57+
import logging
58+
import os
3959
from pathlib import Path
4060

4161
import numpy as np
@@ -47,10 +67,30 @@
4767
from anomalib.data import ImageBatch
4868
from anomalib.data.utils import read_image
4969

70+
logger = logging.getLogger(__name__)
71+
5072

5173
class TorchInferencer:
5274
"""PyTorch inferencer for anomaly detection models.
5375
76+
.. warning::
77+
This is a legacy inferencer. It is recommended to use :class:`anomalib.engine.Engine.predict()`
78+
instead, which provides a more modern and feature-rich interface for model inference.
79+
80+
.. danger::
81+
**Security Notice**: PyTorch model loading uses Python's pickle module,
82+
which can execute code from the checkpoint file. This is a standard PyTorch behavior,
83+
not specific to this library. For security:
84+
85+
- Load models only from trusted sources
86+
- Consider using safer formats like ONNX or TorchScript for production use
87+
88+
To proceed with loading, set:
89+
90+
.. code-block:: bash
91+
92+
export TRUST_REMOTE_CODE=1
93+
5494
Args:
5595
path (str | Path): Path to the PyTorch model weights file.
5696
device (str, optional): Device to use for inference.
@@ -65,6 +105,7 @@ class TorchInferencer:
65105
Raises:
66106
ValueError: If an invalid device is specified.
67107
ValueError: If the model file has an unknown extension.
108+
ValueError: If TRUST_REMOTE_CODE environment variable is not set.
68109
KeyError: If the checkpoint file does not contain a model.
69110
"""
70111

@@ -73,6 +114,10 @@ def __init__(
73114
path: str | Path,
74115
device: str = "auto",
75116
) -> None:
117+
logger.warning(
118+
"TorchInferencer is a legacy inferencer. Consider using Engine.predict() instead, "
119+
"which provides a more modern and feature-rich interface for model inference.",
120+
)
76121
self.device = self._get_device(device)
77122

78123
# Load the model weights and metadata
@@ -118,6 +163,7 @@ def _load_checkpoint(self, path: str | Path) -> dict:
118163
119164
Raises:
120165
ValueError: If the model file has an unknown extension.
166+
ValueError: If TRUST_REMOTE_CODE environment variable is not set.
121167
122168
Example:
123169
>>> model = TorchInferencer(path="path/to/model.pt")
@@ -132,6 +178,23 @@ def _load_checkpoint(self, path: str | Path) -> dict:
132178
msg = f"Unknown PyTorch checkpoint format {path.suffix}. Make sure you save the PyTorch model."
133179
raise ValueError(msg)
134180

181+
trust_remote_code_enabled = os.environ.get("TRUST_REMOTE_CODE", "0").lower() in {"1", "true"}
182+
183+
if not trust_remote_code_enabled:
184+
msg = (
185+
"Loading this model checkpoint requires executing arbitrary code via Python's pickle module, "
186+
"which is disabled by default for security reasons. This can be exploited by malicious model files. "
187+
"If you trust the source of this model and understand the risks, "
188+
"set the environment variable `TRUST_REMOTE_CODE=1` to allow loading."
189+
)
190+
raise ValueError(msg)
191+
192+
logger.warning(
193+
"TRUST_REMOTE_CODE is set to True. Loading model using pickle module, "
194+
"which is inherently insecure and can lead to arbitrary code execution. "
195+
"Only set this to True if you TRUST the source of the checkpoint.",
196+
)
197+
# nosemgrep: trailofbits.python.pickles-in-pytorch.pickles-in-pytorch
135198
return torch.load(path, map_location=self.device, weights_only=False)
136199

137200
def load_model(self, path: str | Path) -> nn.Module:
@@ -184,10 +247,7 @@ def predict(self, image: str | Path | np.ndarray | PILImage | torch.Tensor) -> I
184247
image = self.pre_process(image)
185248
predictions = self.model(image)
186249

187-
return ImageBatch(
188-
image=image,
189-
**predictions._asdict(),
190-
)
250+
return ImageBatch(image=image, **predictions._asdict())
191251

192252
def pre_process(self, image: torch.Tensor) -> torch.Tensor:
193253
"""Pre-process the input image.

β€Žsrc/anomalib/models/__init__.py

Lines changed: 44 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
from omegaconf import DictConfig, OmegaConf
5353

5454
from anomalib.models.components import AnomalibModule
55-
from anomalib.utils.path import convert_to_snake_case
55+
from anomalib.utils.path import convert_snake_to_pascal_case, convert_to_snake_case, convert_to_title_case
5656

5757
from .image import (
5858
Cfa,
@@ -117,58 +117,63 @@ class UnknownModelError(ModuleNotFoundError):
117117
logger = logging.getLogger(__name__)
118118

119119

120-
def convert_snake_to_pascal_case(snake_case: str) -> str:
121-
"""Convert snake_case string to PascalCase.
120+
def list_models(case: str = "snake") -> set[str]:
121+
"""List available anomaly detection models.
122122
123-
This function takes a string in snake_case format (words separated by underscores)
124-
and converts it to PascalCase format (each word capitalized and concatenated).
123+
Returns a set of model names in the specified format that are available in the
124+
anomalib library. This includes both image and video anomaly detection models.
125125
126126
Args:
127-
snake_case (str): Input string in snake_case format (e.g. ``"efficient_ad"``)
127+
case (str): The format to return model names in. Options are:
128+
- "snake_case": Returns names in snake_case format (e.g. "efficient_ad")
129+
- "original": Returns the original PascalCase class names (e.g. "EfficientAd")
130+
Defaults to "snake_case".
128131
129132
Returns:
130-
str: Output string in PascalCase format (e.g. ``"EfficientAd"``)
131-
132-
Examples:
133-
>>> convert_snake_to_pascal_case("efficient_ad")
134-
'EfficientAd'
135-
>>> convert_snake_to_pascal_case("patchcore")
136-
'Patchcore'
137-
>>> convert_snake_to_pascal_case("reverse_distillation")
138-
'ReverseDistillation'
139-
"""
140-
return "".join(word.capitalize() for word in snake_case.split("_"))
141-
142-
143-
def get_available_models() -> set[str]:
144-
"""Get set of available anomaly detection models.
145-
146-
Returns a set of model names in snake_case format that are available in the
147-
anomalib library. This includes both image and video anomaly detection models.
148-
149-
Returns:
150-
set[str]: Set of available model names in snake_case format (e.g.
151-
``'efficient_ad'``, ``'padim'``, etc.)
133+
set[str]: Set of available model names in the specified format.
152134
153135
Example:
154-
Get all available models:
136+
Get all available models in different formats:
155137
156-
>>> from anomalib.models import get_available_models
157-
>>> models = get_available_models()
138+
>>> from anomalib.models import list_models
139+
>>> # Get models in snake_case format
140+
>>> models = list_models(case="snake")
158141
>>> print(sorted(list(models))) # doctest: +NORMALIZE_WHITESPACE
159142
['ai_vad', 'cfa', 'cflow', 'csflow', 'dfkde', 'dfm', 'draem',
160143
'efficient_ad', 'fastflow', 'fre', 'ganomaly', 'padim', 'patchcore',
161144
'reverse_distillation', 'stfpm', 'uflow', 'vlm_ad', 'winclip']
162145
146+
>>> # Get models in original PascalCase format
147+
>>> models = list_models(case="pascal")
148+
>>> print(sorted(list(models))) # doctest: +NORMALIZE_WHITESPACE
149+
['AiVad', 'Cfa', 'Cflow', 'Csflow', 'Dfkde', 'Dfm', 'Draem',
150+
'EfficientAd', 'Fastflow', 'Fre', 'Ganomaly', 'Padim', 'Patchcore',
151+
'ReverseDistillation', 'Stfpm', 'Uflow', 'VlmAd', 'WinClip']
152+
153+
>>> # Get models in title case format
154+
>>> models = list_models(case="title")
155+
>>> print(sorted(list(models))) # doctest: +NORMALIZE_WHITESPACE
156+
['Ai Vad', 'Cfa', 'Cflow', 'Csflow', 'Dfkde', 'Dfm', 'Draem',
157+
'Efficient Ad', 'Fastflow', 'Fre', 'Ganomaly', 'Padim', 'Patchcore',
158+
'Reverse Distillation', 'Stfpm', 'Uflow', 'Vlm Ad', 'Win Clip']
159+
163160
Note:
164161
The returned model names can be used with :func:`get_model` to instantiate
165162
the corresponding model class.
166163
"""
167-
return {
168-
convert_to_snake_case(cls.__name__)
169-
for cls in AnomalibModule.__subclasses__()
170-
if cls.__name__ != "AnomalyModule"
171-
}
164+
if case not in {"snake", "pascal", "title"}:
165+
msg = f"Unsupported format: {case}. Must be one of: snake, pascal, title"
166+
raise ValueError(msg)
167+
168+
models = {cls.__name__ for cls in AnomalibModule.__subclasses__() if cls.__name__ != "AnomalyModule"}
169+
170+
if case == "snake":
171+
return {convert_to_snake_case(name) for name in models}
172+
173+
if case == "title":
174+
return {convert_to_title_case(name) for name in models}
175+
176+
return models
172177

173178

174179
def _get_model_class_by_name(name: str) -> type[AnomalibModule]:
@@ -207,7 +212,7 @@ def _get_model_class_by_name(name: str) -> type[AnomalibModule]:
207212
if name == model.__name__.lower():
208213
model_class = model
209214
if model_class is None:
210-
logger.exception(f"Could not find the model {name}. Available models are {get_available_models()}")
215+
logger.exception(f"Could not find the model {name}. Available models are {list_models()}")
211216
raise UnknownModelError
212217

213218
return model_class
@@ -288,7 +293,7 @@ def get_model(model: DictConfig | str | dict | Namespace, *args, **kwdargs) -> A
288293
module = import_module("anomalib.models")
289294
except ModuleNotFoundError as exception:
290295
logger.exception(
291-
f"Could not find the module {model.class_path}. Available models are {get_available_models()}",
296+
f"Could not find the module {model.class_path}. Available models are {list_models()}",
292297
)
293298
raise UnknownModelError from exception
294299
try:
@@ -299,7 +304,7 @@ def get_model(model: DictConfig | str | dict | Namespace, *args, **kwdargs) -> A
299304
_model = model_class(*args, **init_args)
300305
except AttributeError as exception:
301306
logger.exception(
302-
f"Could not find the model {model.class_path}. Available models are {get_available_models()}",
307+
f"Could not find the model {model.class_path}. Available models are {list_models()}",
303308
)
304309
raise UnknownModelError from exception
305310
else:

β€Žsrc/anomalib/models/image/dsr/torch_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def load_pretrained_discrete_model_weights(self, ckpt: Path, device: torch.devic
128128
device (torch.device | str | None, optional): Device to load weights
129129
to. Defaults to ``None``.
130130
"""
131-
self.discrete_latent_model.load_state_dict(torch.load(ckpt, map_location=device))
131+
self.discrete_latent_model.load_state_dict(torch.load(ckpt, map_location=device, weights_only=True))
132132

133133
def forward(
134134
self,
@@ -906,7 +906,7 @@ def __init__(self, num_embeddings: int, embedding_dim: int) -> None:
906906

907907
# necessary to correctly load the checkpoint file
908908
self.register_buffer("_ema_cluster_size", torch.zeros(num_embeddings))
909-
self._ema_w = nn.Parameter(torch.Tensor(num_embeddings, self._embedding_dim))
909+
self._ema_w = nn.Parameter(torch.zeros(num_embeddings, self._embedding_dim))
910910
self._ema_w.data.normal_()
911911

912912
@property

β€Žsrc/anomalib/models/image/efficient_ad/lightning_model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,9 @@ def prepare_pretrained_model(self) -> None:
167167
pretrained_models_dir / "efficientad_pretrained_weights" / f"pretrained_teacher_{model_size_str}.pth"
168168
)
169169
logger.info(f"Load pretrained teacher model from {teacher_path}")
170-
self.model.teacher.load_state_dict(torch.load(teacher_path, map_location=torch.device(self.device)))
170+
self.model.teacher.load_state_dict(
171+
torch.load(teacher_path, map_location=torch.device(self.device), weights_only=True),
172+
)
171173

172174
def prepare_imagenette_data(self, image_size: tuple[int, int] | torch.Size) -> None:
173175
"""Prepare ImageNette dataset transformations.

β€Žsrc/anomalib/utils/path.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,29 @@ def convert_to_snake_case(s: str) -> str:
168168
return re.sub(r"__+", "_", s)
169169

170170

171+
def convert_snake_to_pascal_case(snake_case: str) -> str:
172+
"""Convert snake_case string to PascalCase.
173+
174+
This function takes a string in snake_case format (words separated by underscores)
175+
and converts it to PascalCase format (each word capitalized and concatenated).
176+
177+
Args:
178+
snake_case (str): Input string in snake_case format (e.g. ``"efficient_ad"``)
179+
180+
Returns:
181+
str: Output string in PascalCase format (e.g. ``"EfficientAd"``)
182+
183+
Examples:
184+
>>> convert_snake_to_pascal_case("efficient_ad")
185+
'EfficientAd'
186+
>>> convert_snake_to_pascal_case("patchcore")
187+
'Patchcore'
188+
>>> convert_snake_to_pascal_case("reverse_distillation")
189+
'ReverseDistillation'
190+
"""
191+
return "".join(word.capitalize() for word in snake_case.split("_"))
192+
193+
171194
def convert_to_title_case(text: str) -> str:
172195
"""Convert text to title case, handling various text formats.
173196

β€Žtests/integration/model/test_models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,12 @@
1717
from anomalib.data import AnomalibDataModule, MVTecAD
1818
from anomalib.deploy import ExportType
1919
from anomalib.engine import Engine
20-
from anomalib.models import AnomalibModule, get_available_models, get_model
20+
from anomalib.models import AnomalibModule, get_model, list_models
2121

2222

2323
def models() -> set[str]:
2424
"""Return all available models."""
25-
return get_available_models()
25+
return list_models()
2626

2727

2828
def export_types() -> list[ExportType]:

β€Žtests/integration/tools/test_gradio_entrypoint.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,12 @@ def get_functions() -> tuple[Callable, Callable]:
3838
def test_torch_inference(
3939
get_functions: tuple[Callable, Callable],
4040
ckpt_path: Callable[[str], Path],
41+
monkeypatch: pytest.MonkeyPatch,
4142
) -> None:
4243
"""Test gradio_inference.py."""
44+
# Set TRUST_REMOTE_CODE environment variable for the test
45+
monkeypatch.setenv("TRUST_REMOTE_CODE", "1")
46+
4347
_ckpt_path = ckpt_path("Padim")
4448
parser, inferencer = get_functions
4549
model = Padim.load_from_checkpoint(_ckpt_path)

β€Žtests/integration/tools/test_torch_entrypoint.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,12 @@ def test_torch_inference(
3535
project_path: Path,
3636
ckpt_path: Callable[[str], Path],
3737
get_dummy_inference_image: str,
38+
monkeypatch: pytest.MonkeyPatch,
3839
) -> None:
3940
"""Test torch_inference.py."""
41+
# Set TRUST_REMOTE_CODE environment variable for the test
42+
monkeypatch.setenv("TRUST_REMOTE_CODE", "1")
43+
4044
_ckpt_path = ckpt_path("Padim")
4145
get_parser, infer = get_functions
4246
model = Padim.load_from_checkpoint(_ckpt_path)

0 commit comments

Comments
Β (0)