Skip to content

Commit 5fcf82c

Browse files
authored
bytes supported as attributes, with additional formatting check for h5netcdf engine (#9407)
1 parent 12d8cfa commit 5fcf82c

File tree

3 files changed

+52
-4
lines changed

3 files changed

+52
-4
lines changed

xarray/backends/api.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def check_name(name: Hashable):
167167
check_name(k)
168168

169169

170-
def _validate_attrs(dataset, invalid_netcdf=False):
170+
def _validate_attrs(dataset, engine, invalid_netcdf=False):
171171
"""`attrs` must have a string key and a value which is either: a number,
172172
a string, an ndarray, a list/tuple of numbers/strings, or a numpy.bool_.
173173
@@ -177,8 +177,8 @@ def _validate_attrs(dataset, invalid_netcdf=False):
177177
`invalid_netcdf=True`.
178178
"""
179179

180-
valid_types = (str, Number, np.ndarray, np.number, list, tuple)
181-
if invalid_netcdf:
180+
valid_types = (str, Number, np.ndarray, np.number, list, tuple, bytes)
181+
if invalid_netcdf and engine == "h5netcdf":
182182
valid_types += (np.bool_,)
183183

184184
def check_attr(name, value, valid_types):
@@ -202,6 +202,23 @@ def check_attr(name, value, valid_types):
202202
f"{', '.join([vtype.__name__ for vtype in valid_types])}"
203203
)
204204

205+
if isinstance(value, bytes) and engine == "h5netcdf":
206+
try:
207+
value.decode("utf-8")
208+
except UnicodeDecodeError as e:
209+
raise ValueError(
210+
f"Invalid value provided for attribute '{name!r}': {value!r}. "
211+
"Only binary data derived from UTF-8 encoded strings is allowed "
212+
f"for the '{engine}' engine. Consider using the 'netcdf4' engine."
213+
) from e
214+
215+
if b"\x00" in value:
216+
raise ValueError(
217+
f"Invalid value provided for attribute '{name!r}': {value!r}. "
218+
f"Null characters are not permitted for the '{engine}' engine. "
219+
"Consider using the 'netcdf4' engine."
220+
)
221+
205222
# Check attrs on the dataset itself
206223
for k, v in dataset.attrs.items():
207224
check_attr(k, v, valid_types)
@@ -1353,7 +1370,7 @@ def to_netcdf(
13531370

13541371
# validate Dataset keys, DataArray names, and attr keys/values
13551372
_validate_dataset_names(dataset)
1356-
_validate_attrs(dataset, invalid_netcdf=invalid_netcdf and engine == "h5netcdf")
1373+
_validate_attrs(dataset, engine, invalid_netcdf)
13571374

13581375
try:
13591376
store_open = WRITEABLE_STORES[engine]

xarray/tests/conftest.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,26 @@ def d(request, backend, type) -> DataArray | Dataset:
139139
raise ValueError
140140

141141

142+
@pytest.fixture
143+
def byte_attrs_dataset():
144+
"""For testing issue #9407"""
145+
null_byte = b"\x00"
146+
other_bytes = bytes(range(1, 256))
147+
ds = Dataset({"x": 1}, coords={"x_coord": [1]})
148+
ds["x"].attrs["null_byte"] = null_byte
149+
ds["x"].attrs["other_bytes"] = other_bytes
150+
151+
expected = ds.copy()
152+
expected["x"].attrs["null_byte"] = ""
153+
expected["x"].attrs["other_bytes"] = other_bytes.decode(errors="replace")
154+
155+
return {
156+
"input": ds,
157+
"expected": expected,
158+
"h5netcdf_error": r"Invalid value provided for attribute .*: .*\. Null characters .*",
159+
}
160+
161+
142162
@pytest.fixture(scope="module")
143163
def create_test_datatree():
144164
"""

xarray/tests/test_backends.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1404,6 +1404,13 @@ def test_refresh_from_disk(self) -> None:
14041404
a.close()
14051405
b.close()
14061406

1407+
def test_byte_attrs(self, byte_attrs_dataset: dict[str, Any]) -> None:
1408+
# test for issue #9407
1409+
input = byte_attrs_dataset["input"]
1410+
expected = byte_attrs_dataset["expected"]
1411+
with self.roundtrip(input) as actual:
1412+
assert_identical(actual, expected)
1413+
14071414

14081415
_counter = itertools.count()
14091416

@@ -3861,6 +3868,10 @@ def test_decode_utf8_warning(self) -> None:
38613868
assert ds.title == title
38623869
assert "attribute 'title' of h5netcdf object '/'" in str(w[0].message)
38633870

3871+
def test_byte_attrs(self, byte_attrs_dataset: dict[str, Any]) -> None:
3872+
with pytest.raises(ValueError, match=byte_attrs_dataset["h5netcdf_error"]):
3873+
super().test_byte_attrs(byte_attrs_dataset)
3874+
38643875

38653876
@requires_h5netcdf
38663877
@requires_netCDF4

0 commit comments

Comments
 (0)