Skip to content

Commit 4f28e57

Browse files
committed
Ensure args are Cubed arrays in unify_chunks
1 parent d1c773b commit 4f28e57

File tree

1 file changed

+11
-1
lines changed

1 file changed

+11
-1
lines changed

cubed_xarray/cubedmanager.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99

1010
import numpy as np
1111

12+
from tlz import partition
13+
1214
from xarray.core import utils
1315
from xarray.core.parallelcompat import ChunkManagerEntrypoint
1416
from xarray.core.pycompat import is_chunked_array, is_duck_dask_array
@@ -192,9 +194,17 @@ def unify_chunks(
192194
*args: Any, # can't type this as mypy assumes args are all same type, but dask unify_chunks args alternate types
193195
**kwargs,
194196
) -> tuple[dict[str, T_NormalizedChunks], list["CubedArray"]]:
197+
from cubed.array_api import asarray
195198
from cubed.core import unify_chunks
196199

197-
return unify_chunks(*args, **kwargs)
200+
# Ensure that args are Cubed arrays. Note that we do this here and not in Cubed, following
201+
# https://numpy.org/neps/nep-0047-array-api-standard.html#the-asarray-asanyarray-pattern
202+
arginds = [
203+
(asarray(a) if ind is not None else a, ind) for a, ind in partition(2, args)
204+
]
205+
array_args = [item for pair in arginds for item in pair]
206+
207+
return unify_chunks(*array_args, **kwargs)
198208

199209
def store(
200210
self,

0 commit comments

Comments
 (0)