Skip to content

Commit 3679a5d

Browse files
benbovydcherianTomNicholaskeewis
authored
Allow setting (or skipping) new indexes in open_dataset (#8051)
Co-authored-by: Deepak Cherian <dcherian@users.noreply.github.com> Co-authored-by: Tom Nicholas <tom@cworthy.org> Co-authored-by: Justus Magin <keewis@posteo.de> Co-authored-by: Justus Magin <keewis@users.noreply.github.com> Co-authored-by: Deepak Cherian <deepak@cherian.net>
1 parent 37dbae1 commit 3679a5d

File tree

6 files changed

+185
-10
lines changed

6 files changed

+185
-10
lines changed

doc/whats-new.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ v2025.07.1 (unreleased)
1212

1313
New Features
1414
~~~~~~~~~~~~
15-
15+
- Allow skipping the creation of default indexes when opening datasets (:pull:`8051`).
16+
By `Benoit Bovy <https://github.com/benbovy>`_ and `Justus Magin <https://github.com/keewis>`_.
1617

1718
Breaking changes
1819
~~~~~~~~~~~~~~~~

xarray/backends/api.py

Lines changed: 62 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from xarray.backends.locks import _get_scheduler
3737
from xarray.coders import CFDatetimeCoder, CFTimedeltaCoder
3838
from xarray.core import indexing
39+
from xarray.core.coordinates import Coordinates
3940
from xarray.core.dataarray import DataArray
4041
from xarray.core.dataset import Dataset
4142
from xarray.core.datatree import DataTree
@@ -379,6 +380,15 @@ def _chunk_ds(
379380
return backend_ds._replace(variables)
380381

381382

383+
def _maybe_create_default_indexes(ds):
384+
to_index = {
385+
name: coord.variable
386+
for name, coord in ds.coords.items()
387+
if coord.dims == (name,) and name not in ds.xindexes
388+
}
389+
return ds.assign_coords(Coordinates(to_index))
390+
391+
382392
def _dataset_from_backend_dataset(
383393
backend_ds,
384394
filename_or_obj,
@@ -389,6 +399,7 @@ def _dataset_from_backend_dataset(
389399
inline_array,
390400
chunked_array_type,
391401
from_array_kwargs,
402+
create_default_indexes,
392403
**extra_tokens,
393404
):
394405
if not isinstance(chunks, int | dict) and chunks not in {None, "auto"}:
@@ -397,11 +408,15 @@ def _dataset_from_backend_dataset(
397408
)
398409

399410
_protect_dataset_variables_inplace(backend_ds, cache)
400-
if chunks is None:
401-
ds = backend_ds
411+
412+
if create_default_indexes:
413+
ds = _maybe_create_default_indexes(backend_ds)
402414
else:
415+
ds = backend_ds
416+
417+
if chunks is not None:
403418
ds = _chunk_ds(
404-
backend_ds,
419+
ds,
405420
filename_or_obj,
406421
engine,
407422
chunks,
@@ -434,6 +449,7 @@ def _datatree_from_backend_datatree(
434449
inline_array,
435450
chunked_array_type,
436451
from_array_kwargs,
452+
create_default_indexes,
437453
**extra_tokens,
438454
):
439455
if not isinstance(chunks, int | dict) and chunks not in {None, "auto"}:
@@ -442,9 +458,11 @@ def _datatree_from_backend_datatree(
442458
)
443459

444460
_protect_datatree_variables_inplace(backend_tree, cache)
445-
if chunks is None:
446-
tree = backend_tree
461+
if create_default_indexes:
462+
tree = backend_tree.map_over_datasets(_maybe_create_default_indexes)
447463
else:
464+
tree = backend_tree
465+
if chunks is not None:
448466
tree = DataTree.from_dict(
449467
{
450468
path: _chunk_ds(
@@ -459,11 +477,12 @@ def _datatree_from_backend_datatree(
459477
node=path,
460478
**extra_tokens,
461479
)
462-
for path, [node] in group_subtrees(backend_tree)
480+
for path, [node] in group_subtrees(tree)
463481
},
464-
name=backend_tree.name,
482+
name=tree.name,
465483
)
466484

485+
if create_default_indexes or chunks is not None:
467486
for path, [node] in group_subtrees(backend_tree):
468487
tree[path].set_close(node._close)
469488

@@ -497,6 +516,7 @@ def open_dataset(
497516
concat_characters: bool | Mapping[str, bool] | None = None,
498517
decode_coords: Literal["coordinates", "all"] | bool | None = None,
499518
drop_variables: str | Iterable[str] | None = None,
519+
create_default_indexes: bool = True,
500520
inline_array: bool = False,
501521
chunked_array_type: str | None = None,
502522
from_array_kwargs: dict[str, Any] | None = None,
@@ -610,6 +630,13 @@ def open_dataset(
610630
A variable or list of variables to exclude from being parsed from the
611631
dataset. This may be useful to drop variables with problems or
612632
inconsistent values.
633+
create_default_indexes : bool, default: True
634+
If True, create pandas indexes for :term:`dimension coordinates <dimension coordinate>`,
635+
which loads the coordinate data into memory. Set it to False if you want to avoid loading
636+
data into memory.
637+
638+
Note that backends can still choose to create other indexes. If you want to control that,
639+
please refer to the backend's documentation.
613640
inline_array: bool, default: False
614641
How to include the array in the dask task graph.
615642
By default(``inline_array=False``) the array is included in a task by
@@ -702,6 +729,7 @@ def open_dataset(
702729
chunked_array_type,
703730
from_array_kwargs,
704731
drop_variables=drop_variables,
732+
create_default_indexes=create_default_indexes,
705733
**decoders,
706734
**kwargs,
707735
)
@@ -725,6 +753,7 @@ def open_dataarray(
725753
concat_characters: bool | None = None,
726754
decode_coords: Literal["coordinates", "all"] | bool | None = None,
727755
drop_variables: str | Iterable[str] | None = None,
756+
create_default_indexes: bool = True,
728757
inline_array: bool = False,
729758
chunked_array_type: str | None = None,
730759
from_array_kwargs: dict[str, Any] | None = None,
@@ -833,6 +862,13 @@ def open_dataarray(
833862
A variable or list of variables to exclude from being parsed from the
834863
dataset. This may be useful to drop variables with problems or
835864
inconsistent values.
865+
create_default_indexes : bool, default: True
866+
If True, create pandas indexes for :term:`dimension coordinates <dimension coordinate>`,
867+
which loads the coordinate data into memory. Set it to False if you want to avoid loading
868+
data into memory.
869+
870+
Note that backends can still choose to create other indexes. If you want to control that,
871+
please refer to the backend's documentation.
836872
inline_array: bool, default: False
837873
How to include the array in the dask task graph.
838874
By default(``inline_array=False``) the array is included in a task by
@@ -890,6 +926,7 @@ def open_dataarray(
890926
chunks=chunks,
891927
cache=cache,
892928
drop_variables=drop_variables,
929+
create_default_indexes=create_default_indexes,
893930
inline_array=inline_array,
894931
chunked_array_type=chunked_array_type,
895932
from_array_kwargs=from_array_kwargs,
@@ -946,6 +983,7 @@ def open_datatree(
946983
concat_characters: bool | Mapping[str, bool] | None = None,
947984
decode_coords: Literal["coordinates", "all"] | bool | None = None,
948985
drop_variables: str | Iterable[str] | None = None,
986+
create_default_indexes: bool = True,
949987
inline_array: bool = False,
950988
chunked_array_type: str | None = None,
951989
from_array_kwargs: dict[str, Any] | None = None,
@@ -1055,6 +1093,13 @@ def open_datatree(
10551093
A variable or list of variables to exclude from being parsed from the
10561094
dataset. This may be useful to drop variables with problems or
10571095
inconsistent values.
1096+
create_default_indexes : bool, default: True
1097+
If True, create pandas indexes for :term:`dimension coordinates <dimension coordinate>`,
1098+
which loads the coordinate data into memory. Set it to False if you want to avoid loading
1099+
data into memory.
1100+
1101+
Note that backends can still choose to create other indexes. If you want to control that,
1102+
please refer to the backend's documentation.
10581103
inline_array: bool, default: False
10591104
How to include the array in the dask task graph.
10601105
By default(``inline_array=False``) the array is included in a task by
@@ -1148,6 +1193,7 @@ def open_datatree(
11481193
chunked_array_type,
11491194
from_array_kwargs,
11501195
drop_variables=drop_variables,
1196+
create_default_indexes=create_default_indexes,
11511197
**decoders,
11521198
**kwargs,
11531199
)
@@ -1175,6 +1221,7 @@ def open_groups(
11751221
concat_characters: bool | Mapping[str, bool] | None = None,
11761222
decode_coords: Literal["coordinates", "all"] | bool | None = None,
11771223
drop_variables: str | Iterable[str] | None = None,
1224+
create_default_indexes: bool = True,
11781225
inline_array: bool = False,
11791226
chunked_array_type: str | None = None,
11801227
from_array_kwargs: dict[str, Any] | None = None,
@@ -1286,6 +1333,13 @@ def open_groups(
12861333
A variable or list of variables to exclude from being parsed from the
12871334
dataset. This may be useful to drop variables with problems or
12881335
inconsistent values.
1336+
create_default_indexes : bool, default: True
1337+
If True, create pandas indexes for :term:`dimension coordinates <dimension coordinate>`,
1338+
which loads the coordinate data into memory. Set it to False if you want to avoid loading
1339+
data into memory.
1340+
1341+
Note that backends can still choose to create other indexes. If you want to control that,
1342+
please refer to the backend's documentation.
12891343
inline_array: bool, default: False
12901344
How to include the array in the dask task graph.
12911345
By default(``inline_array=False``) the array is included in a task by
@@ -1381,6 +1435,7 @@ def open_groups(
13811435
chunked_array_type,
13821436
from_array_kwargs,
13831437
drop_variables=drop_variables,
1438+
create_default_indexes=create_default_indexes,
13841439
**decoders,
13851440
**kwargs,
13861441
)

xarray/backends/store.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
AbstractDataStore,
1010
BackendEntrypoint,
1111
)
12+
from xarray.core.coordinates import Coordinates
1213
from xarray.core.dataset import Dataset
1314

1415
if TYPE_CHECKING:
@@ -36,6 +37,7 @@ def open_dataset(
3637
concat_characters=True,
3738
decode_coords=True,
3839
drop_variables: str | Iterable[str] | None = None,
40+
set_indexes: bool = True,
3941
use_cftime=None,
4042
decode_timedelta=None,
4143
) -> Dataset:
@@ -56,8 +58,19 @@ def open_dataset(
5658
decode_timedelta=decode_timedelta,
5759
)
5860

59-
ds = Dataset(vars, attrs=attrs)
60-
ds = ds.set_coords(coord_names.intersection(vars))
61+
# split data and coordinate variables (promote dimension coordinates)
62+
data_vars = {}
63+
coord_vars = {}
64+
for name, var in vars.items():
65+
if name in coord_names or var.dims == (name,):
66+
coord_vars[name] = var
67+
else:
68+
data_vars[name] = var
69+
70+
# explicit Coordinates object with no index passed
71+
coords = Coordinates(coord_vars, indexes={})
72+
73+
ds = Dataset(data_vars, coords=coords, attrs=attrs)
6174
ds.set_close(filename_or_obj.close)
6275
ds.encoding = encoding
6376

xarray/backends/zarr.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1347,6 +1347,7 @@ def open_zarr(
13471347
use_zarr_fill_value_as_mask=None,
13481348
chunked_array_type: str | None = None,
13491349
from_array_kwargs: dict[str, Any] | None = None,
1350+
create_default_indexes=True,
13501351
**kwargs,
13511352
):
13521353
"""Load and decode a dataset from a Zarr store.
@@ -1457,6 +1458,13 @@ def open_zarr(
14571458
chunked arrays, via whichever chunk manager is specified through the ``chunked_array_type`` kwarg.
14581459
Defaults to ``{'manager': 'dask'}``, meaning additional kwargs will be passed eventually to
14591460
:py:func:`dask.array.from_array`. Experimental API that should not be relied upon.
1461+
create_default_indexes : bool, default: True
1462+
If True, create pandas indexes for :term:`dimension coordinates <dimension coordinate>`,
1463+
which loads the coordinate data into memory. Set it to False if you want to avoid loading
1464+
data into memory.
1465+
1466+
Note that backends can still choose to create other indexes. If you want to control that,
1467+
please refer to the backend's documentation.
14601468
14611469
Returns
14621470
-------
@@ -1513,6 +1521,7 @@ def open_zarr(
15131521
engine="zarr",
15141522
chunks=chunks,
15151523
drop_variables=drop_variables,
1524+
create_default_indexes=create_default_indexes,
15161525
chunked_array_type=chunked_array_type,
15171526
from_array_kwargs=from_array_kwargs,
15181527
backend_kwargs=backend_kwargs,

xarray/tests/test_backends.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
from xarray.coding.variables import SerializationWarning
5656
from xarray.conventions import encode_dataset_coordinates
5757
from xarray.core import indexing
58+
from xarray.core.indexes import PandasIndex
5859
from xarray.core.options import set_options
5960
from xarray.core.types import PDDatetimeUnitOptions
6061
from xarray.core.utils import module_available
@@ -2066,6 +2067,26 @@ def test_encoding_enum__error_multiple_variable_with_changing_enum(self):
20662067
with self.roundtrip(original):
20672068
pass
20682069

2070+
@pytest.mark.parametrize("create_default_indexes", [True, False])
2071+
def test_create_default_indexes(self, tmp_path, create_default_indexes) -> None:
2072+
store_path = tmp_path / "tmp.nc"
2073+
original_ds = xr.Dataset(
2074+
{"data": ("x", np.arange(3))}, coords={"x": [-1, 0, 1]}
2075+
)
2076+
original_ds.to_netcdf(store_path, engine=self.engine, mode="w")
2077+
2078+
with open_dataset(
2079+
store_path,
2080+
engine=self.engine,
2081+
create_default_indexes=create_default_indexes,
2082+
) as loaded_ds:
2083+
if create_default_indexes:
2084+
assert list(loaded_ds.xindexes) == ["x"] and isinstance(
2085+
loaded_ds.xindexes["x"], PandasIndex
2086+
)
2087+
else:
2088+
assert len(loaded_ds.xindexes) == 0
2089+
20692090

20702091
@requires_netCDF4
20712092
class TestNetCDF4Data(NetCDF4Base):
@@ -4063,6 +4084,26 @@ def test_pickle(self) -> None:
40634084
def test_pickle_dataarray(self) -> None:
40644085
pass
40654086

4087+
@pytest.mark.parametrize("create_default_indexes", [True, False])
4088+
def test_create_default_indexes(self, tmp_path, create_default_indexes) -> None:
4089+
store_path = tmp_path / "tmp.nc"
4090+
original_ds = xr.Dataset(
4091+
{"data": ("x", np.arange(3))}, coords={"x": [-1, 0, 1]}
4092+
)
4093+
original_ds.to_netcdf(store_path, engine=self.engine, mode="w")
4094+
4095+
with open_dataset(
4096+
store_path,
4097+
engine=self.engine,
4098+
create_default_indexes=create_default_indexes,
4099+
) as loaded_ds:
4100+
if create_default_indexes:
4101+
assert list(loaded_ds.xindexes) == ["x"] and isinstance(
4102+
loaded_ds.xindexes["x"], PandasIndex
4103+
)
4104+
else:
4105+
assert len(loaded_ds.xindexes) == 0
4106+
40664107

40674108
@requires_scipy
40684109
class TestScipyFilePath(CFEncodedBase, NetCDF3Only):
@@ -6434,6 +6475,26 @@ def test_zarr_closing_internal_zip_store():
64346475
assert_identical(original_da, loaded_da)
64356476

64366477

6478+
@requires_zarr
6479+
@pytest.mark.parametrize("create_default_indexes", [True, False])
6480+
def test_zarr_create_default_indexes(tmp_path, create_default_indexes) -> None:
6481+
from xarray.core.indexes import PandasIndex
6482+
6483+
store_path = tmp_path / "tmp.zarr"
6484+
original_ds = xr.Dataset({"data": ("x", np.arange(3))}, coords={"x": [-1, 0, 1]})
6485+
original_ds.to_zarr(store_path, mode="w")
6486+
6487+
with open_dataset(
6488+
store_path, engine="zarr", create_default_indexes=create_default_indexes
6489+
) as loaded_ds:
6490+
if create_default_indexes:
6491+
assert list(loaded_ds.xindexes) == ["x"] and isinstance(
6492+
loaded_ds.xindexes["x"], PandasIndex
6493+
)
6494+
else:
6495+
assert len(loaded_ds.xindexes) == 0
6496+
6497+
64376498
@requires_zarr
64386499
@pytest.mark.usefixtures("default_zarr_format")
64396500
def test_raises_key_error_on_invalid_zarr_store(tmp_path):

0 commit comments

Comments
 (0)