From e05c2dddd8e47fcf5630f686ed51f901ecb245d1 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Mon, 9 Sep 2024 10:17:06 +0000 Subject: [PATCH 1/5] update --- test/utils/test_io.py | 18 +++++++++++++++++- torch_frame/__init__.py | 7 ++++++- torch_frame/utils/io.py | 30 ++++++++++++++++++++++++++---- 3 files changed, 49 insertions(+), 6 deletions(-) diff --git a/test/utils/test_io.py b/test/utils/test_io.py index cf8d29586..7600a8b2e 100644 --- a/test/utils/test_io.py +++ b/test/utils/test_io.py @@ -3,8 +3,10 @@ import shutil import tempfile +import pytest + import torch_frame -from torch_frame import load, save +from torch_frame import TensorFrame, load, save from torch_frame.config.text_embedder import TextEmbedderConfig from torch_frame.config.text_tokenizer import TextTokenizerConfig from torch_frame.datasets import FakeDataset @@ -114,3 +116,17 @@ def test_save_load_tensor_frame(): tf, col_stats = load(path) assert dataset.col_stats == col_stats assert dataset.tensor_frame == tf + + +class UntrustedClass: + pass + + +def test_load_weights_only_gracefully(tmpdir): + save( + tensor_frame=TensorFrame({}, {}), + col_stats={'a': UntrustedClass()}, + path=tmpdir.join('tf.pt'), + ) + with pytest.warns(UserWarning, match='Weights only load failed'): + load(tmpdir.join('tf.pt')) diff --git a/torch_frame/__init__.py b/torch_frame/__init__.py index 7161acc38..aa3dce170 100644 --- a/torch_frame/__init__.py +++ b/torch_frame/__init__.py @@ -27,8 +27,13 @@ if WITH_PT24: import torch - + import numpy as np + import codecs torch.serialization.add_safe_globals([ + np._core.multiarray.scalar, + np.dtype, + np.dtypes.Int32DType, + codecs.encode, stype, torch_frame.data.stats.StatType, ]) diff --git a/torch_frame/utils/io.py b/torch_frame/utils/io.py index 50732d4a6..8fd51c6b1 100644 --- a/torch_frame/utils/io.py +++ b/torch_frame/utils/io.py @@ -1,5 +1,8 @@ from __future__ import annotations +import pickle +import re +import warnings from typing import Any import torch @@ -13,7 +16,7 @@ ) from torch_frame.data.multi_tensor import _MultiTensor from torch_frame.data.stats import StatType -from torch_frame.typing import WITH_PT24, TensorData +from torch_frame.typing import TensorData def serialize_feat_dict( @@ -96,9 +99,28 @@ def load( tuple: A tuple of loaded :class:`TensorFrame` object and optional :obj:`col_stats`. """ - tf_dict, col_stats = torch.load(path, weights_only=WITH_PT24) + if torch_frame.typing.WITH_PT24: + try: + tf_dict, col_stats = torch.load(path, weights_only=True) + except pickle.UnpicklingError as e: + error_msg = str(e) + if "add_safe_globals" in error_msg: + warn_msg = ("Weights only load failed. Please file an issue " + "to make `torch.load(weights_only=True)` " + "compatible in your case.") + match = re.search(r'add_safe_globals\(.*?\)', error_msg) + if match is not None: + warnings.warn(f"{warn_msg} Please use " + f"`torch.serialization.{match.group()}` to " + f"allowlist this global.") + else: + warnings.warn(warn_msg) + + tf_dict, col_stats = torch.load(path, weights_only=False) + else: + raise e + tf_dict['feat_dict'] = deserialize_feat_dict( tf_dict.pop('feat_serialized_dict')) - tensor_frame = TensorFrame(**tf_dict) - tensor_frame.to(device) + tensor_frame = TensorFrame(**tf_dict).to(device) return tensor_frame, col_stats From 8395116f6d9b4563a4d02e4e86dbb32146289010 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Mon, 9 Sep 2024 10:19:59 +0000 Subject: [PATCH 2/5] update --- test/utils/test_io.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/utils/test_io.py b/test/utils/test_io.py index 7600a8b2e..6623e0c4b 100644 --- a/test/utils/test_io.py +++ b/test/utils/test_io.py @@ -122,6 +122,10 @@ class UntrustedClass: pass +@pytest.mark.skipif( + not torch_frame.typing.WITH_PT24, + reason='Requres PyTorch 2.4', +) def test_load_weights_only_gracefully(tmpdir): save( tensor_frame=TensorFrame({}, {}), From 9bb05856f7c14beca09dba2790338f71b85b3bd6 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Mon, 9 Sep 2024 10:20:33 +0000 Subject: [PATCH 3/5] . --- torch_frame/utils/io.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torch_frame/utils/io.py b/torch_frame/utils/io.py index 8fd51c6b1..f10c8f475 100644 --- a/torch_frame/utils/io.py +++ b/torch_frame/utils/io.py @@ -119,6 +119,8 @@ def load( tf_dict, col_stats = torch.load(path, weights_only=False) else: raise e + else: + tf_dict, col_stats = torch.load(path, weights_only=False) tf_dict['feat_dict'] = deserialize_feat_dict( tf_dict.pop('feat_serialized_dict')) From 4e4c3ea5cccd26328952a1bfa8dfb798064bed95 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Mon, 9 Sep 2024 10:23:56 +0000 Subject: [PATCH 4/5] . --- torch_frame/__init__.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/torch_frame/__init__.py b/torch_frame/__init__.py index aa3dce170..182e0dbfe 100644 --- a/torch_frame/__init__.py +++ b/torch_frame/__init__.py @@ -27,13 +27,7 @@ if WITH_PT24: import torch - import numpy as np - import codecs torch.serialization.add_safe_globals([ - np._core.multiarray.scalar, - np.dtype, - np.dtypes.Int32DType, - codecs.encode, stype, torch_frame.data.stats.StatType, ]) From 7a0226bb2032eade080276ff815a04f1bde21543 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Mon, 9 Sep 2024 10:27:16 +0000 Subject: [PATCH 5/5] . --- torch_frame/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torch_frame/__init__.py b/torch_frame/__init__.py index 182e0dbfe..7161acc38 100644 --- a/torch_frame/__init__.py +++ b/torch_frame/__init__.py @@ -27,6 +27,7 @@ if WITH_PT24: import torch + torch.serialization.add_safe_globals([ stype, torch_frame.data.stats.StatType,