Skip to content

Commit 732a722

Browse files
committed
WIP: BUG: tensorstore chunk consistency
Re: #161
1 parent 2688b9c commit 732a722

File tree

2 files changed

+126
-3
lines changed

2 files changed

+126
-3
lines changed

ngff_zarr/to_ngff_zarr.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,13 +118,28 @@ def _write_with_tensorstore(
118118
internal_chunk_shape=None,
119119
full_array_shape=None,
120120
create_dataset=True,
121+
original_chunks=None, # Add parameter for consistent chunk tracking
121122
) -> None:
122123
"""Write array using tensorstore backend"""
123124
import tensorstore as ts
124125

125126
# Use full array shape if provided, otherwise use the region array shape
126127
dataset_shape = full_array_shape if full_array_shape is not None else array.shape
127128

129+
# Use original chunks for consistency if provided, otherwise use current chunks
130+
consistent_chunks = original_chunks if original_chunks is not None else chunks
131+
132+
# Validate chunk shapes to prevent zero or negative values
133+
validated_chunks = []
134+
for i, chunk_size in enumerate(consistent_chunks):
135+
if chunk_size <= 0:
136+
# Fallback to minimum valid chunk size or array dimension
137+
fallback_size = min(64, dataset_shape[i]) if dataset_shape else 64
138+
validated_chunks.append(fallback_size)
139+
else:
140+
validated_chunks.append(chunk_size)
141+
consistent_chunks = tuple(validated_chunks)
142+
128143
spec = {
129144
"kvstore": {
130145
"driver": "file",
@@ -136,14 +151,14 @@ def _write_with_tensorstore(
136151
}
137152
if zarr_format == 2:
138153
spec["driver"] = "zarr"
139-
spec["metadata"]["chunks"] = chunks
154+
spec["metadata"]["chunks"] = consistent_chunks
140155
spec["metadata"]["dimension_separator"] = "/"
141156
spec["metadata"]["dtype"] = array.dtype.str
142157
elif zarr_format == 3:
143158
spec["driver"] = "zarr3"
144159
spec["metadata"]["chunk_grid"] = {
145160
"name": "regular",
146-
"configuration": {"chunk_shape": chunks},
161+
"configuration": {"chunk_shape": consistent_chunks},
147162
}
148163
spec["metadata"]["data_type"] = _numpy_to_zarr_dtype(array.dtype)
149164
if dimension_names:
@@ -337,6 +352,7 @@ def _write_array_with_tensorstore(
337352
region: Tuple[slice, ...],
338353
full_array_shape: Optional[Tuple[int, ...]] = None,
339354
create_dataset: bool = True,
355+
original_chunks: Optional[Tuple[int, ...]] = None, # Add parameter
340356
**kwargs,
341357
) -> None:
342358
"""Write an array using the TensorStore backend."""
@@ -351,6 +367,7 @@ def _write_array_with_tensorstore(
351367
dimension_names=dimension_names,
352368
full_array_shape=full_array_shape,
353369
create_dataset=create_dataset,
370+
original_chunks=original_chunks,
354371
**kwargs,
355372
)
356373
else: # Sharding
@@ -364,6 +381,7 @@ def _write_array_with_tensorstore(
364381
internal_chunk_shape=internal_chunk_shape,
365382
full_array_shape=full_array_shape,
366383
create_dataset=create_dataset,
384+
original_chunks=original_chunks,
367385
**kwargs,
368386
)
369387

@@ -435,6 +453,7 @@ def _handle_large_array_writing(
435453
progress: Optional[Union[NgffProgress, NgffProgressCallback]],
436454
index: int,
437455
nscales: int,
456+
original_chunks: Tuple[int, ...], # Add parameter to track original chunks
438457
**kwargs,
439458
) -> None:
440459
"""Handle writing large arrays by splitting them into manageable pieces."""
@@ -498,6 +517,7 @@ def _handle_large_array_writing(
498517
region,
499518
full_array_shape=arr.shape,
500519
create_dataset=(region_index == 0), # Only create on first region
520+
original_chunks=original_chunks, # Pass original chunks for consistency
501521
**kwargs,
502522
)
503523
else:
@@ -818,6 +838,9 @@ def to_ngff_zarr(
818838
dim_factors = {d: 1 for d in dims}
819839
previous_dim_factors = dim_factors
820840

841+
# Capture original chunks before any rechunking operations
842+
original_chunks = tuple([c[0] for c in arr.chunks])
843+
821844
# Configure sharding if needed
822845
sharding_kwargs, internal_chunk_shape, arr = _configure_sharding(
823846
arr, chunks_per_shard, dims, kwargs.copy()
@@ -859,6 +882,7 @@ def to_ngff_zarr(
859882
progress,
860883
index,
861884
nscales,
885+
original_chunks, # Pass original chunks for consistency
862886
**kwargs,
863887
)
864888
else:
@@ -882,6 +906,7 @@ def to_ngff_zarr(
882906
region,
883907
full_array_shape=arr.shape,
884908
create_dataset=True, # Always create for small arrays
909+
original_chunks=original_chunks, # Pass original chunks for consistency
885910
**kwargs,
886911
)
887912
else:

test/test_to_ngff_zarr_tensorstore.py

Lines changed: 99 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,4 +142,102 @@ def test_tensorstore_already_exists_failure():
142142
end_time_write = time.time()
143143
logger.info(
144144
f" Zarr written in {end_time_write - start_time_write:.2f} seconds (UNEXPECTED SUCCESS)."
145-
)
145+
)
146+
147+
148+
def test_tensorstore_chunk_shape_consistency():
149+
"""Test that TensorStore handles chunk shape consistency for edge case dimensions."""
150+
# This reproduces the issue from #161 where array dimensions don't divide evenly by chunk size
151+
# Shape needs to be large enough to trigger regional writing and have uneven chunk divisions
152+
pytest.importorskip("tensorstore")
153+
154+
shape = (
155+
515,
156+
512,
157+
512,
158+
) # Large enough to trigger issue, with uneven division in first dimension
159+
test_array = np.random.rand(*shape).astype(np.float32)
160+
161+
image = to_ngff_image(
162+
test_array,
163+
dims=("z", "y", "x"),
164+
scale={"z": 1.0, "y": 1.0, "x": 1.0},
165+
)
166+
167+
multiscales = to_multiscales(
168+
image,
169+
method=Methods.ITKWASM_GAUSSIAN,
170+
cache=False,
171+
)
172+
173+
with tempfile.TemporaryDirectory() as tmpdir:
174+
# This should not fail with chunk shape mismatch errors
175+
to_ngff_zarr(
176+
store=tmpdir,
177+
multiscales=multiscales,
178+
use_tensorstore=True,
179+
version="0.5",
180+
)
181+
182+
183+
def test_tensorstore_chunk_shape_consistency_with_sharding():
184+
"""Test TensorStore with sharding and edge case dimensions."""
185+
pytest.importorskip("tensorstore")
186+
187+
shape = (
188+
515,
189+
512,
190+
512,
191+
) # Large enough to trigger issue, with uneven division in first dimension
192+
test_array = np.random.rand(*shape).astype(np.float32)
193+
194+
image = to_ngff_image(
195+
test_array,
196+
dims=("z", "y", "x"),
197+
scale={"z": 1.0, "y": 1.0, "x": 1.0},
198+
)
199+
200+
multiscales = to_multiscales(
201+
image,
202+
method=Methods.ITKWASM_GAUSSIAN,
203+
cache=False,
204+
)
205+
206+
with tempfile.TemporaryDirectory() as tmpdir:
207+
# This should not fail with chunk shape mismatch errors
208+
to_ngff_zarr(
209+
store=tmpdir,
210+
multiscales=multiscales,
211+
use_tensorstore=True,
212+
chunks_per_shard=2,
213+
version="0.5",
214+
)
215+
216+
217+
def test_tensorstore_zero_chunk_validation():
218+
"""Test that zero chunk sizes are properly validated."""
219+
pytest.importorskip("tensorstore")
220+
221+
shape = (513, 512, 512) # Large enough to trigger regional writing with edge chunks
222+
test_array = np.random.rand(*shape).astype(np.float32)
223+
224+
image = to_ngff_image(
225+
test_array,
226+
dims=("z", "y", "x"),
227+
scale={"z": 1.0, "y": 1.0, "x": 1.0},
228+
)
229+
230+
multiscales = to_multiscales(
231+
image,
232+
method=Methods.ITKWASM_GAUSSIAN,
233+
cache=False,
234+
)
235+
236+
with tempfile.TemporaryDirectory() as tmpdir:
237+
# This should not fail with zero chunk size errors
238+
to_ngff_zarr(
239+
store=tmpdir,
240+
multiscales=multiscales,
241+
use_tensorstore=True,
242+
version="0.5",
243+
)

0 commit comments

Comments
 (0)