@@ -125,6 +125,7 @@ def _write_with_tensorstore(
125125 # Use full array shape if provided, otherwise use the region array shape
126126 dataset_shape = full_array_shape if full_array_shape is not None else array .shape
127127
128+ # Build the base spec
128129 spec = {
129130 "kvstore" : {
130131 "driver" : "file" ,
@@ -134,33 +135,36 @@ def _write_with_tensorstore(
134135 "shape" : dataset_shape ,
135136 },
136137 }
138+
137139 if zarr_format == 2 :
138140 spec ["driver" ] = "zarr" if zarr_version_major < 3 else "zarr2"
139- spec ["metadata" ]["chunks" ] = chunks
140141 spec ["metadata" ]["dimension_separator" ] = "/"
141142 spec ["metadata" ]["dtype" ] = array .dtype .str
143+ # Only add chunk info when creating the dataset
144+ if create_dataset :
145+ spec ["metadata" ]["chunks" ] = chunks
142146 elif zarr_format == 3 :
143147 spec ["driver" ] = "zarr3"
144- spec ["metadata" ]["chunk_grid" ] = {
145- "name" : "regular" ,
146- "configuration" : {"chunk_shape" : chunks },
147- }
148148 spec ["metadata" ]["data_type" ] = _numpy_to_zarr_dtype (array .dtype )
149- spec [' metadata' ]["chunk_key_encoding" ] = {
149+ spec [" metadata" ]["chunk_key_encoding" ] = {
150150 "name" : "default" ,
151- "configuration" : {
152- "separator" : "/"
153- }
151+ "configuration" : {"separator" : "/" },
154152 }
155153 if dimension_names :
156154 spec ["metadata" ]["dimension_names" ] = dimension_names
157- if internal_chunk_shape :
158- spec ["metadata" ]["codecs" ] = [
159- {
160- "name" : "sharding_indexed" ,
161- "configuration" : {"chunk_shape" : internal_chunk_shape },
162- }
163- ]
155+ # Only add chunk info when creating the dataset
156+ if create_dataset :
157+ spec ["metadata" ]["chunk_grid" ] = {
158+ "name" : "regular" ,
159+ "configuration" : {"chunk_shape" : chunks },
160+ }
161+ if internal_chunk_shape :
162+ spec ["metadata" ]["codecs" ] = [
163+ {
164+ "name" : "sharding_indexed" ,
165+ "configuration" : {"chunk_shape" : internal_chunk_shape },
166+ }
167+ ]
164168 else :
165169 raise ValueError (f"Unsupported zarr format: { zarr_format } " )
166170
@@ -169,15 +173,70 @@ def _write_with_tensorstore(
169173 if create_dataset :
170174 dataset = ts .open (spec , create = True , dtype = array .dtype ).result ()
171175 else :
172- dataset = ts .open (spec , create = False , dtype = array .dtype ).result ()
176+ # For existing datasets, use a minimal spec that just specifies the path
177+ existing_spec = {
178+ "kvstore" : {
179+ "driver" : "file" ,
180+ "path" : store_path ,
181+ },
182+ "driver" : spec ["driver" ],
183+ }
184+ dataset = ts .open (existing_spec , create = False , dtype = array .dtype ).result ()
173185 except Exception as e :
174186 if "ALREADY_EXISTS" in str (e ) and create_dataset :
175187 # Dataset already exists, open it without creating
176- dataset = ts .open (spec , create = False , dtype = array .dtype ).result ()
188+ existing_spec = {
189+ "kvstore" : {
190+ "driver" : "file" ,
191+ "path" : store_path ,
192+ },
193+ "driver" : spec ["driver" ],
194+ }
195+ dataset = ts .open (existing_spec , create = False , dtype = array .dtype ).result ()
177196 else :
178197 raise
179198
180- dataset [region ] = array
199+ # Try to write the dask array directly first
200+ try :
201+ dataset [region ] = array
202+ except Exception as e :
203+ # If we encounter dimension mismatch or shape-related errors,
204+ # compute the array and try again with corrective action
205+ error_msg = str (e ).lower ()
206+ if any (
207+ keyword in error_msg
208+ for keyword in [
209+ "dimension" ,
210+ "shape" ,
211+ "mismatch" ,
212+ "size" ,
213+ "extent" ,
214+ "rank" ,
215+ "invalid" ,
216+ ]
217+ ):
218+ # Compute the array to get the actual shape
219+ computed_array = array .compute ()
220+
221+ # Adjust region to match the actual computed array shape if needed
222+ if len (region ) == len (computed_array .shape ):
223+ adjusted_region = tuple (
224+ slice (
225+ region [i ].start or 0 ,
226+ (region [i ].start or 0 ) + computed_array .shape [i ],
227+ )
228+ if isinstance (region [i ], slice )
229+ else region [i ]
230+ for i in range (len (region ))
231+ )
232+ else :
233+ adjusted_region = region
234+
235+ # Try writing the computed array with adjusted region
236+ dataset [adjusted_region ] = computed_array
237+ else :
238+ # Re-raise the exception if it's not related to dimension/shape issues
239+ raise
181240
182241
183242def _validate_ngff_parameters (
@@ -310,11 +369,21 @@ def _configure_sharding(
310369 internal_chunk_shape = c0
311370 arr = arr .rechunk (shards )
312371
313- # Only include 'shards' and 'chunks' in sharding_kwargs
314- sharding_kwargs = {
315- "shards" : shards ,
316- "chunks" : c0 ,
317- }
372+ # Configure sharding parameters differently for v2 vs v3
373+ sharding_kwargs = {}
374+ if zarr_version_major >= 3 :
375+ # For Zarr v3, configure sharding as a codec
376+ # Use chunk_shape for internal chunks and configure sharding via codecs
377+ sharding_kwargs ["chunk_shape" ] = internal_chunk_shape
378+ # Note: sharding codec will be configured separately in the codecs parameter
379+ # We'll pass the shard shape through a separate key to be handled later
380+ sharding_kwargs ["_shard_shape" ] = shards
381+ else :
382+ # For zarr v2, use the older API
383+ sharding_kwargs = {
384+ "shards" : shards ,
385+ "chunks" : internal_chunk_shape ,
386+ }
318387
319388 return sharding_kwargs , internal_chunk_shape , arr
320389
@@ -378,8 +447,37 @@ def _write_array_direct(
378447 arr = _prep_for_to_zarr (store , arr )
379448
380449 zarr_fmt = format_kwargs .get ("zarr_format" )
450+
451+ # Handle sharding kwargs for direct writing
452+ cleaned_sharding_kwargs = {}
453+
454+ if sharding_kwargs and "_shard_shape" in sharding_kwargs :
455+ # For Zarr v3 direct writes, use shards and chunks parameters
456+ shard_shape = sharding_kwargs ["_shard_shape" ]
457+ internal_chunk_shape = sharding_kwargs .get ("chunk_shape" )
458+
459+ # Ensure internal_chunk_shape is available
460+ if internal_chunk_shape is None :
461+ # Use chunks from arr if available, or default
462+ internal_chunk_shape = tuple (arr .chunks [i ][0 ] for i in range (arr .ndim ))
463+
464+ # For direct Zarr v3 writes, use shards and chunks
465+ cleaned_sharding_kwargs ["shards" ] = shard_shape
466+ cleaned_sharding_kwargs ["chunks" ] = internal_chunk_shape
467+
468+ # Remove internal kwargs
469+ cleaned_sharding_kwargs .update (
470+ {
471+ k : v
472+ for k , v in sharding_kwargs .items ()
473+ if k not in ["_shard_shape" , "chunk_shape" ]
474+ }
475+ )
476+ else :
477+ cleaned_sharding_kwargs = sharding_kwargs
478+
381479 to_zarr_kwargs = {
382- ** sharding_kwargs ,
480+ ** cleaned_sharding_kwargs ,
383481 ** zarr_kwargs ,
384482 ** format_kwargs ,
385483 ** dimension_names_kwargs ,
@@ -401,7 +499,9 @@ def _write_array_direct(
401499 array [:] = arr .compute ()
402500 else :
403501 # All other cases: use dask.array.to_zarr
404- target = zarr_array if (region is not None and zarr_array is not None ) else store
502+ target = (
503+ zarr_array if (region is not None and zarr_array is not None ) else store
504+ )
405505 dask .array .to_zarr (
406506 arr ,
407507 target ,
@@ -414,7 +514,6 @@ def _write_array_direct(
414514 )
415515
416516
417-
418517def _handle_large_array_writing (
419518 image ,
420519 arr : dask .array .Array ,
@@ -448,15 +547,71 @@ def _handle_large_array_writing(
448547
449548 chunks = tuple ([c [0 ] for c in arr .chunks ])
450549
550+ # If sharding is enabled, configure it properly
551+ chunk_kwargs = {}
552+ codecs_kwargs = {}
553+
554+ if sharding_kwargs :
555+ if "_shard_shape" in sharding_kwargs :
556+ # For Zarr v3, configure sharding as a codec only
557+ shard_shape = sharding_kwargs .pop ("_shard_shape" )
558+ internal_chunk_shape = sharding_kwargs .get (
559+ "chunk_shape"
560+ ) # This is the inner chunk shape
561+
562+ # Configure the sharding codec with proper defaults
563+ from zarr .codecs .sharding import ShardingCodec
564+ from zarr .codecs .bytes import BytesCodec
565+ from zarr .codecs .zstd import ZstdCodec
566+
567+ # Default inner codecs for sharding
568+ default_codecs = [BytesCodec (), ZstdCodec ()]
569+
570+ # Ensure internal_chunk_shape is available; fallback to chunks if needed
571+ if internal_chunk_shape is None :
572+ internal_chunk_shape = chunks
573+
574+ # The array's chunk_shape should be the shard shape
575+ # The sharding codec's chunk_shape should be the internal chunk shape
576+ sharding_codec = ShardingCodec (
577+ chunk_shape = internal_chunk_shape , # Internal chunk shape within shards
578+ codecs = default_codecs ,
579+ )
580+
581+ # Set up codecs with sharding
582+ existing_codecs = zarr_kwargs .get ("codecs" , [])
583+ if not isinstance (existing_codecs , list ):
584+ existing_codecs = []
585+ codecs_kwargs ["codecs" ] = [sharding_codec ] + existing_codecs
586+
587+ # Set the array's chunk_shape to the shard shape
588+ chunk_kwargs ["chunk_shape" ] = shard_shape
589+
590+ # Clean up remaining kwargs (remove chunk_shape since we're setting it explicitly)
591+ remaining_kwargs = {
592+ k : v
593+ for k , v in sharding_kwargs .items ()
594+ if k not in ["_shard_shape" , "chunk_shape" ]
595+ }
596+ sharding_kwargs_clean = remaining_kwargs
597+ else :
598+ # For Zarr v2 or other cases
599+ sharding_kwargs_clean = sharding_kwargs
600+ else :
601+ # No sharding
602+ chunk_kwargs = {"chunks" : chunks }
603+ sharding_kwargs_clean = {}
604+
451605 zarr_array = open_array (
452606 shape = arr .shape ,
453- chunks = chunks ,
454607 dtype = arr .dtype ,
455608 store = store ,
456609 path = path ,
457610 mode = "a" ,
458- ** sharding_kwargs ,
611+ ** chunk_kwargs ,
612+ ** sharding_kwargs_clean ,
459613 ** zarr_kwargs ,
614+ ** codecs_kwargs ,
460615 ** dimension_names_kwargs ,
461616 ** format_kwargs ,
462617 )
@@ -491,7 +646,7 @@ def _handle_large_array_writing(
491646 store_path ,
492647 path ,
493648 optimized ,
494- [ c [ 0 ] for c in arr_region . chunks ],
649+ chunks , # Use original array chunks, not region chunks
495650 shards ,
496651 internal_chunk_shape ,
497652 zarr_format ,
@@ -830,18 +985,11 @@ def to_ngff_zarr(
830985 arr , chunks_per_shard , dims , kwargs .copy ()
831986 )
832987
833- # Get the chunks and optional shards for TensorStore
988+ # Get the chunks - these are now the shards if sharding is enabled
834989 chunks = tuple ([c [0 ] for c in arr .chunks ])
835- shards = None
836- if chunks_per_shard is not None :
837- if isinstance (chunks_per_shard , int ):
838- shards = tuple ([c * chunks_per_shard for c in chunks ])
839- elif isinstance (chunks_per_shard , (tuple , list )):
840- shards = tuple ([c * chunks_per_shard [i ] for i , c in enumerate (chunks )])
841- elif isinstance (chunks_per_shard , dict ):
842- shards = tuple (
843- [c * chunks_per_shard .get (dims [i ], 1 ) for i , c in enumerate (chunks )]
844- )
990+
991+ # For TensorStore, shards are the same as chunks when sharding is enabled
992+ shards = chunks if chunks_per_shard is not None else None
845993
846994 # Determine write method based on memory requirements
847995 if memory_usage (image ) > config .memory_target and multiscales .scale_factors :
0 commit comments