diff --git a/CHANGELOG.md b/CHANGELOG.md index fae9694b..60d39093 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added support for materializing dataset for train and test dataframe separately([#470](https://github.com/pyg-team/pytorch-frame/issues/470)) - Added support for PyTorch 2.5 ([#464](https://github.com/pyg-team/pytorch-frame/pull/464)) - Added a benchmark script to compare PyTorch Frame with PyTorch Tabular ([#398](https://github.com/pyg-team/pytorch-frame/pull/398), [#444](https://github.com/pyg-team/pytorch-frame/pull/444)) - Added `is_floating_point` method to `MultiNestedTensor` and `MultiEmbeddingTensor` ([#445](https://github.com/pyg-team/pytorch-frame/pull/445)) diff --git a/test/data/test_dataset.py b/test/data/test_dataset.py index 1b23ba5e..ffbcb7db 100644 --- a/test/data/test_dataset.py +++ b/test/data/test_dataset.py @@ -280,3 +280,44 @@ def test_col_to_pattern_raise_error(): dataset = FakeDataset(num_rows=10, stypes=[torch_frame.timestamp]) Dataset(dataset.df, dataset.col_to_stype, dataset.target_col, col_to_time_format=2) + + +def test_materialization_with_col_stats(tmpdir): + tmp_path = str(tmpdir.mkdir("image")) + text_embedder_cfg = TextEmbedderConfig( + text_embedder=HashTextEmbedder(1), + batch_size=8, + ) + image_embedder_cfg = ImageEmbedderConfig( + image_embedder=RandomImageEmbedder(1), + batch_size=8, + ) + dataset_stypes = [ + torch_frame.categorical, + torch_frame.numerical, + torch_frame.multicategorical, + torch_frame.sequence_numerical, + torch_frame.timestamp, + torch_frame.text_embedded, + torch_frame.embedding, + torch_frame.image_embedded, + ] + train_dataset = FakeDataset( + num_rows=10, + stypes=dataset_stypes, + col_to_text_embedder_cfg=text_embedder_cfg, + col_to_image_embedder_cfg=image_embedder_cfg, + tmp_path=tmp_path, + ) + train_dataset.materialize() # materialize to compute col_stats + test_dataset = FakeDataset( + num_rows=5, + stypes=dataset_stypes, + col_to_text_embedder_cfg=text_embedder_cfg, + col_to_image_embedder_cfg=image_embedder_cfg, + tmp_path=tmp_path, + ) + test_dataset.materialize(col_stats=train_dataset.col_stats) + + assert train_dataset.col_stats == test_dataset.col_stats, \ + "col_stats should be the same for train and test datasets" diff --git a/torch_frame/data/dataset.py b/torch_frame/data/dataset.py index fd5614bd..ab8b8639 100644 --- a/torch_frame/data/dataset.py +++ b/torch_frame/data/dataset.py @@ -554,6 +554,7 @@ def materialize( self, device: torch.device | None = None, path: str | None = None, + col_stats: dict[str, dict[StatType, Any]] | None = None, ) -> Dataset: r"""Materializes the dataset into a tensor representation. From this point onwards, the dataset should be treated as read-only. @@ -570,6 +571,10 @@ def materialize( :obj:`path`. If :obj:`path` is :obj:`None`, this will materialize the dataset without caching. (default: :obj:`None`) + col_stats (Dict[str, Dict[StatType, Any]], optional): optional + col_stats provided by the user. If not provided, the statistics + is calculated from the dataframe itself. (default: :obj:`None`) + """ if self.is_materialized: # Materialized without specifying path at first and materialize @@ -589,23 +594,37 @@ def materialize( return self # 1. Fill column statistics: - for col, stype in self.col_to_stype.items(): - ser = self.df[col] - self._col_stats[col] = compute_col_stats( - ser, - stype, - sep=self.col_to_sep.get(col, None), - time_format=self.col_to_time_format.get(col, None), - ) - # For a target column, sort categories lexicographically such that - # we do not accidentally swap labels in binary classification - # tasks. - if col == self.target_col and stype == torch_frame.categorical: - index, value = self._col_stats[col][StatType.COUNT] - if len(index) == 2: - ser = pd.Series(index=index, data=value).sort_index() - index, value = ser.index.tolist(), ser.values.tolist() - self._col_stats[col][StatType.COUNT] = (index, value) + if col_stats is None: + # calculate from data if col_stats is not provided + for col, stype in self.col_to_stype.items(): + ser = self.df[col] + self._col_stats[col] = compute_col_stats( + ser, + stype, + sep=self.col_to_sep.get(col, None), + time_format=self.col_to_time_format.get(col, None), + ) + # For a target column, sort categories lexicographically + # such that we do not accidentally swap labels in binary + # classification tasks. + if col == self.target_col and stype == torch_frame.categorical: + index, value = self._col_stats[col][StatType.COUNT] + if len(index) == 2: + ser = pd.Series(index=index, data=value).sort_index() + index, value = ser.index.tolist(), ser.values.tolist() + self._col_stats[col][StatType.COUNT] = (index, value) + else: + # basic validation for the col_stats provided by the user + for col_, stype_ in self.col_to_stype.items(): + assert col_ in col_stats, \ + f"{col_} is not specified in the provided col_stats" + stats_ = col_stats[col_] + assert all([key_ in stats_ + for key_ in StatType.stats_for_stype(stype_)]), \ + "not all required stats are calculated" \ + f" in the provided col_stats for {col}" + + self._col_stats = col_stats # 2. Create the `TensorFrame`: self._to_tensor_frame_converter = self._get_tensorframe_converter()