Skip to content

Commit a21a5fa

Browse files
authored
Incorrect dtypes in map_selection (#669)
1 parent 6dc00e1 commit a21a5fa

File tree

2 files changed

+20
-2
lines changed

2 files changed

+20
-2
lines changed

cubed/core/ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -816,9 +816,9 @@ def key_function(out_key):
816816
key_function,
817817
x,
818818
shapes=[shape],
819-
dtypes=[x.dtype],
819+
dtypes=[dtype],
820820
chunkss=[chunks],
821-
extra_func_kwargs=dict(func=func, dtype=dtype),
821+
extra_func_kwargs=dict(func=func, dtype=x.dtype),
822822
num_input_blocks=num_input_blocks,
823823
iterable_input_blocks=iterable_input_blocks,
824824
selection_function=selection_function,

cubed/tests/test_overlap.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,24 @@ def test_map_overlap_1d_single_chunk():
3939
assert_array_equal(b.compute(), np.array([0, 0, 1, 2, 3, 4, 5, 0]))
4040

4141

42+
def test_map_overlap_1d_change_dtype():
43+
x = np.arange(6)
44+
a = xp.asarray(x, chunks=(3,))
45+
46+
b = cubed.map_overlap(
47+
lambda x: x.astype(np.float64),
48+
a,
49+
dtype=np.float64,
50+
chunks=((5, 5),),
51+
depth=1,
52+
boundary=0,
53+
trim=False,
54+
)
55+
56+
assert b.dtype == np.float64
57+
assert_array_equal(b.compute(), np.array([0, 0, 1, 2, 3, 2, 3, 4, 5, 0]))
58+
59+
4260
def test_map_overlap_2d():
4361
x = np.arange(36).reshape((6, 6))
4462
a = xp.asarray(x, chunks=(3, 3))

0 commit comments

Comments
 (0)