From fbcba9d900f1a8425e073b905ba3862740625fa0 Mon Sep 17 00:00:00 2001 From: NeelKondapalli Date: Wed, 31 Jul 2024 18:45:51 +0530 Subject: [PATCH 1/6] Issue 422: Fixed warning and added safe globals --- torch_frame/utils/io.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torch_frame/utils/io.py b/torch_frame/utils/io.py index 5a22f3e8e..950419d23 100644 --- a/torch_frame/utils/io.py +++ b/torch_frame/utils/io.py @@ -15,6 +15,8 @@ from torch_frame.data.stats import StatType from torch_frame.typing import TensorData +torch.serialization.add_safe_globals([torch_frame.stype]) +torch.serialization.add_safe_globals([StatType]) def serialize_feat_dict( feat_dict: dict[torch_frame.stype, TensorData] @@ -95,7 +97,7 @@ def load( tuple: A tuple of loaded :class:`TensorFrame` object and optional :obj:`col_stats`. """ - tf_dict, col_stats = torch.load(path) + tf_dict, col_stats = torch.load(path, weights_only = True) tf_dict['feat_dict'] = deserialize_feat_dict( tf_dict.pop('feat_serialized_dict')) tensor_frame = TensorFrame(**tf_dict) From 9ef2da0f18409604ef2fd82caea1c14d420a57c0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 31 Jul 2024 13:21:04 +0000 Subject: [PATCH 2/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torch_frame/utils/io.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch_frame/utils/io.py b/torch_frame/utils/io.py index 950419d23..b2c47828f 100644 --- a/torch_frame/utils/io.py +++ b/torch_frame/utils/io.py @@ -18,6 +18,7 @@ torch.serialization.add_safe_globals([torch_frame.stype]) torch.serialization.add_safe_globals([StatType]) + def serialize_feat_dict( feat_dict: dict[torch_frame.stype, TensorData] ) -> dict[torch_frame.stype, Any]: @@ -97,7 +98,7 @@ def load( tuple: A tuple of loaded :class:`TensorFrame` object and optional :obj:`col_stats`. """ - tf_dict, col_stats = torch.load(path, weights_only = True) + tf_dict, col_stats = torch.load(path, weights_only=True) tf_dict['feat_dict'] = deserialize_feat_dict( tf_dict.pop('feat_serialized_dict')) tensor_frame = TensorFrame(**tf_dict) From 4ce8e5c21b21137d96b0ebbe614464704e7cfb49 Mon Sep 17 00:00:00 2001 From: NeelKondapalli Date: Thu, 1 Aug 2024 13:02:04 +0530 Subject: [PATCH 3/6] Removed serialization.add_safe_globals logic not available in torch 2.2.0 --- torch_frame/utils/io.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/torch_frame/utils/io.py b/torch_frame/utils/io.py index b2c47828f..5fa1a477a 100644 --- a/torch_frame/utils/io.py +++ b/torch_frame/utils/io.py @@ -15,9 +15,6 @@ from torch_frame.data.stats import StatType from torch_frame.typing import TensorData -torch.serialization.add_safe_globals([torch_frame.stype]) -torch.serialization.add_safe_globals([StatType]) - def serialize_feat_dict( feat_dict: dict[torch_frame.stype, TensorData] @@ -98,7 +95,7 @@ def load( tuple: A tuple of loaded :class:`TensorFrame` object and optional :obj:`col_stats`. """ - tf_dict, col_stats = torch.load(path, weights_only=True) + tf_dict, col_stats = torch.load(path, weights_only=False) tf_dict['feat_dict'] = deserialize_feat_dict( tf_dict.pop('feat_serialized_dict')) tensor_frame = TensorFrame(**tf_dict) From f072bcfa4c40dc919e2fb5b930744a1521fdceb3 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Sat, 24 Aug 2024 15:33:45 +0000 Subject: [PATCH 4/6] update --- torch_frame/__init__.py | 16 +++++++++++++++- torch_frame/typing.py | 4 ++++ torch_frame/utils/io.py | 5 +++-- 3 files changed, 22 insertions(+), 3 deletions(-) diff --git a/torch_frame/__init__.py b/torch_frame/__init__.py index 7aa3f9904..7161acc38 100644 --- a/torch_frame/__init__.py +++ b/torch_frame/__init__.py @@ -12,13 +12,27 @@ embedding, ) from .data import TensorFrame -from .typing import TaskType, Metric, DataFrame, NAStrategy +from .typing import ( + TaskType, + Metric, + DataFrame, + NAStrategy, + WITH_PT24, +) from torch_frame.utils import save, load, cat # noqa import torch_frame.data # noqa import torch_frame.datasets # noqa import torch_frame.nn # noqa import torch_frame.gbdt # noqa +if WITH_PT24: + import torch + + torch.serialization.add_safe_globals([ + stype, + torch_frame.data.stats.StatType, + ]) + __version__ = '0.2.3' __all__ = [ diff --git a/torch_frame/typing.py b/torch_frame/typing.py index a2e49159d..c7aede63a 100644 --- a/torch_frame/typing.py +++ b/torch_frame/typing.py @@ -4,11 +4,15 @@ from typing import Dict, List, Mapping, Union import pandas as pd +import torch from torch import Tensor from torch_frame.data.multi_embedding_tensor import MultiEmbeddingTensor from torch_frame.data.multi_nested_tensor import MultiNestedTensor +WITH_PT20 = int(torch.__version__.split('.')[0]) >= 2 +WITH_PT24 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 4 + class Metric(Enum): r"""The metric. diff --git a/torch_frame/utils/io.py b/torch_frame/utils/io.py index 5fa1a477a..2f98a1e80 100644 --- a/torch_frame/utils/io.py +++ b/torch_frame/utils/io.py @@ -80,7 +80,8 @@ def save(tensor_frame: TensorFrame, def load( - path: str, device: torch.device | None = None + path: str, + device: torch.device | None = None, ) -> tuple[TensorFrame, dict[str, dict[StatType, Any]] | None]: r"""Load saved :class:`TensorFrame` object and optional :obj:`col_stats` from a specified path. @@ -95,7 +96,7 @@ def load( tuple: A tuple of loaded :class:`TensorFrame` object and optional :obj:`col_stats`. """ - tf_dict, col_stats = torch.load(path, weights_only=False) + tf_dict, col_stats = torch.load(path, weights_only=True) tf_dict['feat_dict'] = deserialize_feat_dict( tf_dict.pop('feat_serialized_dict')) tensor_frame = TensorFrame(**tf_dict) From 54dab53fda29b67116f0214a74279c51e9c1e2b5 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Sat, 24 Aug 2024 15:53:03 +0000 Subject: [PATCH 5/6] update --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9bbd92324..707ba8d74 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed +- Set `weights_only=True` in `torch_frame.load` from PyTorch 2.4 ([#423](https://github.com/pyg-team/pytorch-frame/pull/423)) + ### Deprecated ### Removed From febaeae7bc1ec075daf31853535e5c0a11c33cc5 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Sat, 24 Aug 2024 15:54:24 +0000 Subject: [PATCH 6/6] update --- torch_frame/utils/io.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_frame/utils/io.py b/torch_frame/utils/io.py index 2f98a1e80..50732d4a6 100644 --- a/torch_frame/utils/io.py +++ b/torch_frame/utils/io.py @@ -13,7 +13,7 @@ ) from torch_frame.data.multi_tensor import _MultiTensor from torch_frame.data.stats import StatType -from torch_frame.typing import TensorData +from torch_frame.typing import WITH_PT24, TensorData def serialize_feat_dict( @@ -96,7 +96,7 @@ def load( tuple: A tuple of loaded :class:`TensorFrame` object and optional :obj:`col_stats`. """ - tf_dict, col_stats = torch.load(path, weights_only=True) + tf_dict, col_stats = torch.load(path, weights_only=WITH_PT24) tf_dict['feat_dict'] = deserialize_feat_dict( tf_dict.pop('feat_serialized_dict')) tensor_frame = TensorFrame(**tf_dict)