Skip to content

Commit 60ad9fd

Browse files
authored
Merge pull request #4 from data-exp-lab/dtype_handling
handle different dtypes
2 parents c37c6ac + 6bcded3 commit 60ad9fd

File tree

3 files changed

+27
-7
lines changed

3 files changed

+27
-7
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ ignore = [
131131
"PLR09", # Too many <...>
132132
"PLR2004", # Magic value used in comparison
133133
"ISC001", # Conflicts with formatter
134+
"SIM108", # Use ternary operator
134135
]
135136
isort.required-imports = ["from __future__ import annotations"]
136137
# Uncomment if using a _compat.typing backport

src/pyramid_sampler/sampler.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -132,8 +132,10 @@ def _downsample_by_one_level(
132132
level = coarse_level
133133
fine_level = level - 1
134134
lev_shape = self._get_level_shape(level)
135+
135136
field1 = zarr.open(self.zarr_store_path)[zarr_field]
136-
field1.empty(level, shape=lev_shape, chunks=self.chunks)
137+
dtype = field1[fine_level].dtype
138+
field1.empty(level, shape=lev_shape, chunks=self.chunks, dtype=dtype)
137139

138140
numchunks = field1[str(level)].nchunks
139141

@@ -197,7 +199,8 @@ def _write_chunk_values(
197199
)
198200

199201
coarse_zarr = zarr.open(zarr_file)[zarr_field][str(level)]
200-
coarse_zarr[si[0] : ei[0], si[1] : ei[1] :, si[2] : ei[2]] = outvals
202+
dtype = coarse_zarr.dtype
203+
coarse_zarr[si[0] : ei[0], si[1] : ei[1] :, si[2] : ei[2]] = outvals.astype(dtype)
201204

202205
return 1
203206

@@ -208,15 +211,27 @@ def initialize_test_image(
208211
base_resolution: tuple[int, int, int],
209212
chunks: int | tuple[int, int, int] | None = None,
210213
overwrite_field: bool = True,
214+
dtype: str | type | None = None,
211215
) -> None:
216+
if dtype is None:
217+
dtype = np.float64
212218
field1 = zarr_store.create_group(zarr_field, overwrite=overwrite_field)
213219

214220
if chunks is None:
215221
chunks = (64, 64, 64)
216-
lev0 = da.random.random(base_resolution, chunks=chunks)
222+
fac: int | float
223+
if np.issubdtype(dtype, np.integer):
224+
fac = 100
225+
elif np.issubdtype(dtype, np.floating):
226+
fac = 1.0
227+
else:
228+
msg = f"Unexpected dtype of {dtype}"
229+
raise RuntimeError(msg)
230+
lev0 = fac * da.random.random(base_resolution, chunks=chunks)
231+
lev0 = lev0.astype(dtype)
217232
halfway = np.asarray(base_resolution) // 2
218233
lev0[0 : halfway[0], 0 : halfway[1], 0 : halfway[2]] = (
219-
lev0[0 : halfway[0], 0 : halfway[1], 0 : halfway[2]] + 0.5
234+
lev0[0 : halfway[0], 0 : halfway[1], 0 : halfway[2]] + 0.5 * fac
220235
)
221-
field1.empty(0, shape=base_resolution, chunks=chunks)
236+
field1.empty(0, shape=base_resolution, chunks=chunks, dtype=dtype)
222237
da.to_zarr(lev0, field1["0"])

tests/test_sampler.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,20 +27,24 @@ def test_initialize_test_image(tmp_path):
2727
assert zarr_store[fieldname][0].shape == res
2828

2929

30-
def test_downsampler(tmp_path):
30+
@pytest.mark.parametrize("dtype", ["float32", np.float64, "int", np.int32, np.int16])
31+
def test_downsampler(tmp_path, dtype):
3132
tmp_zrr = str(tmp_path / "myzarr.zarr")
3233
zarr_store = zarr.open(tmp_zrr)
3334
res = (32, 32, 32)
3435
chunks = (8, 8, 8)
3536
fieldname = "test_field"
36-
initialize_test_image(zarr_store, fieldname, res, chunks, overwrite_field=False)
37+
initialize_test_image(
38+
zarr_store, fieldname, res, chunks, overwrite_field=False, dtype=dtype
39+
)
3740

3841
dsr = Downsampler(tmp_zrr, (2, 2, 2), res, chunks)
3942

4043
dsr.downsample(10, fieldname)
4144
expected_max_lev = 2
4245
for lev in range(expected_max_lev + 1):
4346
assert lev in zarr_store[fieldname]
47+
assert zarr_store[fieldname][lev].dtype == np.dtype(dtype)
4448

4549
with pytest.raises(ValueError, match="max_level must exceed 0"):
4650
dsr.downsample(0, fieldname)

0 commit comments

Comments
 (0)