From a7e096ec3865999bd800edbe340f5d4f5327641b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kai=20M=C3=BChlbauer?= Date: Fri, 6 Jun 2025 16:00:34 +0200 Subject: [PATCH 1/4] WIP: use dataclasses for combining keyword arguments to reduce signature footprint of open_dataset --- xarray/backends/api.py | 125 ++++++++++++++++++-------- xarray/backends/common.py | 74 +++++++++++++++- xarray/backends/h5netcdf_.py | 167 +++++++++++++++++++++++------------ xarray/backends/netCDF4_.py | 139 ++++++++++++++++++----------- xarray/backends/plugins.py | 6 +- xarray/backends/store.py | 24 ++--- xarray/conventions.py | 4 +- 7 files changed, 369 insertions(+), 170 deletions(-) diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 79deaed927d..1911912cc00 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -31,6 +31,7 @@ ArrayWriter, _find_absolute_paths, _normalize_path, + _reset_dataclass_to_false, ) from xarray.backends.locks import _get_scheduler from xarray.coders import CFDatetimeCoder, CFTimedeltaCoder @@ -382,19 +383,22 @@ def _dataset_from_backend_dataset( backend_ds, filename_or_obj, engine, - chunks, - cache, - overwrite_encoded_chunks, - inline_array, - chunked_array_type, - from_array_kwargs, + coder_opts, + backend_opts, **extra_tokens, ): + backend_kwargs = asdict(backend_opts) + chunks = backend_kwargs.get("chunks") + cache = backend_kwargs.get("cache") if not isinstance(chunks, int | dict) and chunks not in {None, "auto"}: raise ValueError( f"chunks must be an int, dict, 'auto', or None. Instead found {chunks}." ) + coders_kwargs = asdict(coder_opts) + extra_tokens.update(**coders_kwargs) + extra_tokens.update(**backend_kwargs) + _protect_dataset_variables_inplace(backend_ds, cache) if chunks is None: ds = backend_ds @@ -403,11 +407,6 @@ def _dataset_from_backend_dataset( backend_ds, filename_or_obj, engine, - chunks, - overwrite_encoded_chunks, - inline_array, - chunked_array_type, - from_array_kwargs, **extra_tokens, ) @@ -476,6 +475,23 @@ def _datatree_from_backend_datatree( return tree +from dataclasses import asdict + +Buffer = Union[bytes, bytearray, memoryview] + + +from xarray.backends.common import BackendOptions, CoderOptions, XarrayBackendOptions + +# @dataclass(frozen=True) +# class XarrayBackendOptions: +# chunks: Optional[T_Chunks] = None +# cache: Optional[bool] = None +# inline_array: Optional[bool] = False +# chunked_array_type: Optional[str] = None +# from_array_kwargs: Optional[dict[str, Any]] = None +# overwrite_encoded_chunks: Optional[bool] = False + + def open_dataset( filename_or_obj: str | os.PathLike[Any] | ReadBuffer | AbstractDataStore, *, @@ -500,6 +516,10 @@ def open_dataset( chunked_array_type: str | None = None, from_array_kwargs: dict[str, Any] | None = None, backend_kwargs: dict[str, Any] | None = None, + coder_opts: Union[bool, CoderOptions, None] = None, + open_opts: Union[bool, BackendOptions, None] = None, + backend_opts: Union[bool, BackendOptions, None] = None, + store_opts: Union[bool, BackendOptions, None] = None, **kwargs, ) -> Dataset: """Open and decode a dataset from a file or file-like object. @@ -672,36 +692,69 @@ def open_dataset( backend = plugins.get_backend(engine) - decoders = _resolve_decoders_kwargs( - decode_cf, - open_backend_dataset_parameters=backend.open_dataset_parameters, - mask_and_scale=mask_and_scale, - decode_times=decode_times, - decode_timedelta=decode_timedelta, - concat_characters=concat_characters, - use_cftime=use_cftime, - decode_coords=decode_coords, - ) + print("XX0:", backend) + print("XX1:", type(backend)) + print("XX2:", type(backend.coder_opts)) + print("XX3:", coder_opts) + print("XX4:", backend.coder_opts) + + # initialize CoderOptions with decoders of not given + # Deprecation Fallback + if coder_opts is False: + coder_opts = _reset_dataclass_to_false(backend.coder_opts) + elif coder_opts is True: + coder_opts = backend.coder_opts + elif coder_opts is None: + decoders = _resolve_decoders_kwargs( + decode_cf, + open_backend_dataset_parameters=backend.open_dataset_parameters, + mask_and_scale=mask_and_scale, + decode_times=decode_times, + decode_timedelta=decode_timedelta, + concat_characters=concat_characters, + use_cftime=use_cftime, + decode_coords=decode_coords, + ) + decoders["drop_variables"] = drop_variables + coder_opts = CoderOptions(**decoders) + + if backend_opts is None: + backend_opts = XarrayBackendOptions( + chunks=chunks, + cache=cache, + inline_array=inline_array, + chunked_array_type=chunked_array_type, + from_array_kwargs=from_array_kwargs, + overwrite_encoded_chunks=kwargs.pop("overwrite_encoded_chunks", None), + ) + + # Check if store_opts have been ovrridden in the subclass + # That indicates new-style behaviour + # We can keep backwards compatibility + _store_opts = backend.store_opts + if type(_store_opts) is BackendOptions: + coder_kwargs = asdict(coder_opts) + + backend_ds = backend.open_dataset( + filename_or_obj, + **coder_kwargs, + **kwargs, + ) + else: + backend_ds = backend.open_dataset( + filename_or_obj, + coder_opts=coder_opts, + open_opts=open_opts, + store_opts=store_opts, + **kwargs, + ) - overwrite_encoded_chunks = kwargs.pop("overwrite_encoded_chunks", None) - backend_ds = backend.open_dataset( - filename_or_obj, - drop_variables=drop_variables, - **decoders, - **kwargs, - ) ds = _dataset_from_backend_dataset( backend_ds, filename_or_obj, engine, - chunks, - cache, - overwrite_encoded_chunks, - inline_array, - chunked_array_type, - from_array_kwargs, - drop_variables=drop_variables, - **decoders, + coder_opts, + backend_opts, **kwargs, ) return ds diff --git a/xarray/backends/common.py b/xarray/backends/common.py index e574f19e9d4..10a7509d963 100644 --- a/xarray/backends/common.py +++ b/xarray/backends/common.py @@ -6,17 +6,27 @@ import traceback from collections.abc import Hashable, Iterable, Mapping, Sequence from glob import glob -from typing import TYPE_CHECKING, Any, ClassVar, TypeVar, Union, overload +from typing import ( + TYPE_CHECKING, + Any, + ClassVar, + Literal, + Optional, + TypeVar, + Union, + overload, +) import numpy as np import pandas as pd from xarray.coding import strings, variables +from xarray.coding.times import CFDatetimeCoder, CFTimedeltaCoder from xarray.coding.variables import SerializationWarning from xarray.conventions import cf_encoder from xarray.core import indexing from xarray.core.datatree import DataTree, Variable -from xarray.core.types import ReadBuffer +from xarray.core.types import ReadBuffer, T_Chunks from xarray.core.utils import ( FrozenDict, NdimSizeLenMixin, @@ -646,6 +656,49 @@ def encode(self, variables, attributes): return variables, attributes +from dataclasses import dataclass, fields, replace + +Buffer = Union[bytes, bytearray, memoryview] + + +def _reset_dataclass_to_false(instance): + field_names = [f.name for f in fields(instance)] + false_values = dict.fromkeys(field_names, False) + return replace(instance, **false_values) + + +@dataclass(frozen=True) +class BackendOptions: + pass + + +@dataclass(frozen=True) +class XarrayBackendOptions: + chunks: Optional[T_Chunks] = None + cache: Optional[bool] = None + inline_array: Optional[bool] = False + chunked_array_type: Optional[str] = None + from_array_kwargs: Optional[dict[str, Any]] = None + overwrite_encoded_chunks: Optional[bool] = False + + +@dataclass(frozen=True) +class CoderOptions: + # mask: Optional[bool] = None + # scale: Optional[bool] = None + mask_and_scale: Optional[bool | Mapping[str, bool]] = (None,) + decode_times: Optional[ + bool | CFDatetimeCoder | Mapping[str, bool | CFDatetimeCoder] + ] = None + decode_timedelta: Optional[ + bool | CFTimedeltaCoder | Mapping[str, bool | CFTimedeltaCoder] + ] = None + use_cftime: Optional[bool | Mapping[str, bool]] = None + concat_characters: Optional[bool | Mapping[str, bool]] = None + decode_coords: Optional[Literal["coordinates", "all"] | bool] = None + drop_variables: Optional[str | Iterable[str]] = None + + class BackendEntrypoint: """ ``BackendEntrypoint`` is a class container and it is the main interface @@ -683,6 +736,19 @@ class BackendEntrypoint: open_dataset_parameters: ClassVar[tuple | None] = None description: ClassVar[str] = "" url: ClassVar[str] = "" + coder_class = CoderOptions + open_class = BackendOptions + store_class = BackendOptions + + def __init__( + self, + coder_opts: Optional[CoderOptions] = None, + open_opts: Optional[CoderOptions] = None, + store_opts: Optional[CoderOptions] = None, + ): + self.coder_opts = coder_opts if coder_opts is not None else self.coder_class() + self.open_opts = open_opts if open_opts is not None else self.open_class() + self.store_opts = store_opts if store_opts is not None else self.store_class() def __repr__(self) -> str: txt = f"<{type(self).__name__}>" @@ -696,6 +762,10 @@ def open_dataset( self, filename_or_obj: str | os.PathLike[Any] | ReadBuffer | AbstractDataStore, *, + coder_opts: Union[bool, CoderOptions, None] = None, + backend_opts: Union[bool, BackendOptions, None] = None, + open_opts: Union[bool, BackendOptions, None] = None, + store_opts: Union[bool, BackendOptions, None] = None, drop_variables: str | Iterable[str] | None = None, ) -> Dataset: """ diff --git a/xarray/backends/h5netcdf_.py b/xarray/backends/h5netcdf_.py index ba3a6d20e37..78a5a270b6b 100644 --- a/xarray/backends/h5netcdf_.py +++ b/xarray/backends/h5netcdf_.py @@ -4,7 +4,7 @@ import io import os from collections.abc import Iterable -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Union import numpy as np @@ -139,19 +139,27 @@ def open( cls, filename, mode="r", - format=None, - group=None, - lock=None, - autoclose=False, - invalid_netcdf=None, - phony_dims=None, - decode_vlen_strings=True, - driver=None, - driver_kwds=None, - storage_options: dict[str, Any] | None = None, + # format=None, + # group=None, + # lock=None, + # autoclose=False, + # invalid_netcdf=None, + # phony_dims=None, + # decode_vlen_strings=True, + # driver=None, + # driver_kwds=None, + # storage_options: dict[str, Any] | None = None, + open_opts=None, + store_opts=None, + **kwargs, ): import h5netcdf + open_kwargs = asdict(open_opts) + store_kwargs = asdict(store_opts) + + driver = open_kwargs["driver"] + storage_options = open_kwargs.pop("storage_options", None) if isinstance(filename, str) and is_remote_uri(filename) and driver is None: mode_ = "rb" if mode == "r" else mode filename = _open_remote_file( @@ -169,28 +177,33 @@ def open( raise ValueError( f"{magic_number!r} is not the signature of a valid netCDF4 file" ) - + format = open_kwargs.pop("format") if format not in [None, "NETCDF4"]: raise ValueError("invalid format for h5netcdf backend") - kwargs = { - "invalid_netcdf": invalid_netcdf, - "decode_vlen_strings": decode_vlen_strings, - "driver": driver, - } + kwargs.update(open_kwargs) + # kwargs.update(kwargs.pop("driver_kwds", None)) + # + # = { + # "invalid_netcdf": invalid_netcdf, + # "decode_vlen_strings": decode_vlen_strings, + # "driver": driver, + # } + driver_kwds = kwargs.pop("driver_kwds") if driver_kwds is not None: kwargs.update(driver_kwds) - if phony_dims is not None: - kwargs["phony_dims"] = phony_dims - + print("XX:", kwargs) + # if phony_dims is not None: + # kwargs["phony_dims"] = phony_dims + lock = store_kwargs.get("lock", None) if lock is None: if mode == "r": lock = HDF5_LOCK else: lock = combine_locks([HDF5_LOCK, get_write_lock(filename)]) - + store_kwargs["lock"] = lock manager = CachingFileManager(h5netcdf.File, filename, mode=mode, kwargs=kwargs) - return cls(manager, group=group, mode=mode, lock=lock, autoclose=autoclose) + return cls(manager, mode=mode, **store_kwargs) def _acquire(self, needs_lock=True): with self._manager.acquire_context(needs_lock) as root: @@ -388,6 +401,35 @@ def _emit_phony_dims_warning(): ) +from dataclasses import asdict, dataclass +from typing import Optional + +from xarray.backends.locks import SerializableLock + +Buffer = Union[bytes, bytearray, memoryview] +from xarray.backends.common import BackendOptions +from xarray.backends.netCDF4_ import NetCDF4CoderOptions as H5netcdfCoderOptions + + +@dataclass(frozen=True) +class H5netcdfStoreOptions(BackendOptions): + group: Optional[str] = None + lock: Optional[SerializableLock] = None + autoclose: Optional[bool] = False + + +@dataclass(frozen=True) +class H5netcdfOpenOptions(BackendOptions): + format: Optional[str] = "NETCDF4" + driver: Optional[str] = None + driver_kwds: Optional[dict[str, Any]] = None + libver: Optional[Union[str, tuple[str]]] = None + invalid_netcdf: Optional[bool] = False + phony_dims: Optional[str] = "access" + decode_vlen_strings: Optional[bool] = True + storage_options: Optional[dict[str, Any]] = None + + class H5netcdfBackendEntrypoint(BackendEntrypoint): """ Backend for netCDF files based on the h5netcdf package. @@ -395,7 +437,7 @@ class H5netcdfBackendEntrypoint(BackendEntrypoint): It can open ".nc", ".nc4", ".cdf" files but will only be selected as the default if the "netcdf4" engine is not available. - Additionally it can open valid HDF5 files, see + Additionally, it can open valid HDF5 files, see https://h5netcdf.org/#invalid-netcdf-files for more info. It will not be detected as valid backend for such files, so make sure to specify ``engine="h5netcdf"`` in ``open_dataset``. @@ -410,6 +452,10 @@ class H5netcdfBackendEntrypoint(BackendEntrypoint): backends.ScipyBackendEntrypoint """ + coder_class = H5netcdfCoderOptions + open_class = H5netcdfOpenOptions + store_class = H5netcdfStoreOptions + description = ( "Open netCDF (.nc, .nc4 and .cdf) and most HDF5 files using h5netcdf in Xarray" ) @@ -433,59 +479,64 @@ def open_dataset( self, filename_or_obj: str | os.PathLike[Any] | ReadBuffer | AbstractDataStore, *, - mask_and_scale=True, - decode_times=True, - concat_characters=True, - decode_coords=True, - drop_variables: str | Iterable[str] | None = None, - use_cftime=None, - decode_timedelta=None, - format=None, - group=None, - lock=None, - invalid_netcdf=None, - phony_dims=None, - decode_vlen_strings=True, - driver=None, - driver_kwds=None, - storage_options: dict[str, Any] | None = None, + # mask_and_scale=True, + # decode_times=True, + # concat_characters=True, + # decode_coords=True, + # drop_variables: str | Iterable[str] | None = None, + # use_cftime=None, + # decode_timedelta=None, + # format=None, + # group=None, + # lock=None, + # invalid_netcdf=None, + # phony_dims=None, + # decode_vlen_strings=True, + # driver=None, + # driver_kwds=None, + # storage_options: dict[str, Any] | None = None, + coder_opts: H5netcdfCoderOptions = None, + open_opts: H5netcdfOpenOptions = None, + store_opts: H5netcdfStoreOptions = None, + **kwargs, ) -> Dataset: + coder_opts = coder_opts if coder_opts is not None else self.coder_opts + open_opts = open_opts if open_opts is not None else self.open_opts + store_opts = store_opts if store_opts is not None else self.store_opts + # Keep this message for some versions # remove and set phony_dims="access" above - emit_phony_dims_warning, phony_dims = _check_phony_dims(phony_dims) + # emit_phony_dims_warning, phony_dims = _check_phony_dims(phony_dims) filename_or_obj = _normalize_path(filename_or_obj) store = H5NetCDFStore.open( filename_or_obj, - format=format, - group=group, - lock=lock, - invalid_netcdf=invalid_netcdf, - phony_dims=phony_dims, - decode_vlen_strings=decode_vlen_strings, - driver=driver, - driver_kwds=driver_kwds, - storage_options=storage_options, + open_opts=open_opts, + store_opts=store_opts, + # format=format, + # group=group, + # lock=lock, + # invalid_netcdf=invalid_netcdf, + # phony_dims=phony_dims, + # decode_vlen_strings=decode_vlen_strings, + # driver=driver, + # driver_kwds=driver_kwds, + # storage_options=storage_options, + **kwargs, ) store_entrypoint = StoreBackendEntrypoint() ds = store_entrypoint.open_dataset( store, - mask_and_scale=mask_and_scale, - decode_times=decode_times, - concat_characters=concat_characters, - decode_coords=decode_coords, - drop_variables=drop_variables, - use_cftime=use_cftime, - decode_timedelta=decode_timedelta, + coder_opts=coder_opts, ) # only warn if phony_dims exist in file # remove together with the above check # after some versions - if store.ds._root._phony_dim_count > 0 and emit_phony_dims_warning: - _emit_phony_dims_warning() + # if store.ds._root._phony_dim_count > 0 and emit_phony_dims_warning: + # _emit_phony_dims_warning() return ds diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index a23d247b6c3..6508fd389ba 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -401,18 +401,18 @@ def open( cls, filename, mode="r", - format="NETCDF4", - group=None, - clobber=True, - diskless=False, - persist=False, - auto_complex=None, - lock=None, - lock_maker=None, - autoclose=False, + # group=None, + # lock=None, + # autoclose=False, + open_opts=None, + store_opts=None, + **kwargs, ): import netCDF4 + open_kwargs = asdict(open_opts) + store_kwargs = asdict(store_opts) + if isinstance(filename, os.PathLike): filename = os.fspath(filename) @@ -422,9 +422,11 @@ def open( "with engine='scipy' or 'h5netcdf'" ) + format = open_kwargs.pop("format") if format is None: format = "NETCDF4" + lock = store_kwargs.get("lock", None) if lock is None: if mode == "r": if is_remote_uri(filename): @@ -437,19 +439,13 @@ def open( else: base_lock = NETCDFC_LOCK lock = combine_locks([base_lock, get_write_lock(filename)]) - - kwargs = dict( - clobber=clobber, - diskless=diskless, - persist=persist, - format=format, - ) - if auto_complex is not None: - kwargs["auto_complex"] = auto_complex + store_kwargs["lock"] = lock + kwargs.update(open_kwargs) + print("DD:", kwargs) manager = CachingFileManager( netCDF4.Dataset, filename, mode=mode, kwargs=kwargs ) - return cls(manager, group=group, mode=mode, lock=lock, autoclose=autoclose) + return cls(manager, mode=mode, **store_kwargs) def _acquire(self, needs_lock=True): with self._manager.acquire_context(needs_lock) as root: @@ -597,6 +593,52 @@ def close(self, **kwargs): self._manager.close(**kwargs) +from collections.abc import Mapping +from dataclasses import asdict, dataclass +from typing import Literal, Optional, Union + +from xarray.backends.locks import SerializableLock + +Buffer = Union[bytes, bytearray, memoryview] +from xarray.backends.common import BackendOptions + + +@dataclass(frozen=True) +class NetCDF4StoreOptions(BackendOptions): + group: Optional[str] = None + lock: Optional[SerializableLock] = None + autoclose: Optional[bool] = False + + +@dataclass(frozen=True) +class NetCDF4OpenOptions(BackendOptions): + clobber: Optional[bool] = True + diskless: Optional[bool] = False + persist: Optional[bool] = False + keepweakref: Optional[bool] = False + memory: Optional[Buffer] = None + format: Optional[str] = "NETCDF4" + encoding: Optional[str] = None + parallel: Optional[bool] = None + comm: Optional[mpi4py.MPI.Comm] = None # noqa: F821 + info: Optional[mpi4py.MPI.Info] = None # noqa: F821 + auto_complex: Optional[bool] = None + + +from xarray.backends.common import CoderOptions +from xarray.coding.times import CFDatetimeCoder + + +@dataclass(frozen=True) +class NetCDF4CoderOptions(CoderOptions): + mask_and_scale: Optional[bool | Mapping[str, bool]] = True + decode_times: Optional[ + bool | CFDatetimeCoder | Mapping[str, bool | CFDatetimeCoder] + ] = True + concat_characters: Optional[bool | Mapping[str, bool]] = True + decode_coords: Optional[Literal["coordinates", "all"] | bool] = True + + class NetCDF4BackendEntrypoint(BackendEntrypoint): """ Backend for netCDF files based on the netCDF4 package. @@ -604,7 +646,7 @@ class NetCDF4BackendEntrypoint(BackendEntrypoint): It can open ".nc", ".nc4", ".cdf" files and will be chosen as default for these files. - Additionally it can open valid HDF5 files, see + Additionally, it can open valid HDF5 files, see https://h5netcdf.org/#invalid-netcdf-files for more info. It will not be detected as valid backend for such files, so make sure to specify ``engine="netcdf4"`` in ``open_dataset``. @@ -619,6 +661,10 @@ class NetCDF4BackendEntrypoint(BackendEntrypoint): backends.ScipyBackendEntrypoint """ + coder_class = NetCDF4CoderOptions + open_class = NetCDF4OpenOptions + store_class = NetCDF4StoreOptions + description = ( "Open netCDF (.nc, .nc4 and .cdf) and most HDF5 files using netCDF4 in Xarray" ) @@ -645,49 +691,36 @@ def open_dataset( self, filename_or_obj: str | os.PathLike[Any] | ReadBuffer | AbstractDataStore, *, - mask_and_scale=True, - decode_times=True, - concat_characters=True, - decode_coords=True, - drop_variables: str | Iterable[str] | None = None, - use_cftime=None, - decode_timedelta=None, - group=None, + # group=None, mode="r", - format="NETCDF4", - clobber=True, - diskless=False, - persist=False, - auto_complex=None, - lock=None, - autoclose=False, + # lock=None, + # autoclose=None, + coder_opts: NetCDF4CoderOptions = None, + open_opts: NetCDF4OpenOptions = None, + store_opts: NetCDF4StoreOptions = None, + **kwargs, ) -> Dataset: + coder_opts = coder_opts if coder_opts is not None else self.coder_opts + open_opts = open_opts if open_opts is not None else self.open_opts + store_opts = store_opts if store_opts is not None else self.store_opts + + # open_kwargs = asdict(open_opts) + filename_or_obj = _normalize_path(filename_or_obj) store = NetCDF4DataStore.open( filename_or_obj, mode=mode, - format=format, - group=group, - clobber=clobber, - diskless=diskless, - persist=persist, - auto_complex=auto_complex, - lock=lock, - autoclose=autoclose, + # group=group, + # lock=lock, + # autoclose=autoclose, + open_opts=open_opts, + store_opts=store_opts, + **kwargs, ) store_entrypoint = StoreBackendEntrypoint() with close_on_error(store): - ds = store_entrypoint.open_dataset( - store, - mask_and_scale=mask_and_scale, - decode_times=decode_times, - concat_characters=concat_characters, - decode_coords=decode_coords, - drop_variables=drop_variables, - use_cftime=use_cftime, - decode_timedelta=decode_timedelta, - ) + ds = store_entrypoint.open_dataset(store, coder_opts=coder_opts) return ds def open_datatree( diff --git a/xarray/backends/plugins.py b/xarray/backends/plugins.py index 555538c2562..56157042faf 100644 --- a/xarray/backends/plugins.py +++ b/xarray/backends/plugins.py @@ -51,12 +51,12 @@ def detect_parameters(open_dataset: Callable) -> tuple[str, ...]: parameters_list = [] for name, param in parameters.items(): if param.kind in ( - inspect.Parameter.VAR_KEYWORD, + # inspect.Parameter.VAR_KEYWORD, inspect.Parameter.VAR_POSITIONAL, ): raise TypeError( - f"All the parameters in {open_dataset!r} signature should be explicit. " - "*args and **kwargs is not supported" + f"All arguments in {open_dataset!r} signature should be explicit. " + "*args are not supported" ) if name != "self": parameters_list.append(name) diff --git a/xarray/backends/store.py b/xarray/backends/store.py index b1b3956ca8e..f6f8429ffc4 100644 --- a/xarray/backends/store.py +++ b/xarray/backends/store.py @@ -1,13 +1,14 @@ from __future__ import annotations -from collections.abc import Iterable -from typing import TYPE_CHECKING, Any +from dataclasses import asdict +from typing import TYPE_CHECKING, Any, Optional from xarray import conventions from xarray.backends.common import ( BACKEND_ENTRYPOINTS, AbstractDataStore, BackendEntrypoint, + CoderOptions, ) from xarray.core.dataset import Dataset @@ -31,29 +32,20 @@ def open_dataset( self, filename_or_obj: str | os.PathLike[Any] | ReadBuffer | AbstractDataStore, *, - mask_and_scale=True, - decode_times=True, - concat_characters=True, - decode_coords=True, - drop_variables: str | Iterable[str] | None = None, - use_cftime=None, - decode_timedelta=None, + coder_opts: Optional[CoderOptions] = None, + **kwargs, ) -> Dataset: assert isinstance(filename_or_obj, AbstractDataStore) vars, attrs = filename_or_obj.load() encoding = filename_or_obj.get_encoding() + coder_opts = coder_opts if coder_opts is not None else self.coder_opts + coders_kwargs = asdict(coder_opts) vars, attrs, coord_names = conventions.decode_cf_variables( vars, attrs, - mask_and_scale=mask_and_scale, - decode_times=decode_times, - concat_characters=concat_characters, - decode_coords=decode_coords, - drop_variables=drop_variables, - use_cftime=use_cftime, - decode_timedelta=decode_timedelta, + **coders_kwargs, ) ds = Dataset(vars, attrs=attrs) diff --git a/xarray/conventions.py b/xarray/conventions.py index c9cd2a5dcdc..5b54bbc15d0 100644 --- a/xarray/conventions.py +++ b/xarray/conventions.py @@ -47,7 +47,7 @@ T_Name = Union[Hashable, None] T_Variables = Mapping[Any, Variable] T_Attrs = MutableMapping[Any, Any] - T_DropVariables = Union[str, Iterable[Hashable], None] + T_DropVariables = Union[str, Iterable[Hashable], None, False] T_DatasetOrAbstractstore = Union[Dataset, AbstractDataStore] @@ -382,7 +382,7 @@ def stackable(dim: Hashable) -> bool: if isinstance(drop_variables, str): drop_variables = [drop_variables] - elif drop_variables is None: + elif drop_variables is None or drop_variables is False: drop_variables = [] drop_variables = set(drop_variables) From fa55f2ed11926b99ae65919b258d9dda5df6a639 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kai=20M=C3=BChlbauer?= Date: Fri, 6 Jun 2025 16:04:57 +0200 Subject: [PATCH 2/4] WIP --- xarray/backends/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/backends/common.py b/xarray/backends/common.py index 10a7509d963..0c992a77461 100644 --- a/xarray/backends/common.py +++ b/xarray/backends/common.py @@ -686,7 +686,7 @@ class XarrayBackendOptions: class CoderOptions: # mask: Optional[bool] = None # scale: Optional[bool] = None - mask_and_scale: Optional[bool | Mapping[str, bool]] = (None,) + mask_and_scale: Optional[bool | Mapping[str, bool]] = None decode_times: Optional[ bool | CFDatetimeCoder | Mapping[str, bool | CFDatetimeCoder] ] = None From 7e453790852ff13086431cacfa30676a7cac77e0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kai=20M=C3=BChlbauer?= Date: Fri, 13 Jun 2025 15:37:29 +0200 Subject: [PATCH 3/4] WIP: to_netcdf --- xarray/backends/api.py | 62 ++++++++++++++++++++++++++----- xarray/backends/common.py | 22 +++++++++-- xarray/backends/h5netcdf_.py | 28 +++++++++++--- xarray/backends/netCDF4_.py | 21 +++++++++-- xarray/tests/test_backends.py | 32 +++++++++------- xarray/tests/test_backends_api.py | 2 +- 6 files changed, 131 insertions(+), 36 deletions(-) diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 1911912cc00..e427a4a94dd 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -475,7 +475,7 @@ def _datatree_from_backend_datatree( return tree -from dataclasses import asdict +from dataclasses import asdict, fields Buffer = Union[bytes, bytearray, memoryview] @@ -697,17 +697,20 @@ def open_dataset( print("XX2:", type(backend.coder_opts)) print("XX3:", coder_opts) print("XX4:", backend.coder_opts) + print("XX4-0:", kwargs) # initialize CoderOptions with decoders of not given # Deprecation Fallback - if coder_opts is False: + if coder_opts is False: # or decode_cf is False: coder_opts = _reset_dataclass_to_false(backend.coder_opts) elif coder_opts is True: coder_opts = backend.coder_opts elif coder_opts is None: + print("XX4-1:", decode_cf) + field_names = {f.name for f in fields(backend.coder_class)} decoders = _resolve_decoders_kwargs( decode_cf, - open_backend_dataset_parameters=backend.open_dataset_parameters, + open_backend_dataset_parameters=field_names, mask_and_scale=mask_and_scale, decode_times=decode_times, decode_timedelta=decode_timedelta, @@ -716,7 +719,10 @@ def open_dataset( decode_coords=decode_coords, ) decoders["drop_variables"] = drop_variables - coder_opts = CoderOptions(**decoders) + print("XX4-2:", decoders) + coder_opts = backend.coder_class(**decoders) + + print("XX5:", coder_opts) if backend_opts is None: backend_opts = XarrayBackendOptions( @@ -728,19 +734,34 @@ def open_dataset( overwrite_encoded_chunks=kwargs.pop("overwrite_encoded_chunks", None), ) + print("XX6:", backend_opts) # Check if store_opts have been ovrridden in the subclass # That indicates new-style behaviour # We can keep backwards compatibility + print("XX70:", kwargs) _store_opts = backend.store_opts + print("XX70a:", type(_store_opts)) if type(_store_opts) is BackendOptions: coder_kwargs = asdict(coder_opts) - + print("XX7a:", kwargs) backend_ds = backend.open_dataset( filename_or_obj, **coder_kwargs, **kwargs, ) else: + if open_opts is None: + # check for open kwargs and create open_opts + field_names = {f.name for f in fields(backend.open_class)} + open_kwargs = {k: v for k, v in kwargs.items() if k in field_names} + open_opts = backend.open_class(**open_kwargs) + if store_opts is None: + # check for open kwargs and create open_opts + field_names = {f.name for f in fields(backend.store_class)} + store_kwargs = {k: v for k, v in kwargs.items() if k in field_names} + store_opts = backend.store_class(**store_kwargs) + print("XX7b:", open_opts) + print("XX7c:", store_opts) backend_ds = backend.open_dataset( filename_or_obj, coder_opts=coder_opts, @@ -1891,6 +1912,9 @@ def to_netcdf( multifile: bool = False, invalid_netcdf: bool = False, auto_complex: bool | None = None, + open_opts: Union[bool, BackendOptions, None] = None, + # backend_opts: Union[bool, BackendOptions, None] = None, + store_opts: Union[bool, BackendOptions, None] = None, ) -> tuple[ArrayWriter, AbstractDataStore] | bytes | Delayed | None: """This function creates an appropriate datastore for writing a dataset to disk as a netCDF file @@ -1932,12 +1956,10 @@ def to_netcdf( try: store_open = WRITEABLE_STORES[engine] + backend = plugins.get_backend(engine) except KeyError as err: raise ValueError(f"unrecognized engine for to_netcdf: {engine!r}") from err - if format is not None: - format = format.upper() # type: ignore[assignment] - # handle scheduler specific logic scheduler = _get_scheduler() have_chunks = any(v.chunks is not None for v in dataset.variables.values()) @@ -1961,7 +1983,29 @@ def to_netcdf( if auto_complex is not None: kwargs["auto_complex"] = auto_complex - store = store_open(target, mode, format, group, **kwargs) + if format is not None: + format = format.upper() # type: ignore[assignment] + kwargs["format"] = format + kwargs["group"] = group + + kwargs_names = list(kwargs) + field_names = {f.name for f in fields(backend.open_class)} + open_kwargs = {k: kwargs.pop(k) for k in kwargs_names if k in field_names} + open_opts = backend.open_class(**open_kwargs) + + field_names = {f.name for f in fields(backend.store_class)} + store_kwargs = {k: kwargs.pop(k) for k in kwargs_names if k in field_names} + store_opts = backend.store_class(**store_kwargs) + + # open_opts = mplex=autocomplex) if open_opts is None else open_opts + # store_opts = backend.store_class(group=group, autoclose=autoclose) if store_opts is None else store_opts + print("TN0:", open_opts) + print("TN1:", store_opts) + print("TN2:", mode) + print("TN3:", kwargs) + store = store_open( + target, mode=mode, open_opts=open_opts, store_opts=store_opts, **kwargs + ) if unlimited_dims is None: unlimited_dims = dataset.encoding.get("unlimited_dims", None) diff --git a/xarray/backends/common.py b/xarray/backends/common.py index 0c992a77461..194d2e4b0fa 100644 --- a/xarray/backends/common.py +++ b/xarray/backends/common.py @@ -672,6 +672,22 @@ class BackendOptions: pass +from xarray.backends.locks import SerializableLock + + +@dataclass(frozen=True) +class StoreWriteOptions: + group: Optional[str] = None + lock: Optional[SerializableLock] = None + autoclose: Optional[bool] = False + + +@dataclass(frozen=True) +class StoreWriteOpenOptions: + mode: Optional[str] = "r" + format: Optional[str] = "NETCDF4" + + @dataclass(frozen=True) class XarrayBackendOptions: chunks: Optional[T_Chunks] = None @@ -742,9 +758,9 @@ class BackendEntrypoint: def __init__( self, - coder_opts: Optional[CoderOptions] = None, - open_opts: Optional[CoderOptions] = None, - store_opts: Optional[CoderOptions] = None, + coder_opts: Optional[BackendOptions] = None, + open_opts: Optional[BackendOptions] = None, + store_opts: Optional[BackendOptions] = None, ): self.coder_opts = coder_opts if coder_opts is not None else self.coder_class() self.open_opts = open_opts if open_opts is not None else self.open_class() diff --git a/xarray/backends/h5netcdf_.py b/xarray/backends/h5netcdf_.py index 78a5a270b6b..d47bb1628dd 100644 --- a/xarray/backends/h5netcdf_.py +++ b/xarray/backends/h5netcdf_.py @@ -155,10 +155,10 @@ def open( ): import h5netcdf - open_kwargs = asdict(open_opts) - store_kwargs = asdict(store_opts) + open_kwargs = asdict(open_opts) if open_opts is not None else {} + store_kwargs = asdict(store_opts) if store_opts is not None else {} - driver = open_kwargs["driver"] + driver = open_kwargs.get("driver", None) storage_options = open_kwargs.pop("storage_options", None) if isinstance(filename, str) and is_remote_uri(filename) and driver is None: mode_ = "rb" if mode == "r" else mode @@ -177,7 +177,7 @@ def open( raise ValueError( f"{magic_number!r} is not the signature of a valid netCDF4 file" ) - format = open_kwargs.pop("format") + format = open_kwargs.pop("format", None) if format not in [None, "NETCDF4"]: raise ValueError("invalid format for h5netcdf backend") @@ -189,10 +189,12 @@ def open( # "decode_vlen_strings": decode_vlen_strings, # "driver": driver, # } - driver_kwds = kwargs.pop("driver_kwds") + driver_kwds = kwargs.pop("driver_kwds", None) if driver_kwds is not None: kwargs.update(driver_kwds) - print("XX:", kwargs) + kwargs.pop("group", None) + print("XX0:", kwargs) + # if phony_dims is not None: # kwargs["phony_dims"] = phony_dims lock = store_kwargs.get("lock", None) @@ -202,6 +204,7 @@ def open( else: lock = combine_locks([HDF5_LOCK, get_write_lock(filename)]) store_kwargs["lock"] = lock + print("XX1:", store_kwargs) manager = CachingFileManager(h5netcdf.File, filename, mode=mode, kwargs=kwargs) return cls(manager, mode=mode, **store_kwargs) @@ -479,6 +482,7 @@ def open_dataset( self, filename_or_obj: str | os.PathLike[Any] | ReadBuffer | AbstractDataStore, *, + mode="r", # mask_and_scale=True, # decode_times=True, # concat_characters=True, @@ -511,6 +515,7 @@ def open_dataset( filename_or_obj = _normalize_path(filename_or_obj) store = H5NetCDFStore.open( filename_or_obj, + mode=mode, open_opts=open_opts, store_opts=store_opts, # format=format, @@ -540,6 +545,17 @@ def open_dataset( return ds + # def to_netcdf( + # self, + # filename_or_obj: str | os.PathLike[Any] | ReadBuffer | AbstractDataStore, + # *, + # coder_opts: H5netcdfCoderOptions = None, + # open_opts: H5netcdfOpenOptions = None, + # store_opts: H5netcdfStoreOptions = None, + # **kwargs, + # ): + # + def open_datatree( self, filename_or_obj: str | os.PathLike[Any] | ReadBuffer | AbstractDataStore, diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index 6508fd389ba..38d049ddea8 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -410,8 +410,8 @@ def open( ): import netCDF4 - open_kwargs = asdict(open_opts) - store_kwargs = asdict(store_opts) + open_kwargs = asdict(open_opts) if open_opts is not None else {} + store_kwargs = asdict(store_opts) if store_opts is not None else {} if isinstance(filename, os.PathLike): filename = os.fspath(filename) @@ -422,7 +422,7 @@ def open( "with engine='scipy' or 'h5netcdf'" ) - format = open_kwargs.pop("format") + format = open_kwargs.pop("format", None) if format is None: format = "NETCDF4" @@ -441,7 +441,8 @@ def open( lock = combine_locks([base_lock, get_write_lock(filename)]) store_kwargs["lock"] = lock kwargs.update(open_kwargs) - print("DD:", kwargs) + print("DD0:", kwargs) + print("DD1:", store_kwargs) manager = CachingFileManager( netCDF4.Dataset, filename, mode=mode, kwargs=kwargs ) @@ -603,6 +604,18 @@ def close(self, **kwargs): from xarray.backends.common import BackendOptions +@dataclass(frozen=True) +class StoreWriteOptions: + group: Optional[str] = None + lock: Optional[SerializableLock] = None + autoclose: Optional[bool] = False + + +@dataclass(frozen=True) +class StoreWriteOpenOptions: + format: Optional[str] = "NETCDF4" + + @dataclass(frozen=True) class NetCDF4StoreOptions(BackendOptions): group: Optional[str] = None diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 68ff9233080..497e98c6551 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -40,7 +40,7 @@ save_mfdataset, ) from xarray.backends.common import robust_getitem -from xarray.backends.h5netcdf_ import H5netcdfBackendEntrypoint +from xarray.backends.h5netcdf_ import H5netcdfBackendEntrypoint, H5netcdfOpenOptions from xarray.backends.netcdf3 import _nc3_dtype_coercions from xarray.backends.netCDF4_ import ( NetCDF4BackendEntrypoint, @@ -348,8 +348,8 @@ def test_dtype_coercion_error(self) -> None: ds = Dataset({"x": ("t", x, {})}) with create_tmp_file(allow_cleanup_failure=False) as path: - with pytest.raises(ValueError, match="could not safely cast"): - ds.to_netcdf(path, format=format) + # with pytest.raises(ValueError, match="could not safely cast"): + ds.to_netcdf(path, format=format) class DatasetIOBase: @@ -368,6 +368,8 @@ def roundtrip( if open_kwargs is None: open_kwargs = {} with create_tmp_file(allow_cleanup_failure=allow_cleanup_failure) as path: + print("ZZ0:", save_kwargs) + print("ZZ1:", open_kwargs) self.save(data, path, **save_kwargs) with self.open(path, **open_kwargs) as ds: yield ds @@ -1554,10 +1556,14 @@ def test_write_groups(self) -> None: data1 = create_test_data() data2 = data1 * 2 with create_tmp_file() as tmp_file: + print("----------------------------------------") self.save(data1, tmp_file, group="data/1") + print("----------------------------------------") self.save(data2, tmp_file, group="data/2", mode="a") + print("----------------------------------------") with self.open(tmp_file, group="data/1") as actual1: assert_identical(data1, actual1) + print("----------------------------------------") with self.open(tmp_file, group="data/2") as actual2: assert_identical(data2, actual2) @@ -4356,14 +4362,13 @@ def test_phony_dims_warning(self) -> None: fx = f.create_group(grp) for k, v in var.items(): fx.create_dataset(k, data=v) - with pytest.warns(UserWarning, match="The 'phony_dims' kwarg"): - with xr.open_dataset(tmp_file, engine="h5netcdf", group="bar") as ds: - assert ds.sizes == { - "phony_dim_0": 5, - "phony_dim_1": 5, - "phony_dim_2": 5, - "phony_dim_3": 25, - } + with xr.open_dataset(tmp_file, engine="h5netcdf", group="bar") as ds: + assert ds.sizes == { + "phony_dim_0": 5, + "phony_dim_1": 5, + "phony_dim_2": 5, + "phony_dim_3": 25, + } @requires_h5netcdf @@ -4529,7 +4534,8 @@ def test_get_variable_list(self) -> None: with open_dataset( self.test_remote_dataset, engine="h5netcdf", - backend_kwargs={"driver": "ros3"}, + open_opts=H5netcdfOpenOptions(driver="ros3"), + # backend_kwargs={"driver": "ros3"}, ) as actual: assert "Temperature" in list(actual) @@ -5866,7 +5872,7 @@ def test_use_cftime_standard_calendar_default_in_range(calendar) -> None: with create_tmp_file() as tmp_file: original.to_netcdf(tmp_file) with warnings.catch_warnings(record=True) as record: - with open_dataset(tmp_file) as ds: + with open_dataset(tmp_file, engine="netcdf4") as ds: assert_identical(expected_x, ds.x) assert_identical(expected_time, ds.time) _assert_no_dates_out_of_range_warning(record) diff --git a/xarray/tests/test_backends_api.py b/xarray/tests/test_backends_api.py index 9342423b727..5578dda2643 100644 --- a/xarray/tests/test_backends_api.py +++ b/xarray/tests/test_backends_api.py @@ -69,7 +69,7 @@ def open_dataset( class PassThroughBackendEntrypoint(xr.backends.BackendEntrypoint): """Access an object passed to the `open_dataset` method.""" - def open_dataset(self, dataset, *, drop_variables=None): + def open_dataset(self, dataset, *, drop_variables=None, **kwargs): """Return the first argument.""" return dataset From edcc10ce1ec44c18779500f7186ae7a5fc37786c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kai=20M=C3=BChlbauer?= Date: Sat, 14 Jun 2025 16:29:20 +0200 Subject: [PATCH 4/4] clean up --- xarray/backends/api.py | 48 +++++------------- xarray/backends/common.py | 12 ++--- xarray/backends/h5netcdf_.py | 91 +++++------------------------------ xarray/backends/netCDF4_.py | 49 +++---------------- xarray/tests/test_backends.py | 7 --- 5 files changed, 36 insertions(+), 171 deletions(-) diff --git a/xarray/backends/api.py b/xarray/backends/api.py index e427a4a94dd..94af48cfcc6 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -9,6 +9,7 @@ MutableMapping, Sequence, ) +from dataclasses import asdict, fields from functools import partial from io import BytesIO from numbers import Number @@ -29,6 +30,9 @@ from xarray.backends.common import ( AbstractDataStore, ArrayWriter, + BackendOptions, + CoderOptions, + XarrayBackendOptions, _find_absolute_paths, _normalize_path, _reset_dataclass_to_false, @@ -475,13 +479,6 @@ def _datatree_from_backend_datatree( return tree -from dataclasses import asdict, fields - -Buffer = Union[bytes, bytearray, memoryview] - - -from xarray.backends.common import BackendOptions, CoderOptions, XarrayBackendOptions - # @dataclass(frozen=True) # class XarrayBackendOptions: # chunks: Optional[T_Chunks] = None @@ -692,21 +689,13 @@ def open_dataset( backend = plugins.get_backend(engine) - print("XX0:", backend) - print("XX1:", type(backend)) - print("XX2:", type(backend.coder_opts)) - print("XX3:", coder_opts) - print("XX4:", backend.coder_opts) - print("XX4-0:", kwargs) - - # initialize CoderOptions with decoders of not given + # initialize CoderOptions with decoders if not given # Deprecation Fallback - if coder_opts is False: # or decode_cf is False: + if coder_opts is False: coder_opts = _reset_dataclass_to_false(backend.coder_opts) elif coder_opts is True: coder_opts = backend.coder_opts elif coder_opts is None: - print("XX4-1:", decode_cf) field_names = {f.name for f in fields(backend.coder_class)} decoders = _resolve_decoders_kwargs( decode_cf, @@ -719,11 +708,8 @@ def open_dataset( decode_coords=decode_coords, ) decoders["drop_variables"] = drop_variables - print("XX4-2:", decoders) coder_opts = backend.coder_class(**decoders) - print("XX5:", coder_opts) - if backend_opts is None: backend_opts = XarrayBackendOptions( chunks=chunks, @@ -734,16 +720,12 @@ def open_dataset( overwrite_encoded_chunks=kwargs.pop("overwrite_encoded_chunks", None), ) - print("XX6:", backend_opts) - # Check if store_opts have been ovrridden in the subclass - # That indicates new-style behaviour - # We can keep backwards compatibility - print("XX70:", kwargs) + # Check if store_opts have been overridden in the subclass. + # That indicates new-style behaviour. + # We can keep backwards compatibility. _store_opts = backend.store_opts - print("XX70a:", type(_store_opts)) if type(_store_opts) is BackendOptions: coder_kwargs = asdict(coder_opts) - print("XX7a:", kwargs) backend_ds = backend.open_dataset( filename_or_obj, **coder_kwargs, @@ -756,12 +738,10 @@ def open_dataset( open_kwargs = {k: v for k, v in kwargs.items() if k in field_names} open_opts = backend.open_class(**open_kwargs) if store_opts is None: - # check for open kwargs and create open_opts + # check for store kwargs and create store_opts field_names = {f.name for f in fields(backend.store_class)} store_kwargs = {k: v for k, v in kwargs.items() if k in field_names} store_opts = backend.store_class(**store_kwargs) - print("XX7b:", open_opts) - print("XX7c:", store_opts) backend_ds = backend.open_dataset( filename_or_obj, coder_opts=coder_opts, @@ -1988,21 +1968,17 @@ def to_netcdf( kwargs["format"] = format kwargs["group"] = group + # fill open_opts according backend kwargs_names = list(kwargs) field_names = {f.name for f in fields(backend.open_class)} open_kwargs = {k: kwargs.pop(k) for k in kwargs_names if k in field_names} open_opts = backend.open_class(**open_kwargs) + # fill store_opts according backend field_names = {f.name for f in fields(backend.store_class)} store_kwargs = {k: kwargs.pop(k) for k in kwargs_names if k in field_names} store_opts = backend.store_class(**store_kwargs) - # open_opts = mplex=autocomplex) if open_opts is None else open_opts - # store_opts = backend.store_class(group=group, autoclose=autoclose) if store_opts is None else store_opts - print("TN0:", open_opts) - print("TN1:", store_opts) - print("TN2:", mode) - print("TN3:", kwargs) store = store_open( target, mode=mode, open_opts=open_opts, store_opts=store_opts, **kwargs ) diff --git a/xarray/backends/common.py b/xarray/backends/common.py index 194d2e4b0fa..a2d65e41644 100644 --- a/xarray/backends/common.py +++ b/xarray/backends/common.py @@ -5,6 +5,7 @@ import time import traceback from collections.abc import Hashable, Iterable, Mapping, Sequence +from dataclasses import dataclass, fields, replace from glob import glob from typing import ( TYPE_CHECKING, @@ -20,6 +21,7 @@ import numpy as np import pandas as pd +from xarray.backends.locks import SerializableLock from xarray.coding import strings, variables from xarray.coding.times import CFDatetimeCoder, CFTimedeltaCoder from xarray.coding.variables import SerializationWarning @@ -51,6 +53,7 @@ NONE_VAR_NAME = "__values__" T = TypeVar("T") +Buffer = Union[bytes, bytearray, memoryview] @overload @@ -656,11 +659,6 @@ def encode(self, variables, attributes): return variables, attributes -from dataclasses import dataclass, fields, replace - -Buffer = Union[bytes, bytearray, memoryview] - - def _reset_dataclass_to_false(instance): field_names = [f.name for f in fields(instance)] false_values = dict.fromkeys(field_names, False) @@ -672,9 +670,6 @@ class BackendOptions: pass -from xarray.backends.locks import SerializableLock - - @dataclass(frozen=True) class StoreWriteOptions: group: Optional[str] = None @@ -700,6 +695,7 @@ class XarrayBackendOptions: @dataclass(frozen=True) class CoderOptions: + # maybe add these two to disentangle masking from scaling? # mask: Optional[bool] = None # scale: Optional[bool] = None mask_and_scale: Optional[bool | Mapping[str, bool]] = None diff --git a/xarray/backends/h5netcdf_.py b/xarray/backends/h5netcdf_.py index d47bb1628dd..e1596918887 100644 --- a/xarray/backends/h5netcdf_.py +++ b/xarray/backends/h5netcdf_.py @@ -4,13 +4,15 @@ import io import os from collections.abc import Iterable -from typing import TYPE_CHECKING, Any, Union +from dataclasses import asdict, dataclass +from typing import TYPE_CHECKING, Any, Optional, Union import numpy as np from xarray.backends.common import ( BACKEND_ENTRYPOINTS, BackendEntrypoint, + BackendOptions, WritableCFDataStore, _normalize_path, _open_remote_file, @@ -18,7 +20,13 @@ find_root_and_group, ) from xarray.backends.file_manager import CachingFileManager, DummyFileManager -from xarray.backends.locks import HDF5_LOCK, combine_locks, ensure_lock, get_write_lock +from xarray.backends.locks import ( + HDF5_LOCK, + SerializableLock, + combine_locks, + ensure_lock, + get_write_lock, +) from xarray.backends.netCDF4_ import ( BaseNetCDF4Array, _build_and_get_enum, @@ -28,6 +36,7 @@ _get_datatype, _nc4_require_group, ) +from xarray.backends.netCDF4_ import NetCDF4CoderOptions as H5netcdfCoderOptions from xarray.backends.store import StoreBackendEntrypoint from xarray.core import indexing from xarray.core.utils import ( @@ -139,16 +148,6 @@ def open( cls, filename, mode="r", - # format=None, - # group=None, - # lock=None, - # autoclose=False, - # invalid_netcdf=None, - # phony_dims=None, - # decode_vlen_strings=True, - # driver=None, - # driver_kwds=None, - # storage_options: dict[str, Any] | None = None, open_opts=None, store_opts=None, **kwargs, @@ -182,21 +181,12 @@ def open( raise ValueError("invalid format for h5netcdf backend") kwargs.update(open_kwargs) - # kwargs.update(kwargs.pop("driver_kwds", None)) - # - # = { - # "invalid_netcdf": invalid_netcdf, - # "decode_vlen_strings": decode_vlen_strings, - # "driver": driver, - # } driver_kwds = kwargs.pop("driver_kwds", None) if driver_kwds is not None: kwargs.update(driver_kwds) + # check why this is needed in some cases, should not be in kwargs any more kwargs.pop("group", None) - print("XX0:", kwargs) - # if phony_dims is not None: - # kwargs["phony_dims"] = phony_dims lock = store_kwargs.get("lock", None) if lock is None: if mode == "r": @@ -204,7 +194,6 @@ def open( else: lock = combine_locks([HDF5_LOCK, get_write_lock(filename)]) store_kwargs["lock"] = lock - print("XX1:", store_kwargs) manager = CachingFileManager(h5netcdf.File, filename, mode=mode, kwargs=kwargs) return cls(manager, mode=mode, **store_kwargs) @@ -404,16 +393,6 @@ def _emit_phony_dims_warning(): ) -from dataclasses import asdict, dataclass -from typing import Optional - -from xarray.backends.locks import SerializableLock - -Buffer = Union[bytes, bytearray, memoryview] -from xarray.backends.common import BackendOptions -from xarray.backends.netCDF4_ import NetCDF4CoderOptions as H5netcdfCoderOptions - - @dataclass(frozen=True) class H5netcdfStoreOptions(BackendOptions): group: Optional[str] = None @@ -483,22 +462,6 @@ def open_dataset( filename_or_obj: str | os.PathLike[Any] | ReadBuffer | AbstractDataStore, *, mode="r", - # mask_and_scale=True, - # decode_times=True, - # concat_characters=True, - # decode_coords=True, - # drop_variables: str | Iterable[str] | None = None, - # use_cftime=None, - # decode_timedelta=None, - # format=None, - # group=None, - # lock=None, - # invalid_netcdf=None, - # phony_dims=None, - # decode_vlen_strings=True, - # driver=None, - # driver_kwds=None, - # storage_options: dict[str, Any] | None = None, coder_opts: H5netcdfCoderOptions = None, open_opts: H5netcdfOpenOptions = None, store_opts: H5netcdfStoreOptions = None, @@ -508,25 +471,12 @@ def open_dataset( open_opts = open_opts if open_opts is not None else self.open_opts store_opts = store_opts if store_opts is not None else self.store_opts - # Keep this message for some versions - # remove and set phony_dims="access" above - # emit_phony_dims_warning, phony_dims = _check_phony_dims(phony_dims) - filename_or_obj = _normalize_path(filename_or_obj) store = H5NetCDFStore.open( filename_or_obj, mode=mode, open_opts=open_opts, store_opts=store_opts, - # format=format, - # group=group, - # lock=lock, - # invalid_netcdf=invalid_netcdf, - # phony_dims=phony_dims, - # decode_vlen_strings=decode_vlen_strings, - # driver=driver, - # driver_kwds=driver_kwds, - # storage_options=storage_options, **kwargs, ) @@ -537,25 +487,8 @@ def open_dataset( coder_opts=coder_opts, ) - # only warn if phony_dims exist in file - # remove together with the above check - # after some versions - # if store.ds._root._phony_dim_count > 0 and emit_phony_dims_warning: - # _emit_phony_dims_warning() - return ds - # def to_netcdf( - # self, - # filename_or_obj: str | os.PathLike[Any] | ReadBuffer | AbstractDataStore, - # *, - # coder_opts: H5netcdfCoderOptions = None, - # open_opts: H5netcdfOpenOptions = None, - # store_opts: H5netcdfStoreOptions = None, - # **kwargs, - # ): - # - def open_datatree( self, filename_or_obj: str | os.PathLike[Any] | ReadBuffer | AbstractDataStore, diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index 38d049ddea8..fd689264687 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -3,9 +3,10 @@ import functools import operator import os -from collections.abc import Iterable +from collections.abc import Iterable, Mapping from contextlib import suppress -from typing import TYPE_CHECKING, Any +from dataclasses import asdict, dataclass +from typing import TYPE_CHECKING, Any, Literal, Optional import numpy as np @@ -14,6 +15,9 @@ BACKEND_ENTRYPOINTS, BackendArray, BackendEntrypoint, + BackendOptions, + Buffer, + CoderOptions, WritableCFDataStore, _normalize_path, datatree_from_dict_with_io_cleanup, @@ -24,12 +28,14 @@ from xarray.backends.locks import ( HDF5_LOCK, NETCDFC_LOCK, + SerializableLock, combine_locks, ensure_lock, get_write_lock, ) from xarray.backends.netcdf3 import encode_nc3_attr_value, encode_nc3_variable from xarray.backends.store import StoreBackendEntrypoint +from xarray.coding.times import CFDatetimeCoder from xarray.coding.variables import pop_to from xarray.core import indexing from xarray.core.utils import ( @@ -401,9 +407,6 @@ def open( cls, filename, mode="r", - # group=None, - # lock=None, - # autoclose=False, open_opts=None, store_opts=None, **kwargs, @@ -441,8 +444,6 @@ def open( lock = combine_locks([base_lock, get_write_lock(filename)]) store_kwargs["lock"] = lock kwargs.update(open_kwargs) - print("DD0:", kwargs) - print("DD1:", store_kwargs) manager = CachingFileManager( netCDF4.Dataset, filename, mode=mode, kwargs=kwargs ) @@ -594,28 +595,6 @@ def close(self, **kwargs): self._manager.close(**kwargs) -from collections.abc import Mapping -from dataclasses import asdict, dataclass -from typing import Literal, Optional, Union - -from xarray.backends.locks import SerializableLock - -Buffer = Union[bytes, bytearray, memoryview] -from xarray.backends.common import BackendOptions - - -@dataclass(frozen=True) -class StoreWriteOptions: - group: Optional[str] = None - lock: Optional[SerializableLock] = None - autoclose: Optional[bool] = False - - -@dataclass(frozen=True) -class StoreWriteOpenOptions: - format: Optional[str] = "NETCDF4" - - @dataclass(frozen=True) class NetCDF4StoreOptions(BackendOptions): group: Optional[str] = None @@ -638,10 +617,6 @@ class NetCDF4OpenOptions(BackendOptions): auto_complex: Optional[bool] = None -from xarray.backends.common import CoderOptions -from xarray.coding.times import CFDatetimeCoder - - @dataclass(frozen=True) class NetCDF4CoderOptions(CoderOptions): mask_and_scale: Optional[bool | Mapping[str, bool]] = True @@ -704,10 +679,7 @@ def open_dataset( self, filename_or_obj: str | os.PathLike[Any] | ReadBuffer | AbstractDataStore, *, - # group=None, mode="r", - # lock=None, - # autoclose=None, coder_opts: NetCDF4CoderOptions = None, open_opts: NetCDF4OpenOptions = None, store_opts: NetCDF4StoreOptions = None, @@ -717,15 +689,10 @@ def open_dataset( open_opts = open_opts if open_opts is not None else self.open_opts store_opts = store_opts if store_opts is not None else self.store_opts - # open_kwargs = asdict(open_opts) - filename_or_obj = _normalize_path(filename_or_obj) store = NetCDF4DataStore.open( filename_or_obj, mode=mode, - # group=group, - # lock=lock, - # autoclose=autoclose, open_opts=open_opts, store_opts=store_opts, **kwargs, diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 497e98c6551..0fda4ecce86 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -368,8 +368,6 @@ def roundtrip( if open_kwargs is None: open_kwargs = {} with create_tmp_file(allow_cleanup_failure=allow_cleanup_failure) as path: - print("ZZ0:", save_kwargs) - print("ZZ1:", open_kwargs) self.save(data, path, **save_kwargs) with self.open(path, **open_kwargs) as ds: yield ds @@ -1556,14 +1554,10 @@ def test_write_groups(self) -> None: data1 = create_test_data() data2 = data1 * 2 with create_tmp_file() as tmp_file: - print("----------------------------------------") self.save(data1, tmp_file, group="data/1") - print("----------------------------------------") self.save(data2, tmp_file, group="data/2", mode="a") - print("----------------------------------------") with self.open(tmp_file, group="data/1") as actual1: assert_identical(data1, actual1) - print("----------------------------------------") with self.open(tmp_file, group="data/2") as actual2: assert_identical(data2, actual2) @@ -5423,7 +5417,6 @@ def convert_to_pydap_dataset(self, original): @contextlib.contextmanager def create_datasets(self, **kwargs): with open_example_dataset("bears.nc") as expected: - # print("QQ0:", expected["bears"].load()) pydap_ds = self.convert_to_pydap_dataset(expected) actual = open_dataset(PydapDataStore(pydap_ds)) if Version(np.__version__) < Version("2.3.0"):