Skip to content

Commit ecfb10f

Browse files
authored
Coerce args to map_blocks to arrays (#566)
1 parent c5a5d7c commit ecfb10f

File tree

2 files changed

+14
-0
lines changed

2 files changed

+14
-0
lines changed

cubed/core/ops.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -578,6 +578,13 @@ def map_blocks(
578578
) -> "Array":
579579
"""Apply a function to corresponding blocks from multiple input arrays."""
580580

581+
from cubed.array_api.creation_functions import asarray
582+
583+
# Coerce all args to Cubed arrays
584+
specs = [a.spec for a in args if hasattr(a, "spec")]
585+
spec0 = specs[0] if len(specs) > 0 else spec
586+
args = tuple(asarray(a, spec=spec0) for a in args)
587+
581588
# Handle the case where an array is created by calling `map_blocks` with no input arrays
582589
if len(args) == 0:
583590
from cubed.array_api.creation_functions import empty_virtual_array

cubed/tests/test_core.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,13 @@ def func(x, y):
235235
assert_array_equal(c.compute(), np.array([[[12, 13]]]))
236236

237237

238+
def test_map_blocks_with_non_cubed_array(spec):
239+
a = xp.arange(10, dtype="int64", chunks=(2,), spec=spec)
240+
b = np.array([1, 2], dtype="int64") # numpy array will be coerced to cubed
241+
c = cubed.map_blocks(nxp.add, a, b, dtype="int64")
242+
assert_array_equal(c.compute(), np.array([1, 3, 3, 5, 5, 7, 7, 9, 9, 11]))
243+
244+
238245
def test_multiple_ops(spec, executor):
239246
a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2), spec=spec)
240247
b = xp.asarray([[1, 1, 1], [1, 1, 1], [1, 1, 1]], chunks=(2, 2), spec=spec)

0 commit comments

Comments
 (0)