From 4035d72812784c43ba3216fb45f5e7009272863e Mon Sep 17 00:00:00 2001 From: HoustonJ2013 Date: Fri, 20 Dec 2024 11:40:07 -0600 Subject: [PATCH 01/11] add variable col_stats in dataset.materialize --- torch_frame/data/dataset.py | 42 +++++++++++++++++++++---------------- 1 file changed, 24 insertions(+), 18 deletions(-) diff --git a/torch_frame/data/dataset.py b/torch_frame/data/dataset.py index fd5614bdf..c85004b89 100644 --- a/torch_frame/data/dataset.py +++ b/torch_frame/data/dataset.py @@ -554,6 +554,7 @@ def materialize( self, device: torch.device | None = None, path: str | None = None, + col_stats: dict[str, dict[StatType, Any]] | None = None, ) -> Dataset: r"""Materializes the dataset into a tensor representation. From this point onwards, the dataset should be treated as read-only. @@ -588,24 +589,29 @@ def materialize( self._is_materialized = True return self - # 1. Fill column statistics: - for col, stype in self.col_to_stype.items(): - ser = self.df[col] - self._col_stats[col] = compute_col_stats( - ser, - stype, - sep=self.col_to_sep.get(col, None), - time_format=self.col_to_time_format.get(col, None), - ) - # For a target column, sort categories lexicographically such that - # we do not accidentally swap labels in binary classification - # tasks. - if col == self.target_col and stype == torch_frame.categorical: - index, value = self._col_stats[col][StatType.COUNT] - if len(index) == 2: - ser = pd.Series(index=index, data=value).sort_index() - index, value = ser.index.tolist(), ser.values.tolist() - self._col_stats[col][StatType.COUNT] = (index, value) + # 1. Fill column statistics: + if col_stats is None: + # calculate from data if col_stats is not provided + for col, stype in self.col_to_stype.items(): + ser = self.df[col] + self._col_stats[col] = compute_col_stats( + ser, + stype, + sep=self.col_to_sep.get(col, None), + time_format=self.col_to_time_format.get(col, None), + ) + # For a target column, sort categories lexicographically such that + # we do not accidentally swap labels in binary classification + # tasks. + if col == self.target_col and stype == torch_frame.categorical: + index, value = self._col_stats[col][StatType.COUNT] + if len(index) == 2: + ser = pd.Series(index=index, data=value).sort_index() + index, value = ser.index.tolist(), ser.values.tolist() + self._col_stats[col][StatType.COUNT] = (index, value) + else: + self._col_stats = col_stats + # 2. Create the `TensorFrame`: self._to_tensor_frame_converter = self._get_tensorframe_converter() From b97abc8b4252c39f7193d2b33beb29979d95fd89 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 22 Dec 2024 03:47:57 +0000 Subject: [PATCH 02/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torch_frame/data/dataset.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torch_frame/data/dataset.py b/torch_frame/data/dataset.py index c85004b89..fe1acc512 100644 --- a/torch_frame/data/dataset.py +++ b/torch_frame/data/dataset.py @@ -589,7 +589,7 @@ def materialize( self._is_materialized = True return self - # 1. Fill column statistics: + # 1. Fill column statistics: if col_stats is None: # calculate from data if col_stats is not provided for col, stype in self.col_to_stype.items(): @@ -612,7 +612,6 @@ def materialize( else: self._col_stats = col_stats - # 2. Create the `TensorFrame`: self._to_tensor_frame_converter = self._get_tensorframe_converter() self._tensor_frame = self._to_tensor_frame_converter(self.df, device) From c2d4211e06f5dc66ea08df19b1701d9285b3368e Mon Sep 17 00:00:00 2001 From: HoustonJ2013 Date: Sun, 22 Dec 2024 20:31:14 -0600 Subject: [PATCH 03/11] add change log and doc string --- CHANGELOG.md | 2 +- torch_frame/data/dataset.py | 10 +++++++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index fae9694b9..770fd3a52 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,7 +6,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## \[Unreleased\] ### Added - +- Added support for materializing dataset for train and test dataframe separately([#470](https://github.com/pyg-team/pytorch-frame/issues/470)) - Added support for PyTorch 2.5 ([#464](https://github.com/pyg-team/pytorch-frame/pull/464)) - Added a benchmark script to compare PyTorch Frame with PyTorch Tabular ([#398](https://github.com/pyg-team/pytorch-frame/pull/398), [#444](https://github.com/pyg-team/pytorch-frame/pull/444)) - Added `is_floating_point` method to `MultiNestedTensor` and `MultiEmbeddingTensor` ([#445](https://github.com/pyg-team/pytorch-frame/pull/445)) diff --git a/torch_frame/data/dataset.py b/torch_frame/data/dataset.py index c85004b89..8f2fb4d1e 100644 --- a/torch_frame/data/dataset.py +++ b/torch_frame/data/dataset.py @@ -571,6 +571,10 @@ def materialize( :obj:`path`. If :obj:`path` is :obj:`None`, this will materialize the dataset without caching. (default: :obj:`None`) + col_stats (Dict[str, Dict[StatType, Any]], optional): optional + col_stats provided by the user. If not provided, the statistics + is calculated from the dataframe itself. (default: :obj:`None`) + """ if self.is_materialized: # Materialized without specifying path at first and materialize @@ -600,9 +604,9 @@ def materialize( sep=self.col_to_sep.get(col, None), time_format=self.col_to_time_format.get(col, None), ) - # For a target column, sort categories lexicographically such that - # we do not accidentally swap labels in binary classification - # tasks. + # For a target column, sort categories lexicographically such that + # we do not accidentally swap labels in binary classification + # tasks. if col == self.target_col and stype == torch_frame.categorical: index, value = self._col_stats[col][StatType.COUNT] if len(index) == 2: From ac95d3f464d2e19389f29de09f12423546a514d2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 23 Dec 2024 02:34:15 +0000 Subject: [PATCH 04/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- CHANGELOG.md | 1 + torch_frame/data/dataset.py | 8 ++++---- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 770fd3a52..60d390937 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## \[Unreleased\] ### Added + - Added support for materializing dataset for train and test dataframe separately([#470](https://github.com/pyg-team/pytorch-frame/issues/470)) - Added support for PyTorch 2.5 ([#464](https://github.com/pyg-team/pytorch-frame/pull/464)) - Added a benchmark script to compare PyTorch Frame with PyTorch Tabular ([#398](https://github.com/pyg-team/pytorch-frame/pull/398), [#444](https://github.com/pyg-team/pytorch-frame/pull/444)) diff --git a/torch_frame/data/dataset.py b/torch_frame/data/dataset.py index 8b84de0b6..ddf310516 100644 --- a/torch_frame/data/dataset.py +++ b/torch_frame/data/dataset.py @@ -571,7 +571,7 @@ def materialize( :obj:`path`. If :obj:`path` is :obj:`None`, this will materialize the dataset without caching. (default: :obj:`None`) - col_stats (Dict[str, Dict[StatType, Any]], optional): optional + col_stats (Dict[str, Dict[StatType, Any]], optional): optional col_stats provided by the user. If not provided, the statistics is calculated from the dataframe itself. (default: :obj:`None`) @@ -604,9 +604,9 @@ def materialize( sep=self.col_to_sep.get(col, None), time_format=self.col_to_time_format.get(col, None), ) - # For a target column, sort categories lexicographically such that - # we do not accidentally swap labels in binary classification - # tasks. + # For a target column, sort categories lexicographically such that + # we do not accidentally swap labels in binary classification + # tasks. if col == self.target_col and stype == torch_frame.categorical: index, value = self._col_stats[col][StatType.COUNT] if len(index) == 2: From 75bb624a4b99f83a2b77c991f38f120452a423cb Mon Sep 17 00:00:00 2001 From: HoustonJ2013 Date: Sun, 22 Dec 2024 20:44:41 -0600 Subject: [PATCH 05/11] pep8 e501 --- torch_frame/data/dataset.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torch_frame/data/dataset.py b/torch_frame/data/dataset.py index 8b84de0b6..f239074ba 100644 --- a/torch_frame/data/dataset.py +++ b/torch_frame/data/dataset.py @@ -595,7 +595,7 @@ def materialize( # 1. Fill column statistics: if col_stats is None: - # calculate from data if col_stats is not provided + # calculate from data if col_stats is not provided for col, stype in self.col_to_stype.items(): ser = self.df[col] self._col_stats[col] = compute_col_stats( @@ -604,9 +604,9 @@ def materialize( sep=self.col_to_sep.get(col, None), time_format=self.col_to_time_format.get(col, None), ) - # For a target column, sort categories lexicographically such that - # we do not accidentally swap labels in binary classification - # tasks. + # For a target column, sort categories lexicographically + # such that we do not accidentally swap labels in binary + # classification tasks. if col == self.target_col and stype == torch_frame.categorical: index, value = self._col_stats[col][StatType.COUNT] if len(index) == 2: From c7824f66e22efd0185a448f54f03cfbe1d880c0d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 23 Dec 2024 02:48:29 +0000 Subject: [PATCH 06/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torch_frame/data/dataset.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torch_frame/data/dataset.py b/torch_frame/data/dataset.py index 864d172d9..b916f09a5 100644 --- a/torch_frame/data/dataset.py +++ b/torch_frame/data/dataset.py @@ -595,7 +595,7 @@ def materialize( # 1. Fill column statistics: if col_stats is None: - # calculate from data if col_stats is not provided + # calculate from data if col_stats is not provided for col, stype in self.col_to_stype.items(): ser = self.df[col] self._col_stats[col] = compute_col_stats( @@ -604,8 +604,8 @@ def materialize( sep=self.col_to_sep.get(col, None), time_format=self.col_to_time_format.get(col, None), ) - # For a target column, sort categories lexicographically - # such that we do not accidentally swap labels in binary + # For a target column, sort categories lexicographically + # such that we do not accidentally swap labels in binary # classification tasks. if col == self.target_col and stype == torch_frame.categorical: index, value = self._col_stats[col][StatType.COUNT] From 89d4f9d0c37419c5dbf31dcb619168fed334b7b6 Mon Sep 17 00:00:00 2001 From: HoustonJ2013 Date: Thu, 26 Dec 2024 19:15:52 -0600 Subject: [PATCH 07/11] add validation and unittest for dataset materalize col_stats --- test/data/test_dataset.py | 41 +++++++++++++++++++++++++++++++++++++ torch_frame/data/dataset.py | 10 +++++++++ 2 files changed, 51 insertions(+) diff --git a/test/data/test_dataset.py b/test/data/test_dataset.py index 1b23ba5e9..2d0ac1e6f 100644 --- a/test/data/test_dataset.py +++ b/test/data/test_dataset.py @@ -280,3 +280,44 @@ def test_col_to_pattern_raise_error(): dataset = FakeDataset(num_rows=10, stypes=[torch_frame.timestamp]) Dataset(dataset.df, dataset.col_to_stype, dataset.target_col, col_to_time_format=2) + + +def test_materialization_with_col_stats(tmpdir): + tmp_path = str(tmpdir.mkdir("image")) + text_embedder_cfg = TextEmbedderConfig( + text_embedder=HashTextEmbedder(1), + batch_size=8, + ) + image_embedder_cfg = ImageEmbedderConfig( + image_embedder=RandomImageEmbedder(1), + batch_size=8, + ) + dataset_stypes = [ + torch_frame.categorical, + torch_frame.numerical, + torch_frame.multicategorical, + torch_frame.sequence_numerical, + torch_frame.timestamp, + torch_frame.text_embedded, + torch_frame.embedding, + torch_frame.image_embedded, + ] + train_dataset = FakeDataset( + num_rows=10, + stypes=dataset_stypes, + col_to_text_embedder_cfg=text_embedder_cfg, + col_to_image_embedder_cfg=image_embedder_cfg, + tmp_path=tmp_path, + ) + train_dataset.materialize() # materialize to compute col_stats + test_dataset = FakeDataset( + num_rows=5, + stypes=dataset_stypes, + col_to_text_embedder_cfg=text_embedder_cfg, + col_to_image_embedder_cfg=image_embedder_cfg, + tmp_path=tmp_path, + ) + test_dataset.materialize(col_stats=train_dataset.col_stats) + + assert train_dataset.col_stats == test_dataset.col_stats, \ + "col_stats should be the same for train and test datasets" \ No newline at end of file diff --git a/torch_frame/data/dataset.py b/torch_frame/data/dataset.py index 864d172d9..ed289793c 100644 --- a/torch_frame/data/dataset.py +++ b/torch_frame/data/dataset.py @@ -614,6 +614,16 @@ def materialize( index, value = ser.index.tolist(), ser.values.tolist() self._col_stats[col][StatType.COUNT] = (index, value) else: + # basic validation for the for col_stats provided by the user + for col_, stype_ in self.col_to_stype.items(): + assert col_ in col_stats, \ + f"{col_} is not specified in the provided col_stats" + stats_ = col_stats[col_] + assert all([key_ in stats_ + for key_ in StatType.stats_for_stype(stype_)]), \ + f"not all required stats are calculated" \ + " in the provided col_stats for {col}" + self._col_stats = col_stats # 2. Create the `TensorFrame`: From b089036925dcd0cdf9e72210e57c049ce25ec6ce Mon Sep 17 00:00:00 2001 From: HoustonJ2013 Date: Thu, 26 Dec 2024 19:17:22 -0600 Subject: [PATCH 08/11] add validation and unittest for dataset materalize col_stats --- torch_frame/data/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_frame/data/dataset.py b/torch_frame/data/dataset.py index 80ce5fa56..42f50d3f9 100644 --- a/torch_frame/data/dataset.py +++ b/torch_frame/data/dataset.py @@ -614,7 +614,7 @@ def materialize( index, value = ser.index.tolist(), ser.values.tolist() self._col_stats[col][StatType.COUNT] = (index, value) else: - # basic validation for the for col_stats provided by the user + # basic validation for the col_stats provided by the user for col_, stype_ in self.col_to_stype.items(): assert col_ in col_stats, \ f"{col_} is not specified in the provided col_stats" From 7ecd3f0f78c90c0ea1302433b735ce12df365ef7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 27 Dec 2024 01:17:55 +0000 Subject: [PATCH 09/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- test/data/test_dataset.py | 4 ++-- torch_frame/data/dataset.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/test/data/test_dataset.py b/test/data/test_dataset.py index 2d0ac1e6f..ffbcb7db8 100644 --- a/test/data/test_dataset.py +++ b/test/data/test_dataset.py @@ -309,7 +309,7 @@ def test_materialization_with_col_stats(tmpdir): col_to_image_embedder_cfg=image_embedder_cfg, tmp_path=tmp_path, ) - train_dataset.materialize() # materialize to compute col_stats + train_dataset.materialize() # materialize to compute col_stats test_dataset = FakeDataset( num_rows=5, stypes=dataset_stypes, @@ -320,4 +320,4 @@ def test_materialization_with_col_stats(tmpdir): test_dataset.materialize(col_stats=train_dataset.col_stats) assert train_dataset.col_stats == test_dataset.col_stats, \ - "col_stats should be the same for train and test datasets" \ No newline at end of file + "col_stats should be the same for train and test datasets" diff --git a/torch_frame/data/dataset.py b/torch_frame/data/dataset.py index 42f50d3f9..54cf8df8f 100644 --- a/torch_frame/data/dataset.py +++ b/torch_frame/data/dataset.py @@ -619,7 +619,7 @@ def materialize( assert col_ in col_stats, \ f"{col_} is not specified in the provided col_stats" stats_ = col_stats[col_] - assert all([key_ in stats_ + assert all([key_ in stats_ for key_ in StatType.stats_for_stype(stype_)]), \ f"not all required stats are calculated" \ " in the provided col_stats for {col}" From 816faee6d3b2699bcea52e4514f1d29bb06efc83 Mon Sep 17 00:00:00 2001 From: HoustonJ2013 Date: Thu, 26 Dec 2024 19:21:56 -0600 Subject: [PATCH 10/11] add validation and unittest for dataset materalize col_stats --- torch_frame/data/dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_frame/data/dataset.py b/torch_frame/data/dataset.py index 42f50d3f9..81f64d902 100644 --- a/torch_frame/data/dataset.py +++ b/torch_frame/data/dataset.py @@ -621,8 +621,8 @@ def materialize( stats_ = col_stats[col_] assert all([key_ in stats_ for key_ in StatType.stats_for_stype(stype_)]), \ - f"not all required stats are calculated" \ - " in the provided col_stats for {col}" + "not all required stats are calculated" \ + f" in the provided col_stats for {col}" self._col_stats = col_stats From 3137a6c3c435c62939105504f6f3d4950c1f54f5 Mon Sep 17 00:00:00 2001 From: HoustonJ2013 Date: Thu, 26 Dec 2024 19:41:13 -0600 Subject: [PATCH 11/11] fix pep8 format --- torch_frame/data/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_frame/data/dataset.py b/torch_frame/data/dataset.py index 115b63ee8..ab8b86393 100644 --- a/torch_frame/data/dataset.py +++ b/torch_frame/data/dataset.py @@ -621,7 +621,7 @@ def materialize( stats_ = col_stats[col_] assert all([key_ in stats_ for key_ in StatType.stats_for_stype(stype_)]), \ - "not all required stats are calculated" \ + "not all required stats are calculated" \ f" in the provided col_stats for {col}" self._col_stats = col_stats