Skip to content

Commit 757c4be

Browse files
committed
Merge branch 'main' into pre-commit-ci-update-config
2 parents 62f6a4a + f4a3d84 commit 757c4be

File tree

2 files changed

+13
-11
lines changed

2 files changed

+13
-11
lines changed

src/pyramid_sampler/sampler.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -142,14 +142,16 @@ def _downsample_by_one_level(
142142
zarr_field: str,
143143
) -> None:
144144
level = coarse_level
145+
level_str = str(level)
145146
fine_level = level - 1
146-
lev_shape = self._get_level_shape(level)
147+
fine_lev_str = str(fine_level)
148+
lev_shape = self._get_level_shape(level).tolist()
147149

148150
field1 = zarr.open(self.zarr_store_path)[zarr_field]
149-
dtype = field1[fine_level].dtype
150-
field1.empty(level, shape=lev_shape, chunks=self.chunks, dtype=dtype)
151+
dtype = field1[fine_lev_str].dtype
152+
field1.empty(name=level_str, shape=lev_shape, chunks=self.chunks, dtype=dtype)
151153

152-
numchunks = field1[str(level)].nchunks
154+
numchunks = field1[level_str].nchunks
153155

154156
chunk_writes = []
155157
for ichunk in range(numchunks):
@@ -270,7 +272,7 @@ def initialize_test_image(
270272
"""
271273
if dtype is None:
272274
dtype = np.float64
273-
field1 = zarr_store.create_group(zarr_field, overwrite=overwrite_field)
275+
field1 = zarr_store.create_group(name=zarr_field, overwrite=overwrite_field)
274276

275277
if chunks is None:
276278
chunks = (64, 64, 64)
@@ -288,5 +290,5 @@ def initialize_test_image(
288290
lev0[0 : halfway[0], 0 : halfway[1], 0 : halfway[2]] = (
289291
lev0[0 : halfway[0], 0 : halfway[1], 0 : halfway[2]] + 0.5 * fac
290292
)
291-
field1.empty(0, shape=base_resolution, chunks=chunks, dtype=dtype)
293+
field1.empty(name="0", shape=base_resolution, chunks=chunks, dtype=dtype)
292294
da.to_zarr(lev0, field1["0"])

tests/test_sampler.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,13 @@ def test_initialize_test_image(tmp_path):
1818
initialize_test_image(zarr_store, fieldname, res, chunks, overwrite_field=False)
1919

2020
assert fieldname in zarr_store
21-
assert zarr_store[fieldname][0].shape == res
22-
assert zarr_store[fieldname][0].chunks == chunks
21+
assert zarr_store[fieldname]["0"].shape == res
22+
assert zarr_store[fieldname]["0"].chunks == chunks
2323
assert Path.exists(tmp_path / "myzarr.zarr" / fieldname)
2424

2525
res = (16, 16, 16)
2626
initialize_test_image(zarr_store, fieldname, res, chunks, overwrite_field=True)
27-
assert zarr_store[fieldname][0].shape == res
27+
assert zarr_store[fieldname]["0"].shape == res
2828

2929

3030
@pytest.mark.parametrize("dtype", ["float32", np.float64, "int", np.int32, np.int16])
@@ -43,8 +43,8 @@ def test_downsampler(tmp_path, dtype):
4343
dsr.downsample(10, fieldname)
4444
expected_max_lev = 2
4545
for lev in range(expected_max_lev + 1):
46-
assert lev in zarr_store[fieldname]
47-
assert zarr_store[fieldname][lev].dtype == np.dtype(dtype)
46+
assert str(lev) in zarr_store[fieldname]
47+
assert zarr_store[fieldname][str(lev)].dtype == np.dtype(dtype)
4848

4949
with pytest.raises(ValueError, match="max_level must exceed 0"):
5050
dsr.downsample(0, fieldname)

0 commit comments

Comments
 (0)