Skip to content

Commit df937c8

Browse files
committed
Add test for apply_ufunc
1 parent 4a66612 commit df937c8

File tree

1 file changed

+40
-0
lines changed

1 file changed

+40
-0
lines changed

cubed_xarray/tests/test_wrapping.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import pytest
55
import xarray as xr
66
from cubed.runtime.create import create_executor
7+
from numpy.testing import assert_array_equal
78
from xarray.namedarray.parallelcompat import list_chunkmanagers
89
from xarray.tests import assert_allclose, create_test_data
910

@@ -24,6 +25,17 @@ def executor(request):
2425
return request.param
2526

2627

28+
def assert_identical(a, b):
29+
"""A version of this function which accepts numpy arrays"""
30+
__tracebackhide__ = True
31+
from xarray.testing import assert_identical as assert_identical_
32+
33+
if hasattr(a, "identical"):
34+
assert_identical_(a, b)
35+
else:
36+
assert_array_equal(a, b)
37+
38+
2739
class TestDiscoverCubedManager:
2840
def test_list_cubedmanager(self):
2941
chunkmanagers = list_chunkmanagers()
@@ -72,3 +84,31 @@ def test_dataset_accessor_visualize(tmp_path):
7284
assert not (tmp_path / "cubed.svg").exists()
7385
ds.cubed.visualize(filename=tmp_path / "cubed")
7486
assert (tmp_path / "cubed.svg").exists()
87+
88+
89+
def identity(x):
90+
return x
91+
92+
93+
# based on test_apply_dask_parallelized_one_arg
94+
def test_apply_ufunc_parallelized_one_arg():
95+
array = cubed.ones((2, 2), chunks=(1, 1))
96+
data_array = xr.DataArray(array, dims=("x", "y"))
97+
98+
def parallel_identity(x):
99+
return xr.apply_ufunc(
100+
identity,
101+
x,
102+
output_dtypes=[x.dtype],
103+
dask="parallelized",
104+
dask_gufunc_kwargs={"allow_rechunk": False},
105+
)
106+
107+
actual = parallel_identity(data_array)
108+
assert isinstance(actual.data, cubed.Array)
109+
assert actual.data.chunks == array.chunks
110+
assert_identical(data_array, actual)
111+
112+
computed = data_array.compute()
113+
actual = parallel_identity(computed)
114+
assert_identical(computed, actual)

0 commit comments

Comments
 (0)