Skip to content

Commit d8c3b1a

Browse files
Add chunk-friendly code path to encode_cf_datetime and encode_cf_timedelta (#8575)
* Add proof of concept dask-friendly datetime encoding * Add dask support for timedelta encoding and more tests * Minor error message edits; add what's new entry * Add return type for new tests * Fix typo in what's new * Add what's new entry for update following #8542 * Add full type hints to encoding functions * Combine datetime64 and timedelta64 zarr tests; add cftime zarr test * Minor edits to what's new * Address initial review comments * Add proof of concept dask-friendly datetime encoding * Add dask support for timedelta encoding and more tests * Minor error message edits; add what's new entry * Add return type for new tests * Fix typo in what's new * Add what's new entry for update following #8542 * Add full type hints to encoding functions * Combine datetime64 and timedelta64 zarr tests; add cftime zarr test * Minor edits to what's new * Address initial review comments * Initial work toward addressing typing comments * Restore covariant=True in T_DuckArray; add type: ignores * Tweak netCDF3 error message * Move what's new entry * Remove extraneous text from merge in what's new * Remove unused type: ignore comment * Remove word from netCDF3 error message
1 parent e22b475 commit d8c3b1a

File tree

6 files changed

+438
-14
lines changed

6 files changed

+438
-14
lines changed

doc/whats-new.rst

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,19 @@ Bug fixes
4444
By `Tom Nicholas <https://github.com/TomNicholas>`_.
4545
- Ensure :py:meth:`DataArray.unstack` works when wrapping array API-compliant classes. (:issue:`8666`, :pull:`8668`)
4646
By `Tom Nicholas <https://github.com/TomNicholas>`_.
47+
- Preserve chunks when writing time-like variables to zarr by enabling lazy CF
48+
encoding of time-like variables (:issue:`7132`, :issue:`8230`, :issue:`8432`,
49+
:pull:`8575`). By `Spencer Clark <https://github.com/spencerkclark>`_ and
50+
`Mattia Almansi <https://github.com/malmans2>`_.
51+
- Preserve chunks when writing time-like variables to zarr by enabling their
52+
lazy encoding (:issue:`7132`, :issue:`8230`, :issue:`8432`, :pull:`8253`,
53+
:pull:`8575`; see also discussion in :pull:`8253`). By `Spencer Clark
54+
<https://github.com/spencerkclark>`_ and `Mattia Almansi
55+
<https://github.com/malmans2>`_.
56+
- Raise an informative error if dtype encoding of time-like variables would
57+
lead to integer overflow or unsafe conversion from floating point to integer
58+
values (:issue:`8542`, :pull:`8575`). By `Spencer Clark
59+
<https://github.com/spencerkclark>`_.
4760

4861
Documentation
4962
~~~~~~~~~~~~~

xarray/backends/netcdf3.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,21 @@
4242

4343
# encode all strings as UTF-8
4444
STRING_ENCODING = "utf-8"
45+
COERCION_VALUE_ERROR = (
46+
"could not safely cast array from {dtype} to {new_dtype}. While it is not "
47+
"always the case, a common reason for this is that xarray has deemed it "
48+
"safest to encode np.datetime64[ns] or np.timedelta64[ns] values with "
49+
"int64 values representing units of 'nanoseconds'. This is either due to "
50+
"the fact that the times are known to require nanosecond precision for an "
51+
"accurate round trip, or that the times are unknown prior to writing due "
52+
"to being contained in a chunked array. Ways to work around this are "
53+
"either to use a backend that supports writing int64 values, or to "
54+
"manually specify the encoding['units'] and encoding['dtype'] (e.g. "
55+
"'seconds since 1970-01-01' and np.dtype('int32')) on the time "
56+
"variable(s) such that the times can be serialized in a netCDF3 file "
57+
"(note that depending on the situation, however, this latter option may "
58+
"result in an inaccurate round trip)."
59+
)
4560

4661

4762
def coerce_nc3_dtype(arr):
@@ -66,7 +81,7 @@ def coerce_nc3_dtype(arr):
6681
cast_arr = arr.astype(new_dtype)
6782
if not (cast_arr == arr).all():
6883
raise ValueError(
69-
f"could not safely cast array from dtype {dtype} to {new_dtype}"
84+
COERCION_VALUE_ERROR.format(dtype=dtype, new_dtype=new_dtype)
7085
)
7186
arr = cast_arr
7287
return arr

xarray/coding/times.py

Lines changed: 173 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,11 @@
2222
)
2323
from xarray.core import indexing
2424
from xarray.core.common import contains_cftime_datetimes, is_np_datetime_like
25+
from xarray.core.duck_array_ops import asarray
2526
from xarray.core.formatting import first_n_items, format_timestamp, last_item
27+
from xarray.core.parallelcompat import T_ChunkedArray, get_chunked_array_type
2628
from xarray.core.pdcompat import nanosecond_precision_timestamp
27-
from xarray.core.pycompat import is_duck_dask_array
29+
from xarray.core.pycompat import is_chunked_array, is_duck_dask_array
2830
from xarray.core.utils import emit_user_level_warning
2931
from xarray.core.variable import Variable
3032

@@ -34,7 +36,7 @@
3436
cftime = None
3537

3638
if TYPE_CHECKING:
37-
from xarray.core.types import CFCalendar
39+
from xarray.core.types import CFCalendar, T_DuckArray
3840

3941
T_Name = Union[Hashable, None]
4042

@@ -667,12 +669,48 @@ def _division(deltas, delta, floor):
667669
return num
668670

669671

672+
def _cast_to_dtype_if_safe(num: np.ndarray, dtype: np.dtype) -> np.ndarray:
673+
with warnings.catch_warnings():
674+
warnings.filterwarnings("ignore", message="overflow")
675+
cast_num = np.asarray(num, dtype=dtype)
676+
677+
if np.issubdtype(dtype, np.integer):
678+
if not (num == cast_num).all():
679+
if np.issubdtype(num.dtype, np.floating):
680+
raise ValueError(
681+
f"Not possible to cast all encoded times from "
682+
f"{num.dtype!r} to {dtype!r} without losing precision. "
683+
f"Consider modifying the units such that integer values "
684+
f"can be used, or removing the units and dtype encoding, "
685+
f"at which point xarray will make an appropriate choice."
686+
)
687+
else:
688+
raise OverflowError(
689+
f"Not possible to cast encoded times from "
690+
f"{num.dtype!r} to {dtype!r} without overflow. Consider "
691+
f"removing the dtype encoding, at which point xarray will "
692+
f"make an appropriate choice, or explicitly switching to "
693+
"a larger integer dtype."
694+
)
695+
else:
696+
if np.isinf(cast_num).any():
697+
raise OverflowError(
698+
f"Not possible to cast encoded times from {num.dtype!r} to "
699+
f"{dtype!r} without overflow. Consider removing the dtype "
700+
f"encoding, at which point xarray will make an appropriate "
701+
f"choice, or explicitly switching to a larger floating point "
702+
f"dtype."
703+
)
704+
705+
return cast_num
706+
707+
670708
def encode_cf_datetime(
671-
dates,
709+
dates: T_DuckArray, # type: ignore
672710
units: str | None = None,
673711
calendar: str | None = None,
674712
dtype: np.dtype | None = None,
675-
) -> tuple[np.ndarray, str, str]:
713+
) -> tuple[T_DuckArray, str, str]:
676714
"""Given an array of datetime objects, returns the tuple `(num, units,
677715
calendar)` suitable for a CF compliant time variable.
678716
@@ -682,7 +720,21 @@ def encode_cf_datetime(
682720
--------
683721
cftime.date2num
684722
"""
685-
dates = np.asarray(dates)
723+
dates = asarray(dates)
724+
if is_chunked_array(dates):
725+
return _lazily_encode_cf_datetime(dates, units, calendar, dtype)
726+
else:
727+
return _eagerly_encode_cf_datetime(dates, units, calendar, dtype)
728+
729+
730+
def _eagerly_encode_cf_datetime(
731+
dates: T_DuckArray, # type: ignore
732+
units: str | None = None,
733+
calendar: str | None = None,
734+
dtype: np.dtype | None = None,
735+
allow_units_modification: bool = True,
736+
) -> tuple[T_DuckArray, str, str]:
737+
dates = asarray(dates)
686738

687739
data_units = infer_datetime_units(dates)
688740

@@ -731,7 +783,7 @@ def encode_cf_datetime(
731783
f"Set encoding['dtype'] to integer dtype to serialize to int64. "
732784
f"Set encoding['dtype'] to floating point dtype to silence this warning."
733785
)
734-
elif np.issubdtype(dtype, np.integer):
786+
elif np.issubdtype(dtype, np.integer) and allow_units_modification:
735787
new_units = f"{needed_units} since {format_timestamp(ref_date)}"
736788
emit_user_level_warning(
737789
f"Times can't be serialized faithfully to int64 with requested units {units!r}. "
@@ -752,12 +804,80 @@ def encode_cf_datetime(
752804
# we already covered for this in pandas-based flow
753805
num = cast_to_int_if_safe(num)
754806

755-
return (num, units, calendar)
807+
if dtype is not None:
808+
num = _cast_to_dtype_if_safe(num, dtype)
809+
810+
return num, units, calendar
811+
812+
813+
def _encode_cf_datetime_within_map_blocks(
814+
dates: T_DuckArray, # type: ignore
815+
units: str,
816+
calendar: str,
817+
dtype: np.dtype,
818+
) -> T_DuckArray:
819+
num, *_ = _eagerly_encode_cf_datetime(
820+
dates, units, calendar, dtype, allow_units_modification=False
821+
)
822+
return num
823+
824+
825+
def _lazily_encode_cf_datetime(
826+
dates: T_ChunkedArray,
827+
units: str | None = None,
828+
calendar: str | None = None,
829+
dtype: np.dtype | None = None,
830+
) -> tuple[T_ChunkedArray, str, str]:
831+
if calendar is None:
832+
# This will only trigger minor compute if dates is an object dtype array.
833+
calendar = infer_calendar_name(dates)
834+
835+
if units is None and dtype is None:
836+
if dates.dtype == "O":
837+
units = "microseconds since 1970-01-01"
838+
dtype = np.dtype("int64")
839+
else:
840+
units = "nanoseconds since 1970-01-01"
841+
dtype = np.dtype("int64")
842+
843+
if units is None or dtype is None:
844+
raise ValueError(
845+
f"When encoding chunked arrays of datetime values, both the units "
846+
f"and dtype must be prescribed or both must be unprescribed. "
847+
f"Prescribing only one or the other is not currently supported. "
848+
f"Got a units encoding of {units} and a dtype encoding of {dtype}."
849+
)
850+
851+
chunkmanager = get_chunked_array_type(dates)
852+
num = chunkmanager.map_blocks(
853+
_encode_cf_datetime_within_map_blocks,
854+
dates,
855+
units,
856+
calendar,
857+
dtype,
858+
dtype=dtype,
859+
)
860+
return num, units, calendar
756861

757862

758863
def encode_cf_timedelta(
759-
timedeltas, units: str | None = None, dtype: np.dtype | None = None
760-
) -> tuple[np.ndarray, str]:
864+
timedeltas: T_DuckArray, # type: ignore
865+
units: str | None = None,
866+
dtype: np.dtype | None = None,
867+
) -> tuple[T_DuckArray, str]:
868+
timedeltas = asarray(timedeltas)
869+
if is_chunked_array(timedeltas):
870+
return _lazily_encode_cf_timedelta(timedeltas, units, dtype)
871+
else:
872+
return _eagerly_encode_cf_timedelta(timedeltas, units, dtype)
873+
874+
875+
def _eagerly_encode_cf_timedelta(
876+
timedeltas: T_DuckArray, # type: ignore
877+
units: str | None = None,
878+
dtype: np.dtype | None = None,
879+
allow_units_modification: bool = True,
880+
) -> tuple[T_DuckArray, str]:
761881
data_units = infer_timedelta_units(timedeltas)
762882

763883
if units is None:
@@ -784,7 +904,7 @@ def encode_cf_timedelta(
784904
f"Set encoding['dtype'] to integer dtype to serialize to int64. "
785905
f"Set encoding['dtype'] to floating point dtype to silence this warning."
786906
)
787-
elif np.issubdtype(dtype, np.integer):
907+
elif np.issubdtype(dtype, np.integer) and allow_units_modification:
788908
emit_user_level_warning(
789909
f"Timedeltas can't be serialized faithfully with requested units {units!r}. "
790910
f"Serializing with units {needed_units!r} instead. "
@@ -797,7 +917,49 @@ def encode_cf_timedelta(
797917

798918
num = _division(time_deltas, time_delta, floor_division)
799919
num = num.values.reshape(timedeltas.shape)
800-
return (num, units)
920+
921+
if dtype is not None:
922+
num = _cast_to_dtype_if_safe(num, dtype)
923+
924+
return num, units
925+
926+
927+
def _encode_cf_timedelta_within_map_blocks(
928+
timedeltas: T_DuckArray, # type:ignore
929+
units: str,
930+
dtype: np.dtype,
931+
) -> T_DuckArray:
932+
num, _ = _eagerly_encode_cf_timedelta(
933+
timedeltas, units, dtype, allow_units_modification=False
934+
)
935+
return num
936+
937+
938+
def _lazily_encode_cf_timedelta(
939+
timedeltas: T_ChunkedArray, units: str | None = None, dtype: np.dtype | None = None
940+
) -> tuple[T_ChunkedArray, str]:
941+
if units is None and dtype is None:
942+
units = "nanoseconds"
943+
dtype = np.dtype("int64")
944+
945+
if units is None or dtype is None:
946+
raise ValueError(
947+
f"When encoding chunked arrays of timedelta values, both the "
948+
f"units and dtype must be prescribed or both must be "
949+
f"unprescribed. Prescribing only one or the other is not "
950+
f"currently supported. Got a units encoding of {units} and a "
951+
f"dtype encoding of {dtype}."
952+
)
953+
954+
chunkmanager = get_chunked_array_type(timedeltas)
955+
num = chunkmanager.map_blocks(
956+
_encode_cf_timedelta_within_map_blocks,
957+
timedeltas,
958+
units,
959+
dtype,
960+
dtype=dtype,
961+
)
962+
return num, units
801963

802964

803965
class CFDatetimeCoder(VariableCoder):

xarray/core/parallelcompat.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
from xarray.core.pycompat import is_chunked_array
2424

25-
T_ChunkedArray = TypeVar("T_ChunkedArray")
25+
T_ChunkedArray = TypeVar("T_ChunkedArray", bound=Any)
2626

2727
if TYPE_CHECKING:
2828
from xarray.core.types import T_Chunks, T_DuckArray, T_NormalizedChunks
@@ -310,7 +310,7 @@ def rechunk(
310310
dask.array.Array.rechunk
311311
cubed.Array.rechunk
312312
"""
313-
return data.rechunk(chunks, **kwargs) # type: ignore[attr-defined]
313+
return data.rechunk(chunks, **kwargs)
314314

315315
@abstractmethod
316316
def compute(self, *data: T_ChunkedArray | Any, **kwargs) -> tuple[np.ndarray, ...]:

xarray/tests/test_backends.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
)
4949
from xarray.backends.pydap_ import PydapDataStore
5050
from xarray.backends.scipy_ import ScipyBackendEntrypoint
51+
from xarray.coding.cftime_offsets import cftime_range
5152
from xarray.coding.strings import check_vlen_dtype, create_vlen_dtype
5253
from xarray.coding.variables import SerializationWarning
5354
from xarray.conventions import encode_dataset_coordinates
@@ -2929,6 +2930,28 @@ def test_attributes(self, obj) -> None:
29292930
with pytest.raises(TypeError, match=r"Invalid attribute in Dataset.attrs."):
29302931
ds.to_zarr(store_target, **self.version_kwargs)
29312932

2933+
@requires_dask
2934+
@pytest.mark.parametrize("dtype", ["datetime64[ns]", "timedelta64[ns]"])
2935+
def test_chunked_datetime64_or_timedelta64(self, dtype) -> None:
2936+
# Generalized from @malmans2's test in PR #8253
2937+
original = create_test_data().astype(dtype).chunk(1)
2938+
with self.roundtrip(original, open_kwargs={"chunks": {}}) as actual:
2939+
for name, actual_var in actual.variables.items():
2940+
assert original[name].chunks == actual_var.chunks
2941+
assert original.chunks == actual.chunks
2942+
2943+
@requires_cftime
2944+
@requires_dask
2945+
def test_chunked_cftime_datetime(self) -> None:
2946+
# Based on @malmans2's test in PR #8253
2947+
times = cftime_range("2000", freq="D", periods=3)
2948+
original = xr.Dataset(data_vars={"chunked_times": (["time"], times)})
2949+
original = original.chunk({"time": 1})
2950+
with self.roundtrip(original, open_kwargs={"chunks": {}}) as actual:
2951+
for name, actual_var in actual.variables.items():
2952+
assert original[name].chunks == actual_var.chunks
2953+
assert original.chunks == actual.chunks
2954+
29322955
def test_vectorized_indexing_negative_step(self) -> None:
29332956
if not has_dask:
29342957
pytest.xfail(

0 commit comments

Comments
 (0)