Skip to content

Commit 095d47f

Browse files
josephnowakpre-commit-ci[bot]max-sixty
authored
Improve safe chunk validation (#9559)
* fix safe chunks validation * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix safe chunks validation * Update xarray/tests/test_backends.py Co-authored-by: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> * The validation of the chunks now is able to detect full or partial chunk and raise a proper error based on the mode selected, it is also possible to use the auto region detection with the mode "a" * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * The test_extract_zarr_variable_encoding does not need to use the region parameter * Inline the code of the allow_partial_chunks and end, document the parameter in order on the extract_zarr_variable_encoding method, raise the correct error if the border size is smaller than the zchunk on mode equal to r+ * Inline the code of the allow_partial_chunks and end, document the parameter in order on the extract_zarr_variable_encoding method, raise the correct error if the border size is smaller than the zchunk on mode equal to r+ * Now the mode r+ is able to update the last chunk of Zarr even if it is not "complete" * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Now the mode r+ is able to update the last chunk of Zarr even if it is not "complete" * Add a typehint to the modes to avoid issues with mypy * Fix the detection of the last chunk * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix the whats-new and add mode="w" to the new test case * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Co-authored-by: Maximilian Roos <m@maxroos.com>
1 parent 7bdc6d4 commit 095d47f

File tree

5 files changed

+312
-54
lines changed

5 files changed

+312
-54
lines changed

doc/whats-new.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,9 @@ Bug fixes
6161
<https://github.com/spencerkclark>`_.
6262
- Fix a few bugs affecting groupby reductions with `flox`. (:issue:`8090`, :issue:`9398`).
6363
By `Deepak Cherian <https://github.com/dcherian>`_.
64-
64+
- Fix the safe_chunks validation option on the to_zarr method
65+
(:issue:`5511`, :pull:`9559`). By `Joseph Nowak
66+
<https://github.com/josephnowak>`_.
6567

6668
Documentation
6769
~~~~~~~~~~~~~

xarray/backends/zarr.py

Lines changed: 120 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,9 @@ def __getitem__(self, key):
112112
# could possibly have a work-around for 0d data here
113113

114114

115-
def _determine_zarr_chunks(enc_chunks, var_chunks, ndim, name, safe_chunks):
115+
def _determine_zarr_chunks(
116+
enc_chunks, var_chunks, ndim, name, safe_chunks, region, mode, shape
117+
):
116118
"""
117119
Given encoding chunks (possibly None or []) and variable chunks
118120
(possibly None or []).
@@ -163,7 +165,9 @@ def _determine_zarr_chunks(enc_chunks, var_chunks, ndim, name, safe_chunks):
163165

164166
if len(enc_chunks_tuple) != ndim:
165167
# throw away encoding chunks, start over
166-
return _determine_zarr_chunks(None, var_chunks, ndim, name, safe_chunks)
168+
return _determine_zarr_chunks(
169+
None, var_chunks, ndim, name, safe_chunks, region, mode, shape
170+
)
167171

168172
for x in enc_chunks_tuple:
169173
if not isinstance(x, int):
@@ -189,20 +193,58 @@ def _determine_zarr_chunks(enc_chunks, var_chunks, ndim, name, safe_chunks):
189193
# TODO: incorporate synchronizer to allow writes from multiple dask
190194
# threads
191195
if var_chunks and enc_chunks_tuple:
192-
for zchunk, dchunks in zip(enc_chunks_tuple, var_chunks, strict=True):
193-
for dchunk in dchunks[:-1]:
196+
# If it is possible to write on partial chunks then it is not necessary to check
197+
# the last one contained on the region
198+
allow_partial_chunks = mode != "r+"
199+
200+
base_error = (
201+
f"Specified zarr chunks encoding['chunks']={enc_chunks_tuple!r} for "
202+
f"variable named {name!r} would overlap multiple dask chunks {var_chunks!r} "
203+
f"on the region {region}. "
204+
f"Writing this array in parallel with dask could lead to corrupted data."
205+
f"Consider either rechunking using `chunk()`, deleting "
206+
f"or modifying `encoding['chunks']`, or specify `safe_chunks=False`."
207+
)
208+
209+
for zchunk, dchunks, interval, size in zip(
210+
enc_chunks_tuple, var_chunks, region, shape, strict=True
211+
):
212+
if not safe_chunks:
213+
continue
214+
215+
for dchunk in dchunks[1:-1]:
194216
if dchunk % zchunk:
195-
base_error = (
196-
f"Specified zarr chunks encoding['chunks']={enc_chunks_tuple!r} for "
197-
f"variable named {name!r} would overlap multiple dask chunks {var_chunks!r}. "
198-
f"Writing this array in parallel with dask could lead to corrupted data."
199-
)
200-
if safe_chunks:
201-
raise ValueError(
202-
base_error
203-
+ " Consider either rechunking using `chunk()`, deleting "
204-
"or modifying `encoding['chunks']`, or specify `safe_chunks=False`."
205-
)
217+
raise ValueError(base_error)
218+
219+
region_start = interval.start if interval.start else 0
220+
221+
if len(dchunks) > 1:
222+
# The first border size is the amount of data that needs to be updated on the
223+
# first chunk taking into account the region slice.
224+
first_border_size = zchunk
225+
if allow_partial_chunks:
226+
first_border_size = zchunk - region_start % zchunk
227+
228+
if (dchunks[0] - first_border_size) % zchunk:
229+
raise ValueError(base_error)
230+
231+
if not allow_partial_chunks:
232+
region_stop = interval.stop if interval.stop else size
233+
234+
if region_start % zchunk:
235+
# The last chunk which can also be the only one is a partial chunk
236+
# if it is not aligned at the beginning
237+
raise ValueError(base_error)
238+
239+
if np.ceil(region_stop / zchunk) == np.ceil(size / zchunk):
240+
# If the region is covering the last chunk then check
241+
# if the reminder with the default chunk size
242+
# is equal to the size of the last chunk
243+
if dchunks[-1] % zchunk != size % zchunk:
244+
raise ValueError(base_error)
245+
elif dchunks[-1] % zchunk:
246+
raise ValueError(base_error)
247+
206248
return enc_chunks_tuple
207249

208250
raise AssertionError("We should never get here. Function logic must be wrong.")
@@ -243,7 +285,14 @@ def _get_zarr_dims_and_attrs(zarr_obj, dimension_key, try_nczarr):
243285

244286

245287
def extract_zarr_variable_encoding(
246-
variable, raise_on_invalid=False, name=None, safe_chunks=True
288+
variable,
289+
raise_on_invalid=False,
290+
name=None,
291+
*,
292+
safe_chunks=True,
293+
region=None,
294+
mode=None,
295+
shape=None,
247296
):
248297
"""
249298
Extract zarr encoding dictionary from xarray Variable
@@ -252,12 +301,18 @@ def extract_zarr_variable_encoding(
252301
----------
253302
variable : Variable
254303
raise_on_invalid : bool, optional
255-
304+
name: str | Hashable, optional
305+
safe_chunks: bool, optional
306+
region: tuple[slice, ...], optional
307+
mode: str, optional
308+
shape: tuple[int, ...], optional
256309
Returns
257310
-------
258311
encoding : dict
259312
Zarr encoding for `variable`
260313
"""
314+
315+
shape = shape if shape else variable.shape
261316
encoding = variable.encoding.copy()
262317

263318
safe_to_drop = {"source", "original_shape"}
@@ -285,7 +340,14 @@ def extract_zarr_variable_encoding(
285340
del encoding[k]
286341

287342
chunks = _determine_zarr_chunks(
288-
encoding.get("chunks"), variable.chunks, variable.ndim, name, safe_chunks
343+
enc_chunks=encoding.get("chunks"),
344+
var_chunks=variable.chunks,
345+
ndim=variable.ndim,
346+
name=name,
347+
safe_chunks=safe_chunks,
348+
region=region,
349+
mode=mode,
350+
shape=shape,
289351
)
290352
encoding["chunks"] = chunks
291353
return encoding
@@ -762,16 +824,10 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No
762824
if v.encoding == {"_FillValue": None} and fill_value is None:
763825
v.encoding = {}
764826

765-
# We need to do this for both new and existing variables to ensure we're not
766-
# writing to a partial chunk, even though we don't use the `encoding` value
767-
# when writing to an existing variable. See
768-
# https://github.com/pydata/xarray/issues/8371 for details.
769-
encoding = extract_zarr_variable_encoding(
770-
v,
771-
raise_on_invalid=vn in check_encoding_set,
772-
name=vn,
773-
safe_chunks=self._safe_chunks,
774-
)
827+
zarr_array = None
828+
zarr_shape = None
829+
write_region = self._write_region if self._write_region is not None else {}
830+
write_region = {dim: write_region.get(dim, slice(None)) for dim in dims}
775831

776832
if name in existing_keys:
777833
# existing variable
@@ -801,7 +857,40 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No
801857
)
802858
else:
803859
zarr_array = self.zarr_group[name]
804-
else:
860+
861+
if self._append_dim is not None and self._append_dim in dims:
862+
# resize existing variable
863+
append_axis = dims.index(self._append_dim)
864+
assert write_region[self._append_dim] == slice(None)
865+
write_region[self._append_dim] = slice(
866+
zarr_array.shape[append_axis], None
867+
)
868+
869+
new_shape = list(zarr_array.shape)
870+
new_shape[append_axis] += v.shape[append_axis]
871+
zarr_array.resize(new_shape)
872+
873+
zarr_shape = zarr_array.shape
874+
875+
region = tuple(write_region[dim] for dim in dims)
876+
877+
# We need to do this for both new and existing variables to ensure we're not
878+
# writing to a partial chunk, even though we don't use the `encoding` value
879+
# when writing to an existing variable. See
880+
# https://github.com/pydata/xarray/issues/8371 for details.
881+
# Note: Ideally there should be two functions, one for validating the chunks and
882+
# another one for extracting the encoding.
883+
encoding = extract_zarr_variable_encoding(
884+
v,
885+
raise_on_invalid=vn in check_encoding_set,
886+
name=vn,
887+
safe_chunks=self._safe_chunks,
888+
region=region,
889+
mode=self._mode,
890+
shape=zarr_shape,
891+
)
892+
893+
if name not in existing_keys:
805894
# new variable
806895
encoded_attrs = {}
807896
# the magic for storing the hidden dimension data
@@ -833,22 +922,6 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No
833922
)
834923
zarr_array = _put_attrs(zarr_array, encoded_attrs)
835924

836-
write_region = self._write_region if self._write_region is not None else {}
837-
write_region = {dim: write_region.get(dim, slice(None)) for dim in dims}
838-
839-
if self._append_dim is not None and self._append_dim in dims:
840-
# resize existing variable
841-
append_axis = dims.index(self._append_dim)
842-
assert write_region[self._append_dim] == slice(None)
843-
write_region[self._append_dim] = slice(
844-
zarr_array.shape[append_axis], None
845-
)
846-
847-
new_shape = list(zarr_array.shape)
848-
new_shape[append_axis] += v.shape[append_axis]
849-
zarr_array.resize(new_shape)
850-
851-
region = tuple(write_region[dim] for dim in dims)
852925
writer.add(v.data, zarr_array, region)
853926

854927
def close(self) -> None:
@@ -897,9 +970,9 @@ def _validate_and_autodetect_region(self, ds) -> None:
897970
if not isinstance(region, dict):
898971
raise TypeError(f"``region`` must be a dict, got {type(region)}")
899972
if any(v == "auto" for v in region.values()):
900-
if self._mode != "r+":
973+
if self._mode not in ["r+", "a"]:
901974
raise ValueError(
902-
f"``mode`` must be 'r+' when using ``region='auto'``, got {self._mode!r}"
975+
f"``mode`` must be 'r+' or 'a' when using ``region='auto'``, got {self._mode!r}"
903976
)
904977
region = self._auto_detect_regions(ds, region)
905978

xarray/core/dataarray.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4316,6 +4316,14 @@ def to_zarr(
43164316
if Zarr arrays are written in parallel. This option may be useful in combination
43174317
with ``compute=False`` to initialize a Zarr store from an existing
43184318
DataArray with arbitrary chunk structure.
4319+
In addition to the many-to-one relationship validation, it also detects partial
4320+
chunks writes when using the region parameter,
4321+
these partial chunks are considered unsafe in the mode "r+" but safe in
4322+
the mode "a".
4323+
Note: Even with these validations it can still be unsafe to write
4324+
two or more chunked arrays in the same location in parallel if they are
4325+
not writing in independent regions, for those cases it is better to use
4326+
a synchronizer.
43194327
storage_options : dict, optional
43204328
Any additional parameters for the storage backend (ignored for local
43214329
paths).

xarray/core/dataset.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2509,6 +2509,14 @@ def to_zarr(
25092509
if Zarr arrays are written in parallel. This option may be useful in combination
25102510
with ``compute=False`` to initialize a Zarr from an existing
25112511
Dataset with arbitrary chunk structure.
2512+
In addition to the many-to-one relationship validation, it also detects partial
2513+
chunks writes when using the region parameter,
2514+
these partial chunks are considered unsafe in the mode "r+" but safe in
2515+
the mode "a".
2516+
Note: Even with these validations it can still be unsafe to write
2517+
two or more chunked arrays in the same location in parallel if they are
2518+
not writing in independent regions, for those cases it is better to use
2519+
a synchronizer.
25122520
storage_options : dict, optional
25132521
Any additional parameters for the storage backend (ignored for local
25142522
paths).

0 commit comments

Comments
 (0)