diff --git a/CHANGELOG.md b/CHANGELOG.md index 9bbd92324..707ba8d74 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed +- Set `weights_only=True` in `torch_frame.load` from PyTorch 2.4 ([#423](https://github.com/pyg-team/pytorch-frame/pull/423)) + ### Deprecated ### Removed diff --git a/torch_frame/__init__.py b/torch_frame/__init__.py index 7aa3f9904..7161acc38 100644 --- a/torch_frame/__init__.py +++ b/torch_frame/__init__.py @@ -12,13 +12,27 @@ embedding, ) from .data import TensorFrame -from .typing import TaskType, Metric, DataFrame, NAStrategy +from .typing import ( + TaskType, + Metric, + DataFrame, + NAStrategy, + WITH_PT24, +) from torch_frame.utils import save, load, cat # noqa import torch_frame.data # noqa import torch_frame.datasets # noqa import torch_frame.nn # noqa import torch_frame.gbdt # noqa +if WITH_PT24: + import torch + + torch.serialization.add_safe_globals([ + stype, + torch_frame.data.stats.StatType, + ]) + __version__ = '0.2.3' __all__ = [ diff --git a/torch_frame/typing.py b/torch_frame/typing.py index a2e49159d..c7aede63a 100644 --- a/torch_frame/typing.py +++ b/torch_frame/typing.py @@ -4,11 +4,15 @@ from typing import Dict, List, Mapping, Union import pandas as pd +import torch from torch import Tensor from torch_frame.data.multi_embedding_tensor import MultiEmbeddingTensor from torch_frame.data.multi_nested_tensor import MultiNestedTensor +WITH_PT20 = int(torch.__version__.split('.')[0]) >= 2 +WITH_PT24 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 4 + class Metric(Enum): r"""The metric. diff --git a/torch_frame/utils/io.py b/torch_frame/utils/io.py index 5a22f3e8e..50732d4a6 100644 --- a/torch_frame/utils/io.py +++ b/torch_frame/utils/io.py @@ -13,7 +13,7 @@ ) from torch_frame.data.multi_tensor import _MultiTensor from torch_frame.data.stats import StatType -from torch_frame.typing import TensorData +from torch_frame.typing import WITH_PT24, TensorData def serialize_feat_dict( @@ -80,7 +80,8 @@ def save(tensor_frame: TensorFrame, def load( - path: str, device: torch.device | None = None + path: str, + device: torch.device | None = None, ) -> tuple[TensorFrame, dict[str, dict[StatType, Any]] | None]: r"""Load saved :class:`TensorFrame` object and optional :obj:`col_stats` from a specified path. @@ -95,7 +96,7 @@ def load( tuple: A tuple of loaded :class:`TensorFrame` object and optional :obj:`col_stats`. """ - tf_dict, col_stats = torch.load(path) + tf_dict, col_stats = torch.load(path, weights_only=WITH_PT24) tf_dict['feat_dict'] = deserialize_feat_dict( tf_dict.pop('feat_serialized_dict')) tensor_frame = TensorFrame(**tf_dict)