Skip to content

Commit 79856bc

Browse files
authored
feat: add type schema (#1274)
* feat: allows user to define variable types
1 parent b9ada64 commit 79856bc

File tree

5 files changed

+65
-5
lines changed

5 files changed

+65
-5
lines changed

src/ydata_profiling/compare_reports.py

+4
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,10 @@ def _compare_profile_report_preprocess(
134134
config.html.style.primary_colors
135135
)
136136

137+
# enforce same types
138+
for report in reports[1:]:
139+
report._typeset = reports[0].typeset
140+
137141
# Obtain description sets
138142
descriptions = [report.get_description() for report in reports]
139143
for label, description in zip(labels, descriptions):

src/ydata_profiling/model/pandas/summary_pandas.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from ydata_profiling.config import Settings
1313
from ydata_profiling.model.summarizer import BaseSummarizer
1414
from ydata_profiling.model.summary import describe_1d, get_series_descriptions
15+
from ydata_profiling.model.typeset import ProfilingTypeSet
1516
from ydata_profiling.utils.dataframe import sort_column_names
1617

1718

@@ -37,8 +38,13 @@ def pandas_describe_1d(
3738
# Make sure pd.NA is not in the series
3839
series = series.fillna(np.nan)
3940

40-
# get `infer_dtypes` (bool) from config
41-
if config.infer_dtypes:
41+
if (
42+
isinstance(typeset, ProfilingTypeSet)
43+
and typeset.type_schema
44+
and series.name in typeset.type_schema
45+
):
46+
vtype = typeset.type_schema[series.name]
47+
elif config.infer_dtypes:
4248
# Infer variable types
4349
vtype = typeset.infer_type(series)
4450
series = typeset.cast_to_inferred(series)
@@ -47,6 +53,7 @@ def pandas_describe_1d(
4753
# [new dtypes, changed using `astype` function are now considered]
4854
vtype = typeset.detect_type(series)
4955

56+
typeset.type_schema[series.name] = vtype
5057
return summarizer.summarize(config, series, dtype=vtype)
5158

5259

src/ydata_profiling/model/typeset.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -241,11 +241,22 @@ def is_timedependent(series: pd.Series) -> bool:
241241

242242

243243
class ProfilingTypeSet(visions.VisionsTypeset):
244-
def __init__(self, config: Settings):
244+
def __init__(self, config: Settings, type_schema: dict = None):
245245
self.config = config
246246

247247
types = typeset_types(config)
248248

249249
with warnings.catch_warnings():
250250
warnings.filterwarnings("ignore", category=UserWarning)
251251
super().__init__(types)
252+
253+
self.type_schema = self._init_type_schema(type_schema or {})
254+
255+
def _init_type_schema(self, type_schema: dict) -> dict:
256+
return {k: self._get_type(v) for k, v in type_schema.items()}
257+
258+
def _get_type(self, type_name: str) -> visions.VisionsBaseType:
259+
for t in self.types:
260+
if t.__name__.lower() == type_name.lower():
261+
return t
262+
raise ValueError(f"Type [{type_name}] not found.")

src/ydata_profiling/profile_report.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def __init__(
6868
typeset: Optional[VisionsTypeset] = None,
6969
summarizer: Optional[BaseSummarizer] = None,
7070
config: Optional[Settings] = None,
71+
type_schema: Optional[dict] = None,
7172
**kwargs,
7273
):
7374
"""Generate a ProfileReport based on a pandas or spark.sql DataFrame
@@ -89,6 +90,7 @@ def __init__(
8990
sample: optional dict(name="Sample title", caption="Caption", data=pd.DataFrame())
9091
typeset: optional user typeset to use for type inference
9192
summarizer: optional user summarizer to generate custom summary output
93+
type_schema: optional dict containing pairs of `column name`: `type`
9294
**kwargs: other arguments, for valid arguments, check the default configuration file.
9395
"""
9496
self.__validate_inputs(df, minimal, tsmode, config_file, lazy)
@@ -139,6 +141,7 @@ def __init__(
139141
self.config = report_config
140142
self._df_hash = None
141143
self._sample = sample
144+
self._type_schema = type_schema
142145
self._typeset = typeset
143146
self._summarizer = summarizer
144147

@@ -230,7 +233,7 @@ def invalidate_cache(self, subset: Optional[str] = None) -> None:
230233
@property
231234
def typeset(self) -> Optional[VisionsTypeset]:
232235
if self._typeset is None:
233-
self._typeset = ProfilingTypeSet(self.config)
236+
self._typeset = ProfilingTypeSet(self.config, self._type_schema)
234237
return self._typeset
235238

236239
@property

tests/unit/test_typeset_default.py

+36-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import os
22

3+
import numpy as np
4+
import pandas as pd
35
import pytest
46
from visions.test.series import get_series
57
from visions.test.utils import (
@@ -14,6 +16,7 @@
1416
from tests.unit.test_utils import patch_arg
1517
from ydata_profiling.config import Settings
1618
from ydata_profiling.model.typeset import ProfilingTypeSet
19+
from ydata_profiling.profile_report import ProfileReport
1720

1821
base_path = os.path.abspath(os.path.dirname(__file__))
1922

@@ -161,7 +164,7 @@
161164
)
162165
)
163166
def test_contains(name, series, contains_type, member):
164-
"""Test the generated combinations for "series in type"
167+
"""Test the generated combinations for "series in type".
165168
166169
Args:
167170
series: the series to test
@@ -349,3 +352,35 @@ def test_conversion(name, source_type, relation_type, series, member):
349352
"""
350353
result, message = convert(name, source_type, relation_type, series, member)
351354
assert result, message
355+
356+
357+
@pytest.fixture
358+
def dataframe(size: int = 1000) -> pd.DataFrame:
359+
return pd.DataFrame(
360+
{
361+
"boolean": np.random.choice([True, False], size=size),
362+
"numeric": np.random.rand(size),
363+
"categorical": np.random.choice(np.arange(5), size=size),
364+
"timeseries": np.arange(size),
365+
}
366+
)
367+
368+
369+
def convertion_map() -> list:
370+
types = {
371+
"boolean": ["Categorical", "Unsupported"],
372+
"numeric": ["Categorical", "Boolean", "Unsupported"],
373+
"categorical": ["Numeric", "Boolean", "TimeSeries", "Unsupported"],
374+
"timeseries": ["Numeric", "Boolean", "Categorical", "Unsupported"],
375+
}
376+
return [(k, {k: i}) for k, v in types.items() for i in v]
377+
378+
379+
@pytest.mark.parametrize("column,type_schema", convertion_map())
380+
def test_type_schema(dataframe: pd.DataFrame, column: str, type_schema: dict):
381+
prof = ProfileReport(dataframe[[column]], tsmode=True, type_schema=type_schema)
382+
prof.get_description()
383+
assert isinstance(prof.typeset, ProfilingTypeSet)
384+
assert prof.typeset.type_schema[column] == prof.typeset._get_type(
385+
type_schema[column]
386+
)

0 commit comments

Comments
 (0)