diff --git a/test/utils/test_infer_stype.py b/test/utils/test_infer_stype.py index 17f6a58c1..158238370 100644 --- a/test/utils/test_infer_stype.py +++ b/test/utils/test_infer_stype.py @@ -1,3 +1,4 @@ +import pandas as pd import pytest import torch_frame @@ -38,3 +39,16 @@ def test_infer_df_stype(with_nan): dataset = get_fake_dataset(num_rows, col_to_text_embedder_cfg, with_nan) col_to_stype_inferred = infer_df_stype(dataset.df) assert col_to_stype_inferred == dataset.col_to_stype + + +def test_infer_multicategorical_stype(): + # Test when multicategoricals are lists + df = pd.DataFrame({ + 'category': [['Books', 'Mystery, Thriller'], + ['Books', "Children's Books", 'Geography'], + ['Books', 'Health', 'Fitness & Dieting'], + ['Books', 'Teen & oung Adult']] * 50, + 'id': [i for i in range(200)] + }) + col_to_stype_inferred = infer_df_stype(df) + assert col_to_stype_inferred['category'] == torch_frame.multicategorical diff --git a/torch_frame/utils/infer_stype.py b/torch_frame/utils/infer_stype.py index cc31402b9..ea202d463 100644 --- a/torch_frame/utils/infer_stype.py +++ b/torch_frame/utils/infer_stype.py @@ -5,6 +5,7 @@ import warnings from typing import Any +import numpy as np import pandas as pd import pandas.api.types as ptypes from dateutil.parser import ParserError @@ -137,19 +138,26 @@ def infer_series_stype(ser: Series) -> stype | None: return stype.embedding # Try different possible seps and mick the largest min_count. - min_count_list = [] - for sep in POSSIBLE_SEPS: - try: - min_count_list.append( - _min_count( - ser.apply(lambda row: MultiCategoricalTensorMapper. - split_by_sep(row, sep)).explode())) - except Exception as e: - logging.warn( - "Mapping series into multicategorical stype " - f"with separator {sep} raised an exception {e}") - continue - if max(min_count_list) > cat_min_count_thresh: + if isinstance(ser.iloc[0], list) or isinstance( + ser.iloc[0], np.ndarray): + max_min_count = _min_count(ser.explode()) + else: + min_count_list = [] + for sep in POSSIBLE_SEPS: + try: + min_count_list.append( + _min_count( + ser.apply( + lambda row: MultiCategoricalTensorMapper. + split_by_sep(row, sep)).explode())) + except Exception as e: + logging.warn( + "Mapping series into multicategorical stype " + f"with separator {sep} raised an exception {e}") + continue + max_min_count = max(min_count_list or [0]) + + if max_min_count > cat_min_count_thresh: return stype.multicategorical else: return stype.text_embedded