Skip to content

Commit bc1016e

Browse files
committed
PLAT-1008: Make datetime detection more strict by optionally supporting a must_match_all parameter
GitOrigin-RevId: a8a43e3b3b27d0375d7afc682661711ba630ef51
1 parent c8cfc19 commit bc1016e

File tree

4 files changed

+160
-37
lines changed

4 files changed

+160
-37
lines changed

src/gretel_synthetics/actgan/actgan_wrapper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def fit(self, data: Union[pd.DataFrame, str]) -> None:
7070
if self._auto_transform_datetimes:
7171
if self._verbose:
7272
logger.info("Attempting datetime auto-detection...")
73-
detector.fit_datetime(data, with_suffix=True)
73+
detector.fit_datetime(data, with_suffix=True, must_match_all=True)
7474

7575
detector.fit_empty_columns(data)
7676
if self._verbose:

src/gretel_synthetics/detectors/dates.py

Lines changed: 69 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,23 @@ def _maybe_match(date, format) -> Tuple[Optional[datetime], Optional[str]]:
309309
return None, None
310310

311311

312+
def _check_series(series: pd.Series, format: str) -> bool:
313+
# Remove non-standard formatting directives which are relevant for formatting
314+
# only, not for parsing. The first one, `!`, is introduced by us (see
315+
# ``_strptime_extra``), the second one, `%-`, is a directive not recognized
316+
# by pandas and stripped by RDT as well (see
317+
# https://github.com/sdv-dev/RDT/pull/458/files#r835690711 ).
318+
pd_format = format.replace("!", "").replace("%-", "%")
319+
try:
320+
pd.to_datetime(series, format=pd_format)
321+
return True
322+
except:
323+
# Conservatively ignore any error, and assume that the format
324+
# didn't work.
325+
# This is to prevent errors in the SDV code downstream.
326+
return False
327+
328+
312329
def _parse_date_multiple(
313330
input_date: str,
314331
date_str_fmts: Union[List[str], Set[str]] = _date_str_fmt_permutations,
@@ -334,7 +351,46 @@ def _maybe_d_str_to_fmt_multiple(input_date: str, with_suffix: bool) -> Iterator
334351
pass
335352

336353

337-
def _infer_from_series(series: Iterable[str], with_suffix: bool) -> Optional[str]:
354+
def _infer_from_series_match_all(series: pd.Series, with_suffix: bool) -> Optional[str]:
355+
if series.empty:
356+
return None
357+
358+
# We store the candidate formats as a list instead of a set to ensure a deterministic
359+
# result (the order of ``_maybe_d_str_to_fmt_multiple`` is deterministic as well).
360+
# This matches the behavior of ``_infer_from_series``, which - due to the above
361+
# property as well as ``Counter``s stable iteration based on insertion order -
362+
# is deterministic as well.
363+
candidate_fmts = list(_maybe_d_str_to_fmt_multiple(series[0], with_suffix))
364+
i = 1
365+
# Empirically, ``pd.to_datetime`` is about 8x faster than checking individual values.
366+
# Conservatively, we fall back to calling ``pd.to_datetime`` on the entire remaining
367+
# series when we have 4 or less candidate formats less.
368+
# In most cases, the number of candidate formats will be lower than both 4 and 8
369+
# after the first invocation anyway.
370+
while len(candidate_fmts) > 4 and i < len(series):
371+
value = series[i]
372+
candidate_fmts = [
373+
fmt for fmt in candidate_fmts if _maybe_match(value, fmt) != (None, None)
374+
]
375+
i += 1
376+
377+
if i < len(series):
378+
# If we haven't exhausted the whole series yet, do a ``pd.to_datetime``
379+
# call for the remaining values to weed out incorrect formats.
380+
remaining_series = series[i:]
381+
candidate_fmts = [
382+
fmt for fmt in candidate_fmts if _check_series(remaining_series, fmt)
383+
]
384+
385+
return candidate_fmts[0] if candidate_fmts else None
386+
387+
388+
def _infer_from_series(
389+
series: pd.Series, with_suffix: bool, must_match_all: bool = False
390+
) -> Optional[str]:
391+
if must_match_all:
392+
return _infer_from_series_match_all(series, with_suffix)
393+
338394
counter = Counter()
339395
for value in series:
340396
for fmt in _maybe_d_str_to_fmt_multiple(value, with_suffix):
@@ -347,7 +403,10 @@ def _infer_from_series(series: Iterable[str], with_suffix: bool) -> Optional[str
347403

348404

349405
def detect_datetimes(
350-
df: pd.DataFrame, sample_size: Optional[int] = None, with_suffix: bool = False
406+
df: pd.DataFrame,
407+
sample_size: Optional[int] = None,
408+
with_suffix: bool = False,
409+
must_match_all: bool = False,
351410
) -> DateTimeColumns:
352411
if sample_size is None:
353412
sample_size = SAMPLE_SIZE
@@ -356,9 +415,14 @@ def detect_datetimes(
356415
col for col, col_type in df.dtypes.iteritems() if col_type == "object"
357416
]
358417
for object_col in object_cols:
359-
curr_series: pd.Series = df[object_col].dropna(axis=0).reset_index(drop=True)
360-
sampled_series_str = (curr_series.sample(sample_size, replace=True)).astype(str)
361-
inferred_format = _infer_from_series(sampled_series_str, with_suffix)
418+
test_series: pd.Series = df[object_col].dropna(axis=0).reset_index(drop=True)
419+
# Only sample when we don't require the format to match all entries
420+
if not must_match_all and len(test_series) > sample_size:
421+
test_series = test_series.sample(sample_size)
422+
test_series_str = test_series.astype(str)
423+
inferred_format = _infer_from_series(
424+
test_series_str, with_suffix, must_match_all
425+
)
362426
if inferred_format is not None:
363427
inferred_format = inferred_format.replace("!", "")
364428
column_data.columns[object_col] = DateTimeColumn(

src/gretel_synthetics/detectors/sdv.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,9 +164,13 @@ def fit_datetime(
164164
data: pd.DataFrame,
165165
sample_size: Optional[int] = None,
166166
with_suffix: bool = False,
167+
must_match_all: bool = False,
167168
) -> None:
168169
detections = detect_datetimes(
169-
data, sample_size=sample_size, with_suffix=with_suffix
170+
data,
171+
sample_size=sample_size,
172+
with_suffix=with_suffix,
173+
must_match_all=must_match_all,
170174
)
171175
for _, column_info in detections.columns.items():
172176
type_, transformer = datetime_column_to_sdv(column_info)

tests/detectors/test_detectors_dates.py

Lines changed: 85 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from datetime import datetime, timedelta, timezone
44

5+
import numpy as np
56
import pandas as pd
67
import pytest
78

@@ -67,31 +68,70 @@ def test_date_str_tokenizer(input_str, expected_mask):
6768
assert _tokenize_date_str(input_str).masked_str == expected_mask
6869

6970

70-
def test_infer_from_series():
71-
dates = ["12/20/2020", "10/17/2020", "08/10/2020", "01/22/2020", "09/01/2020"]
72-
assert _infer_from_series(dates, False) == "%m/%d/%Y"
71+
@pytest.mark.parametrize("must_match_all", [False, True])
72+
def test_infer_from_series(must_match_all):
73+
dates = pd.Series(
74+
["12/20/2020", "10/17/2020", "08/10/2020", "01/22/2020", "09/01/2020"]
75+
)
76+
assert _infer_from_series(dates, False, must_match_all=must_match_all) == "%m/%d/%Y"
7377

7478

75-
def test_infer_from_bad_date():
76-
dates = ["#NAME?", "1000#", "Jim", "3", "$moola"]
77-
assert _infer_from_series(dates, False) is None
79+
@pytest.mark.parametrize("must_match_all", [False, True])
80+
def test_infer_from_bad_date(must_match_all):
81+
dates = pd.Series(["#NAME?", "1000#", "Jim", "3", "$moola"])
82+
assert _infer_from_series(dates, False, must_match_all=must_match_all) is None
7883

7984

8085
def test_infer_from_some_bad_date():
81-
dates = ["#NAME?", "1000#", "Jim", "3", "10/17/2020"]
82-
assert _infer_from_series(dates, False) == "%m/%d/%Y"
86+
dates = pd.Series(["#NAME?", "1000#", "Jim", "3", "10/17/2020"])
87+
assert _infer_from_series(dates, False, must_match_all=False) == "%m/%d/%Y"
88+
89+
90+
def test_infer_from_some_bad_date_with_match_all():
91+
dates = pd.Series(["#NAME?", "1000#", "Jim", "3", "10/17/2020"])
92+
assert _infer_from_series(dates, False, must_match_all=True) is None
93+
94+
95+
@pytest.mark.parametrize("must_match_all", [False, True])
96+
def test_infer_from_12_hour(must_match_all):
97+
dates = pd.Series(["8:15 AM", "9:20 PM", "1:55 PM"])
98+
assert _infer_from_series(dates, False, must_match_all=must_match_all) == "%I:%M %p"
99+
100+
101+
@pytest.mark.parametrize("with_suffix", [True, False])
102+
@pytest.mark.parametrize("must_match_all", [False, True])
103+
def test_detect_datetimes(with_suffix, must_match_all, test_df):
104+
# Based on the values in the DF, we assert the `with_suffix` flag
105+
# should not change any of the results
106+
check = detect_datetimes(
107+
test_df, with_suffix=with_suffix, must_match_all=must_match_all
108+
)
109+
assert set(check.column_names) == {"dates", "iso"}
110+
assert check.get_column_info("random") is None
83111

112+
dates = check.get_column_info("dates")
113+
assert dates.name == "dates"
114+
assert dates.inferred_format == "%m/%d/%Y"
84115

85-
def test_infer_from_12_hour():
86-
dates = ["8:15 AM", "9:20 PM", "1:55 PM"]
87-
assert _infer_from_series(dates, False) == "%I:%M %p"
116+
iso = check.get_column_info("iso")
117+
assert iso.name == "iso"
118+
assert iso.inferred_format == "%Y-%m-%dT%X.%f"
88119

89120

90121
@pytest.mark.parametrize("with_suffix", [True, False])
91-
def test_detect_datetimes(with_suffix, test_df):
122+
@pytest.mark.parametrize("must_match_all", [False, True])
123+
def test_detect_datetimes_with_nans(with_suffix, must_match_all, test_df):
124+
# Create a copy to prevent modification to the session-scoped fixture
125+
# object.
126+
test_df = test_df.copy()
127+
# Blank out first row
128+
test_df.iloc[0, :] = np.nan
129+
92130
# Based on the values in the DF, we assert the `with_suffix` flag
93131
# should not change any of the results
94-
check = detect_datetimes(test_df, with_suffix=with_suffix)
132+
check = detect_datetimes(
133+
test_df, with_suffix=with_suffix, must_match_all=must_match_all
134+
)
95135
assert set(check.column_names) == {"dates", "iso"}
96136
assert check.get_column_info("random") is None
97137

@@ -104,27 +144,41 @@ def test_detect_datetimes(with_suffix, test_df):
104144
assert iso.inferred_format == "%Y-%m-%dT%X.%f"
105145

106146

107-
def test_infer_with_suffix():
108-
dates = [
109-
"2020-12-20T00:00:00Z",
110-
"2020-10-17T00:00:00Z",
111-
"2020-08-10T00:00:00Z",
112-
"2020-01-22T00:00:00Z",
113-
"2020-09-01T00:00:00Z",
114-
]
115-
assert _infer_from_series(dates, True) == "%Y-%m-%dT%XZ"
147+
@pytest.mark.parametrize("must_match_all", [False, True])
148+
def test_infer_with_suffix(must_match_all):
149+
dates = pd.Series(
150+
[
151+
"2020-12-20T00:00:00Z",
152+
"2020-10-17T00:00:00Z",
153+
"2020-08-10T00:00:00Z",
154+
"2020-01-22T00:00:00Z",
155+
"2020-09-01T00:00:00Z",
156+
]
157+
)
158+
assert (
159+
_infer_from_series(dates, True, must_match_all=must_match_all) == "%Y-%m-%dT%XZ"
160+
)
116161

117-
dates_2 = [d.replace("Z", "+00:00") for d in dates.copy()]
118-
assert _infer_from_series(dates_2, True) == "%Y-%m-%dT%X+00:00"
162+
dates_2 = pd.Series([d.replace("Z", "+00:00") for d in dates])
163+
assert (
164+
_infer_from_series(dates_2, True, must_match_all=must_match_all)
165+
== "%Y-%m-%dT%X+00:00"
166+
)
119167

120-
dates_3 = [d.replace("Z", "-00:00") for d in dates.copy()]
121-
assert _infer_from_series(dates_3, True) == "%Y-%m-%dT%X-00:00"
168+
dates_3 = pd.Series([d.replace("Z", "-00:00") for d in dates])
169+
assert (
170+
_infer_from_series(dates_3, True, must_match_all=must_match_all)
171+
== "%Y-%m-%dT%X-00:00"
172+
)
122173

123174

124-
def test_detect_datetimes_with_suffix(test_df):
175+
@pytest.mark.parametrize("must_match_all", [False, True])
176+
def test_detect_datetimes_with_suffix(must_match_all, test_df):
177+
# Prevent modification of the session-scoped fixture object
178+
test_df = test_df.copy()
125179
# Add a TZ suffix of "Z" to the iso strings
126180
test_df["iso"] = test_df["iso"].astype("string").apply(lambda val: val + "Z")
127-
check = detect_datetimes(test_df, with_suffix=True)
181+
check = detect_datetimes(test_df, with_suffix=True, must_match_all=must_match_all)
128182
assert set(check.column_names) == {"dates", "iso"}
129183

130184
iso = check.get_column_info("iso")
@@ -134,7 +188,8 @@ def test_detect_datetimes_with_suffix(test_df):
134188
assert iso.inferred_format == "%Y-%m-%dT%X.%fZ"
135189

136190

137-
def test_detect_datetimes_custom_formats():
191+
@pytest.mark.parametrize("must_match_all", [False, True])
192+
def test_detect_datetimes_custom_formats(must_match_all):
138193
df = pd.DataFrame(
139194
{
140195
"str": ["a", "b", "c"],
@@ -151,7 +206,7 @@ def test_detect_datetimes_custom_formats():
151206
}
152207
)
153208

154-
check = detect_datetimes(df)
209+
check = detect_datetimes(df, must_match_all=must_match_all)
155210

156211
assert set(check.column_names) == {
157212
"dateandtime",

0 commit comments

Comments
 (0)