From 297cd2e0503b86990dc9f84f69dd215498a8a147 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 1 Jul 2025 16:14:49 -0600 Subject: [PATCH] Add xarray-specific encoding convention for pd.IntervalArray Closes #2847 xref https://github.com/pydata/xarray/issues/8005#issuecomment-3011641252 --- xarray/coding/variables.py | 50 ++++++++++++++++++++++++++++++++ xarray/conventions.py | 5 ++++ xarray/tests/test_coding.py | 28 ++++++++++++++++++ xarray/tests/test_conventions.py | 20 +++++++++++++ 4 files changed, 103 insertions(+) diff --git a/xarray/coding/variables.py b/xarray/coding/variables.py index 3b7be898ccf..c95b32d644e 100644 --- a/xarray/coding/variables.py +++ b/xarray/coding/variables.py @@ -696,3 +696,53 @@ def encode(self, variable: Variable, name: T_Name = None) -> Variable: def decode(self, variable: Variable, name: T_Name = None) -> Variable: raise NotImplementedError() + + +class IntervalCoder(VariableCoder): + """ + Xarray-specific Interval Coder to roundtrip 1D pd.IntervalArray objects. + """ + + encoded_dtype = "pandas_interval" + encoded_bounds_dim = "__xarray_bounds__" + + def encode(self, variable: Variable, name: T_Name = None) -> Variable: + if isinstance(dtype := variable.dtype, pd.IntervalDtype): + dims, data, attrs, encoding = unpack_for_encoding(variable) + + new_data = np.stack([data.left, data.right], axis=0) + dims = (self.encoded_bounds_dim, *dims) + safe_setitem(attrs, "closed", dtype.closed, name=name) + safe_setitem(attrs, "dtype", self.encoded_dtype, name=name) + safe_setitem(attrs, "bounds_dim", self.encoded_bounds_dim, name=name) + return Variable(dims, new_data, attrs, encoding, fastpath=True) + else: + return Variable + + def decode(self, variable: Variable, name: T_Name = None) -> Variable: + if ( + variable.attrs.get("dtype", None) == self.encoded_dtype + and self.encoded_bounds_dim in variable.dims + ): + if variable.ndim != 2: + raise ValueError( + f"Cannot decode intervals for variable named {name!r} with more than two dimensions." + ) + + dims, data, attrs, encoding = unpack_for_decoding(variable) + pop_to(attrs, encoding, "dtype", name=name) + pop_to(attrs, encoding, "bounds_dim", name=name) + closed = pop_to(attrs, encoding, "closed", name=name) + + _, new_dims = variable.dims + variable = variable.load() + new_data = pd.arrays.IntervalArray.from_arrays( + variable.isel({self.encoded_bounds_dim: 0}).data, + variable.isel({self.encoded_bounds_dim: 1}).data, + closed=closed, + ) + return Variable( + dims=new_dims, data=new_data, attrs=attrs, encoding=encoding + ) + else: + return Variable diff --git a/xarray/conventions.py b/xarray/conventions.py index 5ae40ea57d8..babe4712ebe 100644 --- a/xarray/conventions.py +++ b/xarray/conventions.py @@ -90,6 +90,9 @@ def encode_cf_variable( ensure_not_multiindex(var, name=name) for coder in [ + # IntervalCoder must be before CFDatetimeCoder, + # so we can first encode the interval, then datetimes if necessary + variables.IntervalCoder(), CFDatetimeCoder(), CFTimedeltaCoder(), variables.CFScaleOffsetCoder(), @@ -238,6 +241,8 @@ def decode_cf_variable( ) var = decode_times.decode(var, name=name) + var = variables.IntervalCoder().decode(var) + if decode_endianness and not var.dtype.isnative: var = variables.EndianCoder().decode(var) original_dtype = var.dtype diff --git a/xarray/tests/test_coding.py b/xarray/tests/test_coding.py index acb32504948..3260f9570e9 100644 --- a/xarray/tests/test_coding.py +++ b/xarray/tests/test_coding.py @@ -147,3 +147,31 @@ def test_decode_signed_from_unsigned(bits) -> None: decoded = coder.decode(encoded) assert decoded.dtype == signed_dtype assert decoded.values == original_values + + +@pytest.mark.parametrize( + "data", + [ + [1, 2, 3, 4], + np.array([1, 2, 3, 4], dtype=float), + pd.date_range("2001-01-01", "2002-01-01", freq="MS"), + ], +) +@pytest.mark.parametrize("closed", ["left", "right", "both", "neither"]) +def test_roundtrip_pandas_interval(data, closed) -> None: + v = xr.Variable("time", pd.IntervalIndex.from_breaks(data, closed=closed)) + coder = variables.IntervalCoder() + encoded = coder.encode(v) + expected = xr.Variable( + dims=("__xarray_bounds__", "time"), + data=np.stack([data[:-1], data[1:]], axis=0), + attrs={ + "dtype": "pandas_interval", + "bounds_dim": "__xarray_bounds__", + "closed": closed, + }, + ) + assert_identical(encoded, expected) + + decoded = coder.decode(encoded) + assert_identical(decoded, v) diff --git a/xarray/tests/test_conventions.py b/xarray/tests/test_conventions.py index ce792c83740..34c06a23235 100644 --- a/xarray/tests/test_conventions.py +++ b/xarray/tests/test_conventions.py @@ -666,3 +666,23 @@ def test_decode_cf_variables_decode_timedelta_warning() -> None: with pytest.warns(FutureWarning, match="decode_timedelta"): conventions.decode_cf_variables(variables, {}) + + +@pytest.mark.parametrize( + "data", + [ + [1, 2, 3, 4], + np.array([1, 2, 3, 4], dtype=float), + pd.date_range("2001-01-01", "2002-01-01", freq="MS"), + ], +) +@pytest.mark.parametrize("closed", ["left", "right", "both", "neither"]) +def test_roundtrip_pandas_interval(data, closed) -> None: + v = Variable("time", pd.IntervalIndex.from_breaks(data, closed=closed)) + encoded = conventions.encode_cf_variable(v) + if isinstance(data, pd.DatetimeIndex): + # make sure we've encoded datetimes. + assert "units" in encoded.attrs + assert "calendar" in encoded.attrs + roundtripped = conventions.decode_cf_variable("foo", encoded) + assert_identical(roundtripped, v)