Skip to content

Commit 32bd137

Browse files
authored
Merge pull request #174 from thewtex/tensorstore-chunk-consist-rebase
BUG: Consistent chunking with tensorstore across writes
2 parents 9f12a63 + f8817ba commit 32bd137

File tree

2 files changed

+296
-42
lines changed

2 files changed

+296
-42
lines changed

py/ngff_zarr/to_ngff_zarr.py

Lines changed: 189 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ def _write_with_tensorstore(
125125
# Use full array shape if provided, otherwise use the region array shape
126126
dataset_shape = full_array_shape if full_array_shape is not None else array.shape
127127

128+
# Build the base spec
128129
spec = {
129130
"kvstore": {
130131
"driver": "file",
@@ -134,33 +135,36 @@ def _write_with_tensorstore(
134135
"shape": dataset_shape,
135136
},
136137
}
138+
137139
if zarr_format == 2:
138140
spec["driver"] = "zarr" if zarr_version_major < 3 else "zarr2"
139-
spec["metadata"]["chunks"] = chunks
140141
spec["metadata"]["dimension_separator"] = "/"
141142
spec["metadata"]["dtype"] = array.dtype.str
143+
# Only add chunk info when creating the dataset
144+
if create_dataset:
145+
spec["metadata"]["chunks"] = chunks
142146
elif zarr_format == 3:
143147
spec["driver"] = "zarr3"
144-
spec["metadata"]["chunk_grid"] = {
145-
"name": "regular",
146-
"configuration": {"chunk_shape": chunks},
147-
}
148148
spec["metadata"]["data_type"] = _numpy_to_zarr_dtype(array.dtype)
149-
spec['metadata']["chunk_key_encoding"] = {
149+
spec["metadata"]["chunk_key_encoding"] = {
150150
"name": "default",
151-
"configuration": {
152-
"separator": "/"
153-
}
151+
"configuration": {"separator": "/"},
154152
}
155153
if dimension_names:
156154
spec["metadata"]["dimension_names"] = dimension_names
157-
if internal_chunk_shape:
158-
spec["metadata"]["codecs"] = [
159-
{
160-
"name": "sharding_indexed",
161-
"configuration": {"chunk_shape": internal_chunk_shape},
162-
}
163-
]
155+
# Only add chunk info when creating the dataset
156+
if create_dataset:
157+
spec["metadata"]["chunk_grid"] = {
158+
"name": "regular",
159+
"configuration": {"chunk_shape": chunks},
160+
}
161+
if internal_chunk_shape:
162+
spec["metadata"]["codecs"] = [
163+
{
164+
"name": "sharding_indexed",
165+
"configuration": {"chunk_shape": internal_chunk_shape},
166+
}
167+
]
164168
else:
165169
raise ValueError(f"Unsupported zarr format: {zarr_format}")
166170

@@ -169,15 +173,70 @@ def _write_with_tensorstore(
169173
if create_dataset:
170174
dataset = ts.open(spec, create=True, dtype=array.dtype).result()
171175
else:
172-
dataset = ts.open(spec, create=False, dtype=array.dtype).result()
176+
# For existing datasets, use a minimal spec that just specifies the path
177+
existing_spec = {
178+
"kvstore": {
179+
"driver": "file",
180+
"path": store_path,
181+
},
182+
"driver": spec["driver"],
183+
}
184+
dataset = ts.open(existing_spec, create=False, dtype=array.dtype).result()
173185
except Exception as e:
174186
if "ALREADY_EXISTS" in str(e) and create_dataset:
175187
# Dataset already exists, open it without creating
176-
dataset = ts.open(spec, create=False, dtype=array.dtype).result()
188+
existing_spec = {
189+
"kvstore": {
190+
"driver": "file",
191+
"path": store_path,
192+
},
193+
"driver": spec["driver"],
194+
}
195+
dataset = ts.open(existing_spec, create=False, dtype=array.dtype).result()
177196
else:
178197
raise
179198

180-
dataset[region] = array
199+
# Try to write the dask array directly first
200+
try:
201+
dataset[region] = array
202+
except Exception as e:
203+
# If we encounter dimension mismatch or shape-related errors,
204+
# compute the array and try again with corrective action
205+
error_msg = str(e).lower()
206+
if any(
207+
keyword in error_msg
208+
for keyword in [
209+
"dimension",
210+
"shape",
211+
"mismatch",
212+
"size",
213+
"extent",
214+
"rank",
215+
"invalid",
216+
]
217+
):
218+
# Compute the array to get the actual shape
219+
computed_array = array.compute()
220+
221+
# Adjust region to match the actual computed array shape if needed
222+
if len(region) == len(computed_array.shape):
223+
adjusted_region = tuple(
224+
slice(
225+
region[i].start or 0,
226+
(region[i].start or 0) + computed_array.shape[i],
227+
)
228+
if isinstance(region[i], slice)
229+
else region[i]
230+
for i in range(len(region))
231+
)
232+
else:
233+
adjusted_region = region
234+
235+
# Try writing the computed array with adjusted region
236+
dataset[adjusted_region] = computed_array
237+
else:
238+
# Re-raise the exception if it's not related to dimension/shape issues
239+
raise
181240

182241

183242
def _validate_ngff_parameters(
@@ -310,11 +369,21 @@ def _configure_sharding(
310369
internal_chunk_shape = c0
311370
arr = arr.rechunk(shards)
312371

313-
# Only include 'shards' and 'chunks' in sharding_kwargs
314-
sharding_kwargs = {
315-
"shards": shards,
316-
"chunks": c0,
317-
}
372+
# Configure sharding parameters differently for v2 vs v3
373+
sharding_kwargs = {}
374+
if zarr_version_major >= 3:
375+
# For Zarr v3, configure sharding as a codec
376+
# Use chunk_shape for internal chunks and configure sharding via codecs
377+
sharding_kwargs["chunk_shape"] = internal_chunk_shape
378+
# Note: sharding codec will be configured separately in the codecs parameter
379+
# We'll pass the shard shape through a separate key to be handled later
380+
sharding_kwargs["_shard_shape"] = shards
381+
else:
382+
# For zarr v2, use the older API
383+
sharding_kwargs = {
384+
"shards": shards,
385+
"chunks": internal_chunk_shape,
386+
}
318387

319388
return sharding_kwargs, internal_chunk_shape, arr
320389

@@ -378,8 +447,37 @@ def _write_array_direct(
378447
arr = _prep_for_to_zarr(store, arr)
379448

380449
zarr_fmt = format_kwargs.get("zarr_format")
450+
451+
# Handle sharding kwargs for direct writing
452+
cleaned_sharding_kwargs = {}
453+
454+
if sharding_kwargs and "_shard_shape" in sharding_kwargs:
455+
# For Zarr v3 direct writes, use shards and chunks parameters
456+
shard_shape = sharding_kwargs["_shard_shape"]
457+
internal_chunk_shape = sharding_kwargs.get("chunk_shape")
458+
459+
# Ensure internal_chunk_shape is available
460+
if internal_chunk_shape is None:
461+
# Use chunks from arr if available, or default
462+
internal_chunk_shape = tuple(arr.chunks[i][0] for i in range(arr.ndim))
463+
464+
# For direct Zarr v3 writes, use shards and chunks
465+
cleaned_sharding_kwargs["shards"] = shard_shape
466+
cleaned_sharding_kwargs["chunks"] = internal_chunk_shape
467+
468+
# Remove internal kwargs
469+
cleaned_sharding_kwargs.update(
470+
{
471+
k: v
472+
for k, v in sharding_kwargs.items()
473+
if k not in ["_shard_shape", "chunk_shape"]
474+
}
475+
)
476+
else:
477+
cleaned_sharding_kwargs = sharding_kwargs
478+
381479
to_zarr_kwargs = {
382-
**sharding_kwargs,
480+
**cleaned_sharding_kwargs,
383481
**zarr_kwargs,
384482
**format_kwargs,
385483
**dimension_names_kwargs,
@@ -401,7 +499,9 @@ def _write_array_direct(
401499
array[:] = arr.compute()
402500
else:
403501
# All other cases: use dask.array.to_zarr
404-
target = zarr_array if (region is not None and zarr_array is not None) else store
502+
target = (
503+
zarr_array if (region is not None and zarr_array is not None) else store
504+
)
405505
dask.array.to_zarr(
406506
arr,
407507
target,
@@ -414,7 +514,6 @@ def _write_array_direct(
414514
)
415515

416516

417-
418517
def _handle_large_array_writing(
419518
image,
420519
arr: dask.array.Array,
@@ -448,15 +547,71 @@ def _handle_large_array_writing(
448547

449548
chunks = tuple([c[0] for c in arr.chunks])
450549

550+
# If sharding is enabled, configure it properly
551+
chunk_kwargs = {}
552+
codecs_kwargs = {}
553+
554+
if sharding_kwargs:
555+
if "_shard_shape" in sharding_kwargs:
556+
# For Zarr v3, configure sharding as a codec only
557+
shard_shape = sharding_kwargs.pop("_shard_shape")
558+
internal_chunk_shape = sharding_kwargs.get(
559+
"chunk_shape"
560+
) # This is the inner chunk shape
561+
562+
# Configure the sharding codec with proper defaults
563+
from zarr.codecs.sharding import ShardingCodec
564+
from zarr.codecs.bytes import BytesCodec
565+
from zarr.codecs.zstd import ZstdCodec
566+
567+
# Default inner codecs for sharding
568+
default_codecs = [BytesCodec(), ZstdCodec()]
569+
570+
# Ensure internal_chunk_shape is available; fallback to chunks if needed
571+
if internal_chunk_shape is None:
572+
internal_chunk_shape = chunks
573+
574+
# The array's chunk_shape should be the shard shape
575+
# The sharding codec's chunk_shape should be the internal chunk shape
576+
sharding_codec = ShardingCodec(
577+
chunk_shape=internal_chunk_shape, # Internal chunk shape within shards
578+
codecs=default_codecs,
579+
)
580+
581+
# Set up codecs with sharding
582+
existing_codecs = zarr_kwargs.get("codecs", [])
583+
if not isinstance(existing_codecs, list):
584+
existing_codecs = []
585+
codecs_kwargs["codecs"] = [sharding_codec] + existing_codecs
586+
587+
# Set the array's chunk_shape to the shard shape
588+
chunk_kwargs["chunk_shape"] = shard_shape
589+
590+
# Clean up remaining kwargs (remove chunk_shape since we're setting it explicitly)
591+
remaining_kwargs = {
592+
k: v
593+
for k, v in sharding_kwargs.items()
594+
if k not in ["_shard_shape", "chunk_shape"]
595+
}
596+
sharding_kwargs_clean = remaining_kwargs
597+
else:
598+
# For Zarr v2 or other cases
599+
sharding_kwargs_clean = sharding_kwargs
600+
else:
601+
# No sharding
602+
chunk_kwargs = {"chunks": chunks}
603+
sharding_kwargs_clean = {}
604+
451605
zarr_array = open_array(
452606
shape=arr.shape,
453-
chunks=chunks,
454607
dtype=arr.dtype,
455608
store=store,
456609
path=path,
457610
mode="a",
458-
**sharding_kwargs,
611+
**chunk_kwargs,
612+
**sharding_kwargs_clean,
459613
**zarr_kwargs,
614+
**codecs_kwargs,
460615
**dimension_names_kwargs,
461616
**format_kwargs,
462617
)
@@ -491,7 +646,7 @@ def _handle_large_array_writing(
491646
store_path,
492647
path,
493648
optimized,
494-
[c[0] for c in arr_region.chunks],
649+
chunks, # Use original array chunks, not region chunks
495650
shards,
496651
internal_chunk_shape,
497652
zarr_format,
@@ -830,18 +985,11 @@ def to_ngff_zarr(
830985
arr, chunks_per_shard, dims, kwargs.copy()
831986
)
832987

833-
# Get the chunks and optional shards for TensorStore
988+
# Get the chunks - these are now the shards if sharding is enabled
834989
chunks = tuple([c[0] for c in arr.chunks])
835-
shards = None
836-
if chunks_per_shard is not None:
837-
if isinstance(chunks_per_shard, int):
838-
shards = tuple([c * chunks_per_shard for c in chunks])
839-
elif isinstance(chunks_per_shard, (tuple, list)):
840-
shards = tuple([c * chunks_per_shard[i] for i, c in enumerate(chunks)])
841-
elif isinstance(chunks_per_shard, dict):
842-
shards = tuple(
843-
[c * chunks_per_shard.get(dims[i], 1) for i, c in enumerate(chunks)]
844-
)
990+
991+
# For TensorStore, shards are the same as chunks when sharding is enabled
992+
shards = chunks if chunks_per_shard is not None else None
845993

846994
# Determine write method based on memory requirements
847995
if memory_usage(image) > config.memory_target and multiscales.scale_factors:

0 commit comments

Comments
 (0)