@@ -112,7 +112,9 @@ def __getitem__(self, key):
112
112
# could possibly have a work-around for 0d data here
113
113
114
114
115
- def _determine_zarr_chunks (enc_chunks , var_chunks , ndim , name , safe_chunks ):
115
+ def _determine_zarr_chunks (
116
+ enc_chunks , var_chunks , ndim , name , safe_chunks , region , mode , shape
117
+ ):
116
118
"""
117
119
Given encoding chunks (possibly None or []) and variable chunks
118
120
(possibly None or []).
@@ -163,7 +165,9 @@ def _determine_zarr_chunks(enc_chunks, var_chunks, ndim, name, safe_chunks):
163
165
164
166
if len (enc_chunks_tuple ) != ndim :
165
167
# throw away encoding chunks, start over
166
- return _determine_zarr_chunks (None , var_chunks , ndim , name , safe_chunks )
168
+ return _determine_zarr_chunks (
169
+ None , var_chunks , ndim , name , safe_chunks , region , mode , shape
170
+ )
167
171
168
172
for x in enc_chunks_tuple :
169
173
if not isinstance (x , int ):
@@ -189,20 +193,58 @@ def _determine_zarr_chunks(enc_chunks, var_chunks, ndim, name, safe_chunks):
189
193
# TODO: incorporate synchronizer to allow writes from multiple dask
190
194
# threads
191
195
if var_chunks and enc_chunks_tuple :
192
- for zchunk , dchunks in zip (enc_chunks_tuple , var_chunks , strict = True ):
193
- for dchunk in dchunks [:- 1 ]:
196
+ # If it is possible to write on partial chunks then it is not necessary to check
197
+ # the last one contained on the region
198
+ allow_partial_chunks = mode != "r+"
199
+
200
+ base_error = (
201
+ f"Specified zarr chunks encoding['chunks']={ enc_chunks_tuple !r} for "
202
+ f"variable named { name !r} would overlap multiple dask chunks { var_chunks !r} "
203
+ f"on the region { region } . "
204
+ f"Writing this array in parallel with dask could lead to corrupted data."
205
+ f"Consider either rechunking using `chunk()`, deleting "
206
+ f"or modifying `encoding['chunks']`, or specify `safe_chunks=False`."
207
+ )
208
+
209
+ for zchunk , dchunks , interval , size in zip (
210
+ enc_chunks_tuple , var_chunks , region , shape , strict = True
211
+ ):
212
+ if not safe_chunks :
213
+ continue
214
+
215
+ for dchunk in dchunks [1 :- 1 ]:
194
216
if dchunk % zchunk :
195
- base_error = (
196
- f"Specified zarr chunks encoding['chunks']={ enc_chunks_tuple !r} for "
197
- f"variable named { name !r} would overlap multiple dask chunks { var_chunks !r} . "
198
- f"Writing this array in parallel with dask could lead to corrupted data."
199
- )
200
- if safe_chunks :
201
- raise ValueError (
202
- base_error
203
- + " Consider either rechunking using `chunk()`, deleting "
204
- "or modifying `encoding['chunks']`, or specify `safe_chunks=False`."
205
- )
217
+ raise ValueError (base_error )
218
+
219
+ region_start = interval .start if interval .start else 0
220
+
221
+ if len (dchunks ) > 1 :
222
+ # The first border size is the amount of data that needs to be updated on the
223
+ # first chunk taking into account the region slice.
224
+ first_border_size = zchunk
225
+ if allow_partial_chunks :
226
+ first_border_size = zchunk - region_start % zchunk
227
+
228
+ if (dchunks [0 ] - first_border_size ) % zchunk :
229
+ raise ValueError (base_error )
230
+
231
+ if not allow_partial_chunks :
232
+ region_stop = interval .stop if interval .stop else size
233
+
234
+ if region_start % zchunk :
235
+ # The last chunk which can also be the only one is a partial chunk
236
+ # if it is not aligned at the beginning
237
+ raise ValueError (base_error )
238
+
239
+ if np .ceil (region_stop / zchunk ) == np .ceil (size / zchunk ):
240
+ # If the region is covering the last chunk then check
241
+ # if the reminder with the default chunk size
242
+ # is equal to the size of the last chunk
243
+ if dchunks [- 1 ] % zchunk != size % zchunk :
244
+ raise ValueError (base_error )
245
+ elif dchunks [- 1 ] % zchunk :
246
+ raise ValueError (base_error )
247
+
206
248
return enc_chunks_tuple
207
249
208
250
raise AssertionError ("We should never get here. Function logic must be wrong." )
@@ -243,7 +285,14 @@ def _get_zarr_dims_and_attrs(zarr_obj, dimension_key, try_nczarr):
243
285
244
286
245
287
def extract_zarr_variable_encoding (
246
- variable , raise_on_invalid = False , name = None , safe_chunks = True
288
+ variable ,
289
+ raise_on_invalid = False ,
290
+ name = None ,
291
+ * ,
292
+ safe_chunks = True ,
293
+ region = None ,
294
+ mode = None ,
295
+ shape = None ,
247
296
):
248
297
"""
249
298
Extract zarr encoding dictionary from xarray Variable
@@ -252,12 +301,18 @@ def extract_zarr_variable_encoding(
252
301
----------
253
302
variable : Variable
254
303
raise_on_invalid : bool, optional
255
-
304
+ name: str | Hashable, optional
305
+ safe_chunks: bool, optional
306
+ region: tuple[slice, ...], optional
307
+ mode: str, optional
308
+ shape: tuple[int, ...], optional
256
309
Returns
257
310
-------
258
311
encoding : dict
259
312
Zarr encoding for `variable`
260
313
"""
314
+
315
+ shape = shape if shape else variable .shape
261
316
encoding = variable .encoding .copy ()
262
317
263
318
safe_to_drop = {"source" , "original_shape" }
@@ -285,7 +340,14 @@ def extract_zarr_variable_encoding(
285
340
del encoding [k ]
286
341
287
342
chunks = _determine_zarr_chunks (
288
- encoding .get ("chunks" ), variable .chunks , variable .ndim , name , safe_chunks
343
+ enc_chunks = encoding .get ("chunks" ),
344
+ var_chunks = variable .chunks ,
345
+ ndim = variable .ndim ,
346
+ name = name ,
347
+ safe_chunks = safe_chunks ,
348
+ region = region ,
349
+ mode = mode ,
350
+ shape = shape ,
289
351
)
290
352
encoding ["chunks" ] = chunks
291
353
return encoding
@@ -762,16 +824,10 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No
762
824
if v .encoding == {"_FillValue" : None } and fill_value is None :
763
825
v .encoding = {}
764
826
765
- # We need to do this for both new and existing variables to ensure we're not
766
- # writing to a partial chunk, even though we don't use the `encoding` value
767
- # when writing to an existing variable. See
768
- # https://github.com/pydata/xarray/issues/8371 for details.
769
- encoding = extract_zarr_variable_encoding (
770
- v ,
771
- raise_on_invalid = vn in check_encoding_set ,
772
- name = vn ,
773
- safe_chunks = self ._safe_chunks ,
774
- )
827
+ zarr_array = None
828
+ zarr_shape = None
829
+ write_region = self ._write_region if self ._write_region is not None else {}
830
+ write_region = {dim : write_region .get (dim , slice (None )) for dim in dims }
775
831
776
832
if name in existing_keys :
777
833
# existing variable
@@ -801,7 +857,40 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No
801
857
)
802
858
else :
803
859
zarr_array = self .zarr_group [name ]
804
- else :
860
+
861
+ if self ._append_dim is not None and self ._append_dim in dims :
862
+ # resize existing variable
863
+ append_axis = dims .index (self ._append_dim )
864
+ assert write_region [self ._append_dim ] == slice (None )
865
+ write_region [self ._append_dim ] = slice (
866
+ zarr_array .shape [append_axis ], None
867
+ )
868
+
869
+ new_shape = list (zarr_array .shape )
870
+ new_shape [append_axis ] += v .shape [append_axis ]
871
+ zarr_array .resize (new_shape )
872
+
873
+ zarr_shape = zarr_array .shape
874
+
875
+ region = tuple (write_region [dim ] for dim in dims )
876
+
877
+ # We need to do this for both new and existing variables to ensure we're not
878
+ # writing to a partial chunk, even though we don't use the `encoding` value
879
+ # when writing to an existing variable. See
880
+ # https://github.com/pydata/xarray/issues/8371 for details.
881
+ # Note: Ideally there should be two functions, one for validating the chunks and
882
+ # another one for extracting the encoding.
883
+ encoding = extract_zarr_variable_encoding (
884
+ v ,
885
+ raise_on_invalid = vn in check_encoding_set ,
886
+ name = vn ,
887
+ safe_chunks = self ._safe_chunks ,
888
+ region = region ,
889
+ mode = self ._mode ,
890
+ shape = zarr_shape ,
891
+ )
892
+
893
+ if name not in existing_keys :
805
894
# new variable
806
895
encoded_attrs = {}
807
896
# the magic for storing the hidden dimension data
@@ -833,22 +922,6 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No
833
922
)
834
923
zarr_array = _put_attrs (zarr_array , encoded_attrs )
835
924
836
- write_region = self ._write_region if self ._write_region is not None else {}
837
- write_region = {dim : write_region .get (dim , slice (None )) for dim in dims }
838
-
839
- if self ._append_dim is not None and self ._append_dim in dims :
840
- # resize existing variable
841
- append_axis = dims .index (self ._append_dim )
842
- assert write_region [self ._append_dim ] == slice (None )
843
- write_region [self ._append_dim ] = slice (
844
- zarr_array .shape [append_axis ], None
845
- )
846
-
847
- new_shape = list (zarr_array .shape )
848
- new_shape [append_axis ] += v .shape [append_axis ]
849
- zarr_array .resize (new_shape )
850
-
851
- region = tuple (write_region [dim ] for dim in dims )
852
925
writer .add (v .data , zarr_array , region )
853
926
854
927
def close (self ) -> None :
@@ -897,9 +970,9 @@ def _validate_and_autodetect_region(self, ds) -> None:
897
970
if not isinstance (region , dict ):
898
971
raise TypeError (f"``region`` must be a dict, got { type (region )} " )
899
972
if any (v == "auto" for v in region .values ()):
900
- if self ._mode != "r+" :
973
+ if self ._mode not in [ "r+" , "a" ] :
901
974
raise ValueError (
902
- f"``mode`` must be 'r+' when using ``region='auto'``, got { self ._mode !r} "
975
+ f"``mode`` must be 'r+' or 'a' when using ``region='auto'``, got { self ._mode !r} "
903
976
)
904
977
region = self ._auto_detect_regions (ds , region )
905
978
0 commit comments