From bf8288f1bd56cf60786ea7819c6859ac8ac72a83 Mon Sep 17 00:00:00 2001 From: yiweny Date: Tue, 30 Jul 2024 07:00:06 +0000 Subject: [PATCH 1/3] coorectly infer boolean stypes --- torch_frame/utils/infer_stype.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torch_frame/utils/infer_stype.py b/torch_frame/utils/infer_stype.py index ea202d463..7f93a0a82 100644 --- a/torch_frame/utils/infer_stype.py +++ b/torch_frame/utils/infer_stype.py @@ -109,6 +109,9 @@ def infer_series_stype(ser: Series) -> stype | None: # text_(embedded/tokenized) if ptypes.is_numeric_dtype(ser): + + if ptypes.is_bool_dtype(ser): + return stype.categorical # Candidates: numerical, categorical if ptypes.is_float_dtype(ser) and not (has_nan and (ser % 1 == 0).all()): From 4290aee7d1c9c93a0f18909bf7933091c779bc52 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Tue, 30 Jul 2024 08:26:38 +0000 Subject: [PATCH 2/3] update changelog --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7d17debb4..e0c56a9e3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added support for inferring `stype.categorical` from boolean columns in `utils.infer_series_stype` ([#421](https://github.com/pyg-team/pytorch-frame/pull/421)) + ### Changed ### Deprecated From 1a4340f4507821fe37bfb590ab8017c96f075e2f Mon Sep 17 00:00:00 2001 From: yiweny Date: Tue, 30 Jul 2024 23:49:38 +0000 Subject: [PATCH 3/3] add a test case --- test/utils/test_infer_stype.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/test/utils/test_infer_stype.py b/test/utils/test_infer_stype.py index 158238370..c8da8077d 100644 --- a/test/utils/test_infer_stype.py +++ b/test/utils/test_infer_stype.py @@ -41,7 +41,7 @@ def test_infer_df_stype(with_nan): assert col_to_stype_inferred == dataset.col_to_stype -def test_infer_multicategorical_stype(): +def test_infer_stypes(): # Test when multicategoricals are lists df = pd.DataFrame({ 'category': [['Books', 'Mystery, Thriller'], @@ -52,3 +52,8 @@ def test_infer_multicategorical_stype(): }) col_to_stype_inferred = infer_df_stype(df) assert col_to_stype_inferred['category'] == torch_frame.multicategorical + + df = pd.DataFrame({'bool': [True] * 50 + [False] * 50}) + + col_to_stype_inferred = infer_df_stype(df) + assert col_to_stype_inferred['bool'] == torch_frame.categorical