Skip to content

Commit 087ebbb

Browse files
authored
ds.to_dict with data as arrays, not lists (#7739)
* first stab at ds.to_dict giving data as numpy objects * update whats-new.rst * mypy flailing: add dict typing to test_dataarray * mypy flailing 2: add dict typing to core/varaible.py * testing equality of encodings on Ds.from_dict(ds.to_dict()) roundtrips * .values -> .to_numpy() * requested changes on 4/19/23 * to_dict kwarg data handles bool and str, "list" and True return list of Python datatypes, "array" returns numpy.ndarrays, False returns only the schema * fix mypy hashable not being string * touch ups on to_dict() changes * to_dict with dask, tested. other minor things * touch up * finalize to_dict()
1 parent a220022 commit 087ebbb

File tree

6 files changed

+113
-36
lines changed

6 files changed

+113
-36
lines changed

doc/whats-new.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,10 @@ New Features
105105
- Added ability to save ``DataArray`` objects directly to Zarr using :py:meth:`~xarray.DataArray.to_zarr`.
106106
(:issue:`7692`, :pull:`7693`) .
107107
By `Joe Hamman <https://github.com/jhamman>`_.
108+
- Keyword argument `data='array'` to both :py:meth:`xarray.Dataset.to_dict` and
109+
:py:meth:`xarray.DataArray.to_dict` will now return data as the underlying array type. Python lists are returned for `data='list'` or `data=True`. Supplying `data=False` only returns the schema without data. ``encoding=True`` returns the encoding dictionary for the underlying variable also.
110+
(:issue:`1599`, :pull:`7739`) .
111+
By `James McCreight <https://github.com/jmccreight>`_.
108112

109113
Breaking changes
110114
~~~~~~~~~~~~~~~~

xarray/core/dataarray.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4174,7 +4174,9 @@ def to_zarr(
41744174
zarr_version=zarr_version,
41754175
)
41764176

4177-
def to_dict(self, data: bool = True, encoding: bool = False) -> dict[str, Any]:
4177+
def to_dict(
4178+
self, data: bool | Literal["list", "array"] = "list", encoding: bool = False
4179+
) -> dict[str, Any]:
41784180
"""
41794181
Convert this xarray.DataArray into a dictionary following xarray
41804182
naming conventions.
@@ -4185,9 +4187,14 @@ def to_dict(self, data: bool = True, encoding: bool = False) -> dict[str, Any]:
41854187
41864188
Parameters
41874189
----------
4188-
data : bool, default: True
4190+
data : bool or {"list", "array"}, default: "list"
41894191
Whether to include the actual data in the dictionary. When set to
4190-
False, returns just the schema.
4192+
False, returns just the schema. If set to "array", returns data as
4193+
underlying array type. If set to "list" (or True for backwards
4194+
compatibility), returns data in lists of Python data types. Note
4195+
that for obtaining the "list" output efficiently, use
4196+
`da.compute().to_dict(data="list")`.
4197+
41914198
encoding : bool, default: False
41924199
Whether to include the Dataset's encoding in the dictionary.
41934200

xarray/core/dataset.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6441,7 +6441,9 @@ def to_dask_dataframe(
64416441

64426442
return df
64436443

6444-
def to_dict(self, data: bool = True, encoding: bool = False) -> dict[str, Any]:
6444+
def to_dict(
6445+
self, data: bool | Literal["list", "array"] = "list", encoding: bool = False
6446+
) -> dict[str, Any]:
64456447
"""
64466448
Convert this dataset to a dictionary following xarray naming
64476449
conventions.
@@ -6452,9 +6454,14 @@ def to_dict(self, data: bool = True, encoding: bool = False) -> dict[str, Any]:
64526454
64536455
Parameters
64546456
----------
6455-
data : bool, default: True
6457+
data : bool or {"list", "array"}, default: "list"
64566458
Whether to include the actual data in the dictionary. When set to
6457-
False, returns just the schema.
6459+
False, returns just the schema. If set to "array", returns data as
6460+
underlying array type. If set to "list" (or True for backwards
6461+
compatibility), returns data in lists of Python data types. Note
6462+
that for obtaining the "list" output efficiently, use
6463+
`ds.compute().to_dict(data="list")`.
6464+
64586465
encoding : bool, default: False
64596466
Whether to include the Dataset's encoding in the dictionary.
64606467
@@ -6560,7 +6567,8 @@ def from_dict(cls: type[T_Dataset], d: Mapping[Any, Any]) -> T_Dataset:
65606567
)
65616568
try:
65626569
variable_dict = {
6563-
k: (v["dims"], v["data"], v.get("attrs")) for k, v in variables
6570+
k: (v["dims"], v["data"], v.get("attrs"), v.get("encoding"))
6571+
for k, v in variables
65646572
}
65656573
except KeyError as e:
65666574
raise ValueError(

xarray/core/variable.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -633,11 +633,23 @@ def to_index(self) -> pd.Index:
633633
"""Convert this variable to a pandas.Index"""
634634
return self.to_index_variable().to_index()
635635

636-
def to_dict(self, data: bool = True, encoding: bool = False) -> dict:
636+
def to_dict(
637+
self, data: bool | str = "list", encoding: bool = False
638+
) -> dict[str, Any]:
637639
"""Dictionary representation of variable."""
638-
item = {"dims": self.dims, "attrs": decode_numpy_dict_values(self.attrs)}
639-
if data:
640-
item["data"] = ensure_us_time_resolution(self.values).tolist()
640+
item: dict[str, Any] = {
641+
"dims": self.dims,
642+
"attrs": decode_numpy_dict_values(self.attrs),
643+
}
644+
if data is not False:
645+
if data in [True, "list"]:
646+
item["data"] = ensure_us_time_resolution(self.to_numpy()).tolist()
647+
elif data == "array":
648+
item["data"] = ensure_us_time_resolution(self.data)
649+
else:
650+
msg = 'data argument must be bool, "list", or "array"'
651+
raise ValueError(msg)
652+
641653
else:
642654
item.update({"dtype": str(self.dtype), "shape": self.shape})
643655

xarray/tests/test_dataarray.py

Lines changed: 40 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from collections.abc import Hashable
77
from copy import deepcopy
88
from textwrap import dedent
9-
from typing import Any, Final, cast
9+
from typing import Any, Final, Literal, cast
1010

1111
import numpy as np
1212
import pandas as pd
@@ -3345,46 +3345,70 @@ def test_series_categorical_index(self) -> None:
33453345
arr = DataArray(s)
33463346
assert "'a'" in repr(arr) # should not error
33473347

3348+
@pytest.mark.parametrize("use_dask", [True, False])
3349+
@pytest.mark.parametrize("data", ["list", "array", True])
33483350
@pytest.mark.parametrize("encoding", [True, False])
3349-
def test_to_and_from_dict(self, encoding) -> None:
3351+
def test_to_and_from_dict(
3352+
self, encoding: bool, data: bool | Literal["list", "array"], use_dask: bool
3353+
) -> None:
3354+
if use_dask and not has_dask:
3355+
pytest.skip("requires dask")
3356+
encoding_data = {"bar": "spam"}
33503357
array = DataArray(
33513358
np.random.randn(2, 3), {"x": ["a", "b"]}, ["x", "y"], name="foo"
33523359
)
3353-
array.encoding = {"bar": "spam"}
3354-
expected = {
3360+
array.encoding = encoding_data
3361+
3362+
return_data = array.to_numpy()
3363+
coords_data = np.array(["a", "b"])
3364+
if data == "list" or data is True:
3365+
return_data = return_data.tolist()
3366+
coords_data = coords_data.tolist()
3367+
3368+
expected: dict[str, Any] = {
33553369
"name": "foo",
33563370
"dims": ("x", "y"),
3357-
"data": array.values.tolist(),
3371+
"data": return_data,
33583372
"attrs": {},
3359-
"coords": {"x": {"dims": ("x",), "data": ["a", "b"], "attrs": {}}},
3373+
"coords": {"x": {"dims": ("x",), "data": coords_data, "attrs": {}}},
33603374
}
33613375
if encoding:
3362-
expected["encoding"] = {"bar": "spam"}
3363-
actual = array.to_dict(encoding=encoding)
3376+
expected["encoding"] = encoding_data
3377+
3378+
if has_dask:
3379+
da = array.chunk()
3380+
else:
3381+
da = array
3382+
3383+
if data == "array" or data is False:
3384+
with raise_if_dask_computes():
3385+
actual = da.to_dict(encoding=encoding, data=data)
3386+
else:
3387+
actual = da.to_dict(encoding=encoding, data=data)
33643388

33653389
# check that they are identical
3366-
assert expected == actual
3390+
np.testing.assert_equal(expected, actual)
33673391

33683392
# check roundtrip
3369-
assert_identical(array, DataArray.from_dict(actual))
3393+
assert_identical(da, DataArray.from_dict(actual))
33703394

33713395
# a more bare bones representation still roundtrips
33723396
d = {
33733397
"name": "foo",
33743398
"dims": ("x", "y"),
3375-
"data": array.values.tolist(),
3399+
"data": da.values.tolist(),
33763400
"coords": {"x": {"dims": "x", "data": ["a", "b"]}},
33773401
}
3378-
assert_identical(array, DataArray.from_dict(d))
3402+
assert_identical(da, DataArray.from_dict(d))
33793403

33803404
# and the most bare bones representation still roundtrips
3381-
d = {"name": "foo", "dims": ("x", "y"), "data": array.values}
3382-
assert_identical(array.drop_vars("x"), DataArray.from_dict(d))
3405+
d = {"name": "foo", "dims": ("x", "y"), "data": da.values}
3406+
assert_identical(da.drop_vars("x"), DataArray.from_dict(d))
33833407

33843408
# missing a dims in the coords
33853409
d = {
33863410
"dims": ("x", "y"),
3387-
"data": array.values,
3411+
"data": da.values,
33883412
"coords": {"x": {"data": ["a", "b"]}},
33893413
}
33903414
with pytest.raises(
@@ -3407,7 +3431,7 @@ def test_to_and_from_dict(self, encoding) -> None:
34073431
endiantype = "<U1" if sys.byteorder == "little" else ">U1"
34083432
expected_no_data["coords"]["x"].update({"dtype": endiantype, "shape": (2,)})
34093433
expected_no_data.update({"dtype": "float64", "shape": (2, 3)})
3410-
actual_no_data = array.to_dict(data=False, encoding=encoding)
3434+
actual_no_data = da.to_dict(data=False, encoding=encoding)
34113435
assert expected_no_data == actual_no_data
34123436

34133437
def test_to_and_from_dict_with_time_dim(self) -> None:

xarray/tests/test_dataset.py

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from copy import copy, deepcopy
99
from io import StringIO
1010
from textwrap import dedent
11-
from typing import Any
11+
from typing import Any, Literal
1212

1313
import numpy as np
1414
import pandas as pd
@@ -4596,7 +4596,11 @@ def test_convert_dataframe_with_many_types_and_multiindex(self) -> None:
45964596
expected = df.apply(np.asarray)
45974597
assert roundtripped.equals(expected)
45984598

4599-
def test_to_and_from_dict(self) -> None:
4599+
@pytest.mark.parametrize("encoding", [True, False])
4600+
@pytest.mark.parametrize("data", [True, "list", "array"])
4601+
def test_to_and_from_dict(
4602+
self, encoding: bool, data: bool | Literal["list", "array"]
4603+
) -> None:
46004604
# <xarray.Dataset>
46014605
# Dimensions: (t: 10)
46024606
# Coordinates:
@@ -4617,14 +4621,25 @@ def test_to_and_from_dict(self) -> None:
46174621
"b": {"dims": ("t",), "data": y.tolist(), "attrs": {}},
46184622
},
46194623
}
4624+
if encoding:
4625+
ds.t.encoding.update({"foo": "bar"})
4626+
expected["encoding"] = {}
4627+
expected["coords"]["t"]["encoding"] = ds.t.encoding
4628+
for vvs in ["a", "b"]:
4629+
expected["data_vars"][vvs]["encoding"] = {}
46204630

4621-
actual = ds.to_dict()
4631+
actual = ds.to_dict(data=data, encoding=encoding)
46224632

46234633
# check that they are identical
4624-
assert expected == actual
4634+
np.testing.assert_equal(expected, actual)
46254635

46264636
# check roundtrip
4627-
assert_identical(ds, Dataset.from_dict(actual))
4637+
ds_rt = Dataset.from_dict(actual)
4638+
assert_identical(ds, ds_rt)
4639+
if encoding:
4640+
assert set(ds_rt.variables) == set(ds.variables)
4641+
for vv in ds.variables:
4642+
np.testing.assert_equal(ds_rt[vv].encoding, ds[vv].encoding)
46284643

46294644
# check the data=False option
46304645
expected_no_data = expected.copy()
@@ -4635,14 +4650,18 @@ def test_to_and_from_dict(self) -> None:
46354650
expected_no_data["coords"]["t"].update({"dtype": endiantype, "shape": (10,)})
46364651
expected_no_data["data_vars"]["a"].update({"dtype": "float64", "shape": (10,)})
46374652
expected_no_data["data_vars"]["b"].update({"dtype": "float64", "shape": (10,)})
4638-
actual_no_data = ds.to_dict(data=False)
4653+
actual_no_data = ds.to_dict(data=False, encoding=encoding)
46394654
assert expected_no_data == actual_no_data
46404655

46414656
# verify coords are included roundtrip
46424657
expected_ds = ds.set_coords("b")
4643-
actual2 = Dataset.from_dict(expected_ds.to_dict())
4658+
actual2 = Dataset.from_dict(expected_ds.to_dict(data=data, encoding=encoding))
46444659

46454660
assert_identical(expected_ds, actual2)
4661+
if encoding:
4662+
assert set(expected_ds.variables) == set(actual2.variables)
4663+
for vv in ds.variables:
4664+
np.testing.assert_equal(expected_ds[vv].encoding, actual2[vv].encoding)
46464665

46474666
# test some incomplete dicts:
46484667
# this one has no attrs field, the dims are strings, and x, y are
@@ -4690,7 +4709,10 @@ def test_to_and_from_dict_with_time_dim(self) -> None:
46904709
roundtripped = Dataset.from_dict(ds.to_dict())
46914710
assert_identical(ds, roundtripped)
46924711

4693-
def test_to_and_from_dict_with_nan_nat(self) -> None:
4712+
@pytest.mark.parametrize("data", [True, "list", "array"])
4713+
def test_to_and_from_dict_with_nan_nat(
4714+
self, data: bool | Literal["list", "array"]
4715+
) -> None:
46944716
x = np.random.randn(10, 3)
46954717
y = np.random.randn(10, 3)
46964718
y[2] = np.nan
@@ -4706,7 +4728,7 @@ def test_to_and_from_dict_with_nan_nat(self) -> None:
47064728
"lat": ("lat", lat),
47074729
}
47084730
)
4709-
roundtripped = Dataset.from_dict(ds.to_dict())
4731+
roundtripped = Dataset.from_dict(ds.to_dict(data=data))
47104732
assert_identical(ds, roundtripped)
47114733

47124734
def test_to_dict_with_numpy_attrs(self) -> None:

0 commit comments

Comments
 (0)