From 1e6968a0e7f51324b5303ed0c7af2fcc208a7610 Mon Sep 17 00:00:00 2001 From: yiweny Date: Thu, 18 Jul 2024 23:52:19 +0000 Subject: [PATCH 1/4] fix infer multicategorical stype function and add test case --- test/utils/test_infer_stype.py | 14 ++++++++++++++ torch_frame/utils/infer_stype.py | 32 ++++++++++++++++++++------------ 2 files changed, 34 insertions(+), 12 deletions(-) diff --git a/test/utils/test_infer_stype.py b/test/utils/test_infer_stype.py index 17f6a58c1..158238370 100644 --- a/test/utils/test_infer_stype.py +++ b/test/utils/test_infer_stype.py @@ -1,3 +1,4 @@ +import pandas as pd import pytest import torch_frame @@ -38,3 +39,16 @@ def test_infer_df_stype(with_nan): dataset = get_fake_dataset(num_rows, col_to_text_embedder_cfg, with_nan) col_to_stype_inferred = infer_df_stype(dataset.df) assert col_to_stype_inferred == dataset.col_to_stype + + +def test_infer_multicategorical_stype(): + # Test when multicategoricals are lists + df = pd.DataFrame({ + 'category': [['Books', 'Mystery, Thriller'], + ['Books', "Children's Books", 'Geography'], + ['Books', 'Health', 'Fitness & Dieting'], + ['Books', 'Teen & oung Adult']] * 50, + 'id': [i for i in range(200)] + }) + col_to_stype_inferred = infer_df_stype(df) + assert col_to_stype_inferred['category'] == torch_frame.multicategorical diff --git a/torch_frame/utils/infer_stype.py b/torch_frame/utils/infer_stype.py index cc31402b9..697ed6167 100644 --- a/torch_frame/utils/infer_stype.py +++ b/torch_frame/utils/infer_stype.py @@ -5,6 +5,7 @@ import warnings from typing import Any +import numpy as np import pandas as pd import pandas.api.types as ptypes from dateutil.parser import ParserError @@ -138,18 +139,25 @@ def infer_series_stype(ser: Series) -> stype | None: # Try different possible seps and mick the largest min_count. min_count_list = [] - for sep in POSSIBLE_SEPS: - try: - min_count_list.append( - _min_count( - ser.apply(lambda row: MultiCategoricalTensorMapper. - split_by_sep(row, sep)).explode())) - except Exception as e: - logging.warn( - "Mapping series into multicategorical stype " - f"with separator {sep} raised an exception {e}") - continue - if max(min_count_list) > cat_min_count_thresh: + if isinstance(ser.iloc[0], list) or isinstance( + ser.iloc[0], np.ndarray): + min_count_list.append(_min_count(ser.explode())) + else: + for sep in POSSIBLE_SEPS: + try: + min_count_list.append( + _min_count( + ser.apply( + lambda row: MultiCategoricalTensorMapper. + split_by_sep(row, sep)).explode())) + except Exception as e: + logging.warn( + "Mapping series into multicategorical stype " + f"with separator {sep} raised an exception {e}") + continue + + if len(min_count_list) > 0 and max( + min_count_list) > cat_min_count_thresh: return stype.multicategorical else: return stype.text_embedded From 0cb98582e386fd3169afc57c0157f356b0a92568 Mon Sep 17 00:00:00 2001 From: yiweny Date: Fri, 19 Jul 2024 05:45:28 +0000 Subject: [PATCH 2/4] finish --- torch_frame/utils/infer_stype.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch_frame/utils/infer_stype.py b/torch_frame/utils/infer_stype.py index 697ed6167..0b7385b9c 100644 --- a/torch_frame/utils/infer_stype.py +++ b/torch_frame/utils/infer_stype.py @@ -70,7 +70,8 @@ def infer_series_stype(ser: Series) -> stype | None: # Categorical minimum counting threshold. If the count of the most minor # categories is larger than this value, we treat the column as categorical. cat_min_count_thresh = 4 - + import pdb + pdb.set_trace() if isinstance(ser.iloc[0], list): # Candidates: embedding, sequence_numerical, multicategorical From c3cc5a4f861e0b9d44c5445cafb45592ecd451f3 Mon Sep 17 00:00:00 2001 From: yiweny Date: Fri, 19 Jul 2024 05:48:25 +0000 Subject: [PATCH 3/4] remove pdb --- torch_frame/utils/infer_stype.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torch_frame/utils/infer_stype.py b/torch_frame/utils/infer_stype.py index 0b7385b9c..697ed6167 100644 --- a/torch_frame/utils/infer_stype.py +++ b/torch_frame/utils/infer_stype.py @@ -70,8 +70,7 @@ def infer_series_stype(ser: Series) -> stype | None: # Categorical minimum counting threshold. If the count of the most minor # categories is larger than this value, we treat the column as categorical. cat_min_count_thresh = 4 - import pdb - pdb.set_trace() + if isinstance(ser.iloc[0], list): # Candidates: embedding, sequence_numerical, multicategorical From eb535f2e82a6faad61a648317202985c6f44bdf4 Mon Sep 17 00:00:00 2001 From: yiweny Date: Sat, 20 Jul 2024 21:48:44 +0000 Subject: [PATCH 4/4] fix code based on review comments --- torch_frame/utils/infer_stype.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torch_frame/utils/infer_stype.py b/torch_frame/utils/infer_stype.py index 697ed6167..ea202d463 100644 --- a/torch_frame/utils/infer_stype.py +++ b/torch_frame/utils/infer_stype.py @@ -138,11 +138,11 @@ def infer_series_stype(ser: Series) -> stype | None: return stype.embedding # Try different possible seps and mick the largest min_count. - min_count_list = [] if isinstance(ser.iloc[0], list) or isinstance( ser.iloc[0], np.ndarray): - min_count_list.append(_min_count(ser.explode())) + max_min_count = _min_count(ser.explode()) else: + min_count_list = [] for sep in POSSIBLE_SEPS: try: min_count_list.append( @@ -155,9 +155,9 @@ def infer_series_stype(ser: Series) -> stype | None: "Mapping series into multicategorical stype " f"with separator {sep} raised an exception {e}") continue + max_min_count = max(min_count_list or [0]) - if len(min_count_list) > 0 and max( - min_count_list) > cat_min_count_thresh: + if max_min_count > cat_min_count_thresh: return stype.multicategorical else: return stype.text_embedded