Skip to content

Commit 4a1c501

Browse files
authored
Use cached_property for Datetime and Timedelta column properties (#18601)
xref #5695 Authors: - Matthew Roeschke (https://github.com/mroeschke) Approvers: - Matthew Murray (https://github.com/Matt711) URL: #18601
1 parent d4e49dd commit 4a1c501

File tree

4 files changed

+98
-43
lines changed

4 files changed

+98
-43
lines changed

python/cudf/cudf/core/column/datetime.py

Lines changed: 76 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,38 @@ def __init__(
249249
children=children,
250250
)
251251

252+
def _clear_cache(self) -> None:
253+
super()._clear_cache()
254+
attrs = (
255+
"days_in_month",
256+
"is_year_start",
257+
"is_leap_year",
258+
"is_year_end",
259+
"is_quarter_start",
260+
"is_quarter_end",
261+
"is_month_start",
262+
"is_month_end",
263+
"day_of_year",
264+
"weekday",
265+
"nanosecond",
266+
"microsecond",
267+
"millisecond",
268+
"second",
269+
"minute",
270+
"hour",
271+
"day",
272+
"month",
273+
"year",
274+
"quarter",
275+
"time_unit",
276+
)
277+
for attr in attrs:
278+
try:
279+
delattr(self, attr)
280+
except AttributeError:
281+
# attr was not called yet, so ignore.
282+
pass
283+
252284
@staticmethod
253285
def _validate_dtype_instance(dtype: np.dtype) -> np.dtype:
254286
if not (isinstance(dtype, np.dtype) and dtype.kind == "M"):
@@ -287,86 +319,86 @@ def _validate_fillna_value(
287319
def time_unit(self) -> str:
288320
return np.datetime_data(self.dtype)[0]
289321

290-
@property
322+
@functools.cached_property
291323
@acquire_spill_lock()
292324
def quarter(self) -> ColumnBase:
293325
return type(self).from_pylibcudf(
294326
plc.datetime.extract_quarter(self.to_pylibcudf(mode="read"))
295327
)
296328

297-
@property
329+
@functools.cached_property
298330
def year(self) -> ColumnBase:
299331
return self._get_dt_field(plc.datetime.DatetimeComponent.YEAR)
300332

301-
@property
333+
@functools.cached_property
302334
def month(self) -> ColumnBase:
303335
return self._get_dt_field(plc.datetime.DatetimeComponent.MONTH)
304336

305-
@property
337+
@functools.cached_property
306338
def day(self) -> ColumnBase:
307339
return self._get_dt_field(plc.datetime.DatetimeComponent.DAY)
308340

309-
@property
341+
@functools.cached_property
310342
def hour(self) -> ColumnBase:
311343
return self._get_dt_field(plc.datetime.DatetimeComponent.HOUR)
312344

313-
@property
345+
@functools.cached_property
314346
def minute(self) -> ColumnBase:
315347
return self._get_dt_field(plc.datetime.DatetimeComponent.MINUTE)
316348

317-
@property
349+
@functools.cached_property
318350
def second(self) -> ColumnBase:
319351
return self._get_dt_field(plc.datetime.DatetimeComponent.SECOND)
320352

321-
@property
353+
@functools.cached_property
322354
def millisecond(self) -> ColumnBase:
323355
return self._get_dt_field(plc.datetime.DatetimeComponent.MILLISECOND)
324356

325-
@property
357+
@functools.cached_property
326358
def microsecond(self) -> ColumnBase:
327359
return self._get_dt_field(plc.datetime.DatetimeComponent.MICROSECOND)
328360

329-
@property
361+
@functools.cached_property
330362
def nanosecond(self) -> ColumnBase:
331363
return self._get_dt_field(plc.datetime.DatetimeComponent.NANOSECOND)
332364

333-
@property
365+
@functools.cached_property
334366
def weekday(self) -> ColumnBase:
335367
# pandas counts Monday-Sunday as 0-6
336368
# while libcudf counts Monday-Sunday as 1-7
337369
result = self._get_dt_field(plc.datetime.DatetimeComponent.WEEKDAY)
338370
return result - result.dtype.type(1)
339371

340-
@property
372+
@functools.cached_property
341373
@acquire_spill_lock()
342374
def day_of_year(self) -> ColumnBase:
343375
return type(self).from_pylibcudf(
344376
plc.datetime.day_of_year(self.to_pylibcudf(mode="read"))
345377
)
346378

347-
@property
379+
@functools.cached_property
348380
def is_month_start(self) -> ColumnBase:
349381
return (self.day == 1).fillna(False)
350382

351-
@property
383+
@functools.cached_property
352384
def is_month_end(self) -> ColumnBase:
353385
with acquire_spill_lock():
354386
last_day_col = type(self).from_pylibcudf(
355387
plc.datetime.last_day_of_month(self.to_pylibcudf(mode="read"))
356388
)
357389
return (self.day == last_day_col.day).fillna(False) # type: ignore[attr-defined]
358390

359-
@property
391+
@functools.cached_property
360392
def is_quarter_end(self) -> ColumnBase:
361393
last_month = self.month.isin([3, 6, 9, 12])
362394
return (self.is_month_end & last_month).fillna(False)
363395

364-
@property
396+
@functools.cached_property
365397
def is_quarter_start(self) -> ColumnBase:
366398
first_month = self.month.isin([1, 4, 7, 10])
367399
return (self.is_month_start & first_month).fillna(False)
368400

369-
@property
401+
@functools.cached_property
370402
def is_year_end(self) -> ColumnBase:
371403
day_of_year = self.day_of_year
372404
leap_dates = self.is_leap_year
@@ -375,18 +407,18 @@ def is_year_end(self) -> ColumnBase:
375407
non_leap = day_of_year == 365
376408
return leap.copy_if_else(non_leap, leap_dates).fillna(False)
377409

378-
@property
410+
@functools.cached_property
379411
@acquire_spill_lock()
380412
def is_leap_year(self) -> ColumnBase:
381413
return type(self).from_pylibcudf(
382414
plc.datetime.is_leap_year(self.to_pylibcudf(mode="read"))
383415
)
384416

385-
@property
417+
@functools.cached_property
386418
def is_year_start(self) -> ColumnBase:
387419
return (self.day_of_year == 1).fillna(False)
388420

389-
@property
421+
@functools.cached_property
390422
@acquire_spill_lock()
391423
def days_in_month(self) -> ColumnBase:
392424
return type(self).from_pylibcudf(
@@ -417,7 +449,7 @@ def values(self):
417449
Return a CuPy representation of the DateTimeColumn.
418450
"""
419451
raise NotImplementedError(
420-
"DateTime Arrays is not yet implemented in cudf"
452+
"DateTime Arrays is not yet implemented in cupy"
421453
)
422454

423455
def element_indexing(self, index: int):
@@ -922,12 +954,12 @@ def can_cast_safely(self, to_dtype: DtypeObj) -> bool:
922954
else:
923955
return False
924956

925-
def _with_type_metadata(self, dtype):
957+
def _with_type_metadata(self, dtype) -> DatetimeColumn:
926958
if isinstance(dtype, pd.DatetimeTZDtype):
927959
return DatetimeTZColumn(
928-
data=self.base_data,
960+
data=self.base_data, # type: ignore[arg-type]
929961
dtype=dtype,
930-
mask=self.base_mask,
962+
mask=self.base_mask, # type: ignore[arg-type]
931963
size=self.size,
932964
offset=self.offset,
933965
null_count=self.null_count,
@@ -1003,7 +1035,7 @@ def tz_localize(
10031035
tz: str | None,
10041036
ambiguous: Literal["NaT"] = "NaT",
10051037
nonexistent: Literal["NaT"] = "NaT",
1006-
):
1038+
) -> DatetimeColumn:
10071039
if tz is None:
10081040
return self.copy()
10091041
ambiguous, nonexistent = check_ambiguous_and_nonexistent(
@@ -1087,6 +1119,13 @@ def __init__(
10871119
children=children,
10881120
)
10891121

1122+
def _clear_cache(self) -> None:
1123+
super()._clear_cache()
1124+
try:
1125+
del self._local_time
1126+
except AttributeError:
1127+
pass
1128+
10901129
@staticmethod
10911130
def _validate_dtype_instance(
10921131
dtype: pd.DatetimeTZDtype,
@@ -1118,25 +1157,24 @@ def time_unit(self) -> str:
11181157
return self.dtype.unit
11191158

11201159
@property
1121-
def _utc_time(self):
1160+
def _utc_time(self) -> DatetimeColumn:
11221161
"""Return UTC time as naive timestamps."""
11231162
return DatetimeColumn(
1124-
data=self.base_data,
1163+
data=self.base_data, # type: ignore[arg-type]
11251164
dtype=_get_base_dtype(self.dtype),
1126-
mask=self.base_mask,
1165+
mask=self.base_mask, # type: ignore[arg-type]
11271166
size=self.size,
11281167
offset=self.offset,
11291168
null_count=self.null_count,
11301169
)
11311170

1132-
@property
1133-
def _local_time(self):
1171+
@functools.cached_property
1172+
def _local_time(self) -> DatetimeColumn:
11341173
"""Return the local time as naive timestamps."""
11351174
transition_times, offsets = get_tz_data(str(self.dtype.tz))
11361175
base_dtype = _get_base_dtype(self.dtype)
1137-
transition_times = transition_times.astype(base_dtype)
11381176
indices = (
1139-
transition_times.searchsorted(
1177+
transition_times.astype(base_dtype).searchsorted(
11401178
self.astype(base_dtype), side="right"
11411179
)
11421180
- 1
@@ -1173,7 +1211,7 @@ def _get_dt_field(
11731211
)
11741212
)
11751213

1176-
def __repr__(self):
1214+
def __repr__(self) -> str:
11771215
# Arrow prints the UTC timestamps, but we want to print the
11781216
# local timestamps:
11791217
arr = self._local_time.to_arrow().cast(
@@ -1183,7 +1221,9 @@ def __repr__(self):
11831221
f"{object.__repr__(self)}\n{arr.to_string()}\ndtype: {self.dtype}"
11841222
)
11851223

1186-
def tz_localize(self, tz: str | None, ambiguous="NaT", nonexistent="NaT"):
1224+
def tz_localize(
1225+
self, tz: str | None, ambiguous="NaT", nonexistent="NaT"
1226+
) -> DatetimeColumn:
11871227
if tz is None:
11881228
return self._local_time
11891229
ambiguous, nonexistent = check_ambiguous_and_nonexistent(
@@ -1194,14 +1234,14 @@ def tz_localize(self, tz: str | None, ambiguous="NaT", nonexistent="NaT"):
11941234
"Use `tz_convert` to convert between time zones."
11951235
)
11961236

1197-
def tz_convert(self, tz: str | None):
1237+
def tz_convert(self, tz: str | None) -> DatetimeColumn:
11981238
if tz is None:
11991239
return self._utc_time
12001240
elif tz == str(self.dtype.tz):
12011241
return self.copy()
12021242
utc_time = self._utc_time
12031243
return type(self)(
1204-
data=utc_time.base_data,
1244+
data=utc_time.base_data, # type: ignore[arg-type]
12051245
dtype=pd.DatetimeTZDtype(self.time_unit, tz),
12061246
mask=utc_time.base_mask,
12071247
size=utc_time.size,

python/cudf/cudf/core/column/timedelta.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,21 @@ def __init__(
139139
children=children,
140140
)
141141

142+
def _clear_cache(self) -> None:
143+
super()._clear_cache()
144+
attrs = (
145+
"days",
146+
"seconds",
147+
"microseconds",
148+
"nanoseconds",
149+
"time_unit",
150+
)
151+
for attr in attrs:
152+
try:
153+
delattr(self, attr)
154+
except AttributeError:
155+
pass
156+
142157
def __contains__(self, item: DatetimeLikeScalar) -> bool:
143158
try:
144159
item = np.timedelta64(item, self.time_unit)
@@ -170,7 +185,7 @@ def values(self):
170185
Return a CuPy representation of the TimeDeltaColumn.
171186
"""
172187
raise NotImplementedError(
173-
"TimeDelta Arrays is not yet implemented in cudf"
188+
"TimeDelta Arrays is not yet implemented in cupy"
174189
)
175190

176191
def element_indexing(self, index: int):
@@ -610,7 +625,7 @@ def components(self) -> dict[str, ColumnBase]:
610625
data[result_key] = res_col
611626
return data
612627

613-
@property
628+
@functools.cached_property
614629
def days(self) -> cudf.core.column.NumericalColumn:
615630
"""
616631
Number of days for each element.
@@ -621,7 +636,7 @@ def days(self) -> cudf.core.column.NumericalColumn:
621636
"""
622637
return self // get_np_td_unit_conversion("D", self.dtype)
623638

624-
@property
639+
@functools.cached_property
625640
def seconds(self) -> cudf.core.column.NumericalColumn:
626641
"""
627642
Number of seconds (>= 0 and less than 1 day).
@@ -639,7 +654,7 @@ def seconds(self) -> cudf.core.column.NumericalColumn:
639654
self % get_np_td_unit_conversion("D", self.dtype)
640655
) // get_np_td_unit_conversion("s", None)
641656

642-
@property
657+
@functools.cached_property
643658
def microseconds(self) -> cudf.core.column.NumericalColumn:
644659
"""
645660
Number of microseconds (>= 0 and less than 1 second).
@@ -657,7 +672,7 @@ def microseconds(self) -> cudf.core.column.NumericalColumn:
657672
self % get_np_td_unit_conversion("s", self.dtype)
658673
) // get_np_td_unit_conversion("us", None)
659674

660-
@property
675+
@functools.cached_property
661676
def nanoseconds(self) -> cudf.core.column.NumericalColumn:
662677
"""
663678
Return the number of nanoseconds (n), where 0 <= n < 1 microsecond.

python/cudf/cudf/tests/test_datetime.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1858,7 +1858,7 @@ def test_error_values():
18581858
s = cudf.Series([1, 2, 3], dtype="datetime64[ns]")
18591859
with pytest.raises(
18601860
NotImplementedError,
1861-
match="DateTime Arrays is not yet implemented in cudf",
1861+
match="DateTime Arrays is not yet implemented in cupy",
18621862
):
18631863
s.values
18641864

python/cudf/cudf/tests/test_timedelta.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1213,7 +1213,7 @@ def test_error_values():
12131213
s = cudf.Series([1, 2, 3], dtype="timedelta64[ns]")
12141214
with pytest.raises(
12151215
NotImplementedError,
1216-
match="TimeDelta Arrays is not yet implemented in cudf",
1216+
match="TimeDelta Arrays is not yet implemented in cupy",
12171217
):
12181218
s.values
12191219

0 commit comments

Comments
 (0)