diff --git a/test/utils/test_io.py b/test/utils/test_io.py index cf8d29586..6623e0c4b 100644 --- a/test/utils/test_io.py +++ b/test/utils/test_io.py @@ -3,8 +3,10 @@ import shutil import tempfile +import pytest + import torch_frame -from torch_frame import load, save +from torch_frame import TensorFrame, load, save from torch_frame.config.text_embedder import TextEmbedderConfig from torch_frame.config.text_tokenizer import TextTokenizerConfig from torch_frame.datasets import FakeDataset @@ -114,3 +116,21 @@ def test_save_load_tensor_frame(): tf, col_stats = load(path) assert dataset.col_stats == col_stats assert dataset.tensor_frame == tf + + +class UntrustedClass: + pass + + +@pytest.mark.skipif( + not torch_frame.typing.WITH_PT24, + reason='Requres PyTorch 2.4', +) +def test_load_weights_only_gracefully(tmpdir): + save( + tensor_frame=TensorFrame({}, {}), + col_stats={'a': UntrustedClass()}, + path=tmpdir.join('tf.pt'), + ) + with pytest.warns(UserWarning, match='Weights only load failed'): + load(tmpdir.join('tf.pt')) diff --git a/torch_frame/utils/io.py b/torch_frame/utils/io.py index 50732d4a6..f10c8f475 100644 --- a/torch_frame/utils/io.py +++ b/torch_frame/utils/io.py @@ -1,5 +1,8 @@ from __future__ import annotations +import pickle +import re +import warnings from typing import Any import torch @@ -13,7 +16,7 @@ ) from torch_frame.data.multi_tensor import _MultiTensor from torch_frame.data.stats import StatType -from torch_frame.typing import WITH_PT24, TensorData +from torch_frame.typing import TensorData def serialize_feat_dict( @@ -96,9 +99,30 @@ def load( tuple: A tuple of loaded :class:`TensorFrame` object and optional :obj:`col_stats`. """ - tf_dict, col_stats = torch.load(path, weights_only=WITH_PT24) + if torch_frame.typing.WITH_PT24: + try: + tf_dict, col_stats = torch.load(path, weights_only=True) + except pickle.UnpicklingError as e: + error_msg = str(e) + if "add_safe_globals" in error_msg: + warn_msg = ("Weights only load failed. Please file an issue " + "to make `torch.load(weights_only=True)` " + "compatible in your case.") + match = re.search(r'add_safe_globals\(.*?\)', error_msg) + if match is not None: + warnings.warn(f"{warn_msg} Please use " + f"`torch.serialization.{match.group()}` to " + f"allowlist this global.") + else: + warnings.warn(warn_msg) + + tf_dict, col_stats = torch.load(path, weights_only=False) + else: + raise e + else: + tf_dict, col_stats = torch.load(path, weights_only=False) + tf_dict['feat_dict'] = deserialize_feat_dict( tf_dict.pop('feat_serialized_dict')) - tensor_frame = TensorFrame(**tf_dict) - tensor_frame.to(device) + tensor_frame = TensorFrame(**tf_dict).to(device) return tensor_frame, col_stats