Skip to content

Commit c16fa1e

Browse files
kmuehlbauerZedThreepre-commit-ci[bot]
authored
towards new h5netcdf/netcdf4 features (#9509)
* MNT: towards new h5netcdf/netcdf4 features * Update xarray/backends/h5netcdf_.py * Update xarray/backends/netCDF4_.py * Update xarray/tests/test_backends.py * [pre-commit.ci] auto fixes from pre-commit.com hooks * FIX: correct handling of EnumType on dtype creation * FIX: only handle enumtypes if they are available from h5netcdf * whats-new.rst entry, minor fix * Update xarray/backends/netCDF4_.py Co-authored-by: Peter Hill <zed.three@gmail.com> * attempt to fix typing * use pytest recwarn instead emtpy context manager to make mypy happy * check for invalid_netcdf warning, too * fix howdoi.rst table entry --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Peter Hill <zed.three@gmail.com>
1 parent 0063a51 commit c16fa1e

File tree

11 files changed

+159
-56
lines changed

11 files changed

+159
-56
lines changed

doc/howdoi.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ How do I ...
5858
* - apply a function on all data variables in a Dataset
5959
- :py:meth:`Dataset.map`
6060
* - write xarray objects with complex values to a netCDF file
61-
- :py:func:`Dataset.to_netcdf`, :py:func:`DataArray.to_netcdf` specifying ``engine="h5netcdf", invalid_netcdf=True``
61+
- :py:func:`Dataset.to_netcdf`, :py:func:`DataArray.to_netcdf` specifying ``engine="h5netcdf"`` or :py:func:`Dataset.to_netcdf`, :py:func:`DataArray.to_netcdf` specifying ``engine="netCDF4", auto_complex=True``
6262
* - make xarray objects look like other xarray objects
6363
- :py:func:`~xarray.ones_like`, :py:func:`~xarray.zeros_like`, :py:func:`~xarray.full_like`, :py:meth:`Dataset.reindex_like`, :py:meth:`Dataset.interp_like`, :py:meth:`Dataset.broadcast_like`, :py:meth:`DataArray.reindex_like`, :py:meth:`DataArray.interp_like`, :py:meth:`DataArray.broadcast_like`
6464
* - Make sure my datasets have values at the same coordinate locations

doc/user-guide/io.rst

Lines changed: 3 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -566,29 +566,12 @@ This is not CF-compliant but again facilitates roundtripping of xarray datasets.
566566
Invalid netCDF files
567567
~~~~~~~~~~~~~~~~~~~~
568568

569-
The library ``h5netcdf`` allows writing some dtypes (booleans, complex, ...) that aren't
569+
The library ``h5netcdf`` allows writing some dtypes that aren't
570570
allowed in netCDF4 (see
571-
`h5netcdf documentation <https://github.com/shoyer/h5netcdf#invalid-netcdf-files>`_).
571+
`h5netcdf documentation <https://github.com/h5netcdf/h5netcdf#invalid-netcdf-files>`_).
572572
This feature is available through :py:meth:`DataArray.to_netcdf` and
573573
:py:meth:`Dataset.to_netcdf` when used with ``engine="h5netcdf"``
574-
and currently raises a warning unless ``invalid_netcdf=True`` is set:
575-
576-
.. ipython:: python
577-
:okwarning:
578-
579-
# Writing complex valued data
580-
da = xr.DataArray([1.0 + 1.0j, 2.0 + 2.0j, 3.0 + 3.0j])
581-
da.to_netcdf("complex.nc", engine="h5netcdf", invalid_netcdf=True)
582-
583-
# Reading it back
584-
reopened = xr.open_dataarray("complex.nc", engine="h5netcdf")
585-
reopened
586-
587-
.. ipython:: python
588-
:suppress:
589-
590-
reopened.close()
591-
os.remove("complex.nc")
574+
and currently raises a warning unless ``invalid_netcdf=True`` is set.
592575

593576
.. warning::
594577

doc/whats-new.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ New Features
3535
- Added support for vectorized interpolation using additional interpolators
3636
from the ``scipy.interpolate`` module (:issue:`9049`, :pull:`9526`).
3737
By `Holly Mandel <https://github.com/hollymandel>`_.
38+
- Implement handling of complex numbers (netcdf4/h5netcdf) and enums (h5netcdf) (:issue:`9246`, :issue:`3297`, :pull:`9509`).
39+
By `Kai Mühlbauer <https://github.com/kmuehlbauer>`_.
3840

3941
Breaking changes
4042
~~~~~~~~~~~~~~~~

xarray/backends/api.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1213,6 +1213,7 @@ def to_netcdf(
12131213
*,
12141214
multifile: Literal[True],
12151215
invalid_netcdf: bool = False,
1216+
auto_complex: bool | None = None,
12161217
) -> tuple[ArrayWriter, AbstractDataStore]: ...
12171218

12181219

@@ -1230,6 +1231,7 @@ def to_netcdf(
12301231
compute: bool = True,
12311232
multifile: Literal[False] = False,
12321233
invalid_netcdf: bool = False,
1234+
auto_complex: bool | None = None,
12331235
) -> bytes: ...
12341236

12351237

@@ -1248,6 +1250,7 @@ def to_netcdf(
12481250
compute: Literal[False],
12491251
multifile: Literal[False] = False,
12501252
invalid_netcdf: bool = False,
1253+
auto_complex: bool | None = None,
12511254
) -> Delayed: ...
12521255

12531256

@@ -1265,6 +1268,7 @@ def to_netcdf(
12651268
compute: Literal[True] = True,
12661269
multifile: Literal[False] = False,
12671270
invalid_netcdf: bool = False,
1271+
auto_complex: bool | None = None,
12681272
) -> None: ...
12691273

12701274

@@ -1283,6 +1287,7 @@ def to_netcdf(
12831287
compute: bool = False,
12841288
multifile: Literal[False] = False,
12851289
invalid_netcdf: bool = False,
1290+
auto_complex: bool | None = None,
12861291
) -> Delayed | None: ...
12871292

12881293

@@ -1301,6 +1306,7 @@ def to_netcdf(
13011306
compute: bool = False,
13021307
multifile: bool = False,
13031308
invalid_netcdf: bool = False,
1309+
auto_complex: bool | None = None,
13041310
) -> tuple[ArrayWriter, AbstractDataStore] | Delayed | None: ...
13051311

13061312

@@ -1318,6 +1324,7 @@ def to_netcdf(
13181324
compute: bool = False,
13191325
multifile: bool = False,
13201326
invalid_netcdf: bool = False,
1327+
auto_complex: bool | None = None,
13211328
) -> tuple[ArrayWriter, AbstractDataStore] | bytes | Delayed | None: ...
13221329

13231330

@@ -1333,6 +1340,7 @@ def to_netcdf(
13331340
compute: bool = True,
13341341
multifile: bool = False,
13351342
invalid_netcdf: bool = False,
1343+
auto_complex: bool | None = None,
13361344
) -> tuple[ArrayWriter, AbstractDataStore] | bytes | Delayed | None:
13371345
"""This function creates an appropriate datastore for writing a dataset to
13381346
disk as a netCDF file
@@ -1400,6 +1408,9 @@ def to_netcdf(
14001408
raise ValueError(
14011409
f"unrecognized option 'invalid_netcdf' for engine {engine}"
14021410
)
1411+
if auto_complex is not None:
1412+
kwargs["auto_complex"] = auto_complex
1413+
14031414
store = store_open(target, mode, format, group, **kwargs)
14041415

14051416
if unlimited_dims is None:

xarray/backends/h5netcdf_.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from collections.abc import Callable, Iterable
77
from typing import TYPE_CHECKING, Any
88

9+
import numpy as np
10+
911
from xarray.backends.common import (
1012
BACKEND_ENTRYPOINTS,
1113
BackendEntrypoint,
@@ -17,6 +19,7 @@
1719
from xarray.backends.locks import HDF5_LOCK, combine_locks, ensure_lock, get_write_lock
1820
from xarray.backends.netCDF4_ import (
1921
BaseNetCDF4Array,
22+
_build_and_get_enum,
2023
_encode_nc4_variable,
2124
_ensure_no_forward_slash_in_name,
2225
_extract_nc4_variable_encoding,
@@ -195,6 +198,7 @@ def ds(self):
195198
return self._acquire()
196199

197200
def open_store_variable(self, name, var):
201+
import h5netcdf
198202
import h5py
199203

200204
dimensions = var.dimensions
@@ -230,6 +234,18 @@ def open_store_variable(self, name, var):
230234
elif vlen_dtype is not None: # pragma: no cover
231235
# xarray doesn't support writing arbitrary vlen dtypes yet.
232236
pass
237+
# just check if datatype is available and create dtype
238+
# this check can be removed if h5netcdf >= 1.4.0 for any environment
239+
elif (datatype := getattr(var, "datatype", None)) and isinstance(
240+
datatype, h5netcdf.core.EnumType
241+
):
242+
encoding["dtype"] = np.dtype(
243+
data.dtype,
244+
metadata={
245+
"enum": datatype.enum_dict,
246+
"enum_name": datatype.name,
247+
},
248+
)
233249
else:
234250
encoding["dtype"] = var.dtype
235251

@@ -281,6 +297,14 @@ def prepare_variable(
281297
if dtype is str:
282298
dtype = h5py.special_dtype(vlen=str)
283299

300+
# check enum metadata and use h5netcdf.core.EnumType
301+
if (
302+
hasattr(self.ds, "enumtypes")
303+
and (meta := np.dtype(dtype).metadata)
304+
and (e_name := meta.get("enum_name"))
305+
and (e_dict := meta.get("enum"))
306+
):
307+
dtype = _build_and_get_enum(self, name, dtype, e_name, e_dict)
284308
encoding = _extract_h5nc_encoding(variable, raise_on_invalid=check_encoding)
285309
kwargs = {}
286310

xarray/backends/netCDF4_.py

Lines changed: 48 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@
4242
if TYPE_CHECKING:
4343
from io import BufferedIOBase
4444

45+
from h5netcdf.core import EnumType as h5EnumType
46+
from netCDF4 import EnumType as ncEnumType
47+
4548
from xarray.backends.common import AbstractDataStore
4649
from xarray.core.dataset import Dataset
4750
from xarray.core.datatree import DataTree
@@ -317,6 +320,39 @@ def _is_list_of_strings(value) -> bool:
317320
return arr.dtype.kind in ["U", "S"] and arr.size > 1
318321

319322

323+
def _build_and_get_enum(
324+
store, var_name: str, dtype: np.dtype, enum_name: str, enum_dict: dict[str, int]
325+
) -> ncEnumType | h5EnumType:
326+
"""
327+
Add or get the netCDF4 Enum based on the dtype in encoding.
328+
The return type should be ``netCDF4.EnumType``,
329+
but we avoid importing netCDF4 globally for performances.
330+
"""
331+
if enum_name not in store.ds.enumtypes:
332+
create_func = (
333+
store.ds.createEnumType
334+
if isinstance(store, NetCDF4DataStore)
335+
else store.ds.create_enumtype
336+
)
337+
return create_func(
338+
dtype,
339+
enum_name,
340+
enum_dict,
341+
)
342+
datatype = store.ds.enumtypes[enum_name]
343+
if datatype.enum_dict != enum_dict:
344+
error_msg = (
345+
f"Cannot save variable `{var_name}` because an enum"
346+
f" `{enum_name}` already exists in the Dataset but has"
347+
" a different definition. To fix this error, make sure"
348+
" all variables have a uniquely named enum in their"
349+
" `encoding['dtype'].metadata` or, if they should share"
350+
" the same enum type, make sure the enums are identical."
351+
)
352+
raise ValueError(error_msg)
353+
return datatype
354+
355+
320356
class NetCDF4DataStore(WritableCFDataStore):
321357
"""Store for reading and writing data via the Python-NetCDF4 library.
322358
@@ -370,6 +406,7 @@ def open(
370406
clobber=True,
371407
diskless=False,
372408
persist=False,
409+
auto_complex=None,
373410
lock=None,
374411
lock_maker=None,
375412
autoclose=False,
@@ -402,8 +439,13 @@ def open(
402439
lock = combine_locks([base_lock, get_write_lock(filename)])
403440

404441
kwargs = dict(
405-
clobber=clobber, diskless=diskless, persist=persist, format=format
442+
clobber=clobber,
443+
diskless=diskless,
444+
persist=persist,
445+
format=format,
406446
)
447+
if auto_complex is not None:
448+
kwargs["auto_complex"] = auto_complex
407449
manager = CachingFileManager(
408450
netCDF4.Dataset, filename, mode=mode, kwargs=kwargs
409451
)
@@ -516,7 +558,7 @@ def prepare_variable(
516558
and (e_name := meta.get("enum_name"))
517559
and (e_dict := meta.get("enum"))
518560
):
519-
datatype = self._build_and_get_enum(name, datatype, e_name, e_dict)
561+
datatype = _build_and_get_enum(self, name, datatype, e_name, e_dict)
520562
encoding = _extract_nc4_variable_encoding(
521563
variable, raise_on_invalid=check_encoding, unlimited_dims=unlimited_dims
522564
)
@@ -547,33 +589,6 @@ def prepare_variable(
547589

548590
return target, variable.data
549591

550-
def _build_and_get_enum(
551-
self, var_name: str, dtype: np.dtype, enum_name: str, enum_dict: dict[str, int]
552-
) -> Any:
553-
"""
554-
Add or get the netCDF4 Enum based on the dtype in encoding.
555-
The return type should be ``netCDF4.EnumType``,
556-
but we avoid importing netCDF4 globally for performances.
557-
"""
558-
if enum_name not in self.ds.enumtypes:
559-
return self.ds.createEnumType(
560-
dtype,
561-
enum_name,
562-
enum_dict,
563-
)
564-
datatype = self.ds.enumtypes[enum_name]
565-
if datatype.enum_dict != enum_dict:
566-
error_msg = (
567-
f"Cannot save variable `{var_name}` because an enum"
568-
f" `{enum_name}` already exists in the Dataset but have"
569-
" a different definition. To fix this error, make sure"
570-
" each variable have a uniquely named enum in their"
571-
" `encoding['dtype'].metadata` or, if they should share"
572-
" the same enum type, make sure the enums are identical."
573-
)
574-
raise ValueError(error_msg)
575-
return datatype
576-
577592
def sync(self):
578593
self.ds.sync()
579594

@@ -642,6 +657,7 @@ def open_dataset( # type: ignore[override] # allow LSP violation, not supporti
642657
clobber=True,
643658
diskless=False,
644659
persist=False,
660+
auto_complex=None,
645661
lock=None,
646662
autoclose=False,
647663
) -> Dataset:
@@ -654,6 +670,7 @@ def open_dataset( # type: ignore[override] # allow LSP violation, not supporti
654670
clobber=clobber,
655671
diskless=diskless,
656672
persist=persist,
673+
auto_complex=auto_complex,
657674
lock=lock,
658675
autoclose=autoclose,
659676
)
@@ -688,6 +705,7 @@ def open_datatree(
688705
clobber=True,
689706
diskless=False,
690707
persist=False,
708+
auto_complex=None,
691709
lock=None,
692710
autoclose=False,
693711
**kwargs,
@@ -715,6 +733,7 @@ def open_groups_as_dict(
715733
clobber=True,
716734
diskless=False,
717735
persist=False,
736+
auto_complex=None,
718737
lock=None,
719738
autoclose=False,
720739
**kwargs,

xarray/coding/variables.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -537,6 +537,7 @@ def _choose_float_dtype(
537537
if dtype.itemsize <= 2 and np.issubdtype(dtype, np.integer):
538538
return np.float32
539539
# For all other types and circumstances, we just use float64.
540+
# Todo: with nc-complex from netcdf4-python >= 1.7.0 this is available
540541
# (safe because eg. complex numbers are not supported in NetCDF)
541542
return np.float64
542543

xarray/core/dataarray.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3994,6 +3994,7 @@ def to_netcdf(
39943994
unlimited_dims: Iterable[Hashable] | None = None,
39953995
compute: bool = True,
39963996
invalid_netcdf: bool = False,
3997+
auto_complex: bool | None = None,
39973998
) -> bytes: ...
39983999

39994000
# compute=False returns dask.Delayed
@@ -4010,6 +4011,7 @@ def to_netcdf(
40104011
*,
40114012
compute: Literal[False],
40124013
invalid_netcdf: bool = False,
4014+
auto_complex: bool | None = None,
40134015
) -> Delayed: ...
40144016

40154017
# default return None
@@ -4025,6 +4027,7 @@ def to_netcdf(
40254027
unlimited_dims: Iterable[Hashable] | None = None,
40264028
compute: Literal[True] = True,
40274029
invalid_netcdf: bool = False,
4030+
auto_complex: bool | None = None,
40284031
) -> None: ...
40294032

40304033
# if compute cannot be evaluated at type check time
@@ -4041,6 +4044,7 @@ def to_netcdf(
40414044
unlimited_dims: Iterable[Hashable] | None = None,
40424045
compute: bool = True,
40434046
invalid_netcdf: bool = False,
4047+
auto_complex: bool | None = None,
40444048
) -> Delayed | None: ...
40454049

40464050
def to_netcdf(
@@ -4054,6 +4058,7 @@ def to_netcdf(
40544058
unlimited_dims: Iterable[Hashable] | None = None,
40554059
compute: bool = True,
40564060
invalid_netcdf: bool = False,
4061+
auto_complex: bool | None = None,
40574062
) -> bytes | Delayed | None:
40584063
"""Write DataArray contents to a netCDF file.
40594064
@@ -4170,6 +4175,7 @@ def to_netcdf(
41704175
compute=compute,
41714176
multifile=False,
41724177
invalid_netcdf=invalid_netcdf,
4178+
auto_complex=auto_complex,
41734179
)
41744180

41754181
# compute=True (default) returns ZarrStore

0 commit comments

Comments
 (0)