4
4
import pytest
5
5
import xarray as xr
6
6
from cubed .runtime .create import create_executor
7
+ from numpy .testing import assert_array_equal
7
8
from xarray .namedarray .parallelcompat import list_chunkmanagers
8
9
from xarray .tests import assert_allclose , create_test_data
9
10
@@ -24,6 +25,17 @@ def executor(request):
24
25
return request .param
25
26
26
27
28
+ def assert_identical (a , b ):
29
+ """A version of this function which accepts numpy arrays"""
30
+ __tracebackhide__ = True
31
+ from xarray .testing import assert_identical as assert_identical_
32
+
33
+ if hasattr (a , "identical" ):
34
+ assert_identical_ (a , b )
35
+ else :
36
+ assert_array_equal (a , b )
37
+
38
+
27
39
class TestDiscoverCubedManager :
28
40
def test_list_cubedmanager (self ):
29
41
chunkmanagers = list_chunkmanagers ()
@@ -72,3 +84,31 @@ def test_dataset_accessor_visualize(tmp_path):
72
84
assert not (tmp_path / "cubed.svg" ).exists ()
73
85
ds .cubed .visualize (filename = tmp_path / "cubed" )
74
86
assert (tmp_path / "cubed.svg" ).exists ()
87
+
88
+
89
+ def identity (x ):
90
+ return x
91
+
92
+
93
+ # based on test_apply_dask_parallelized_one_arg
94
+ def test_apply_ufunc_parallelized_one_arg ():
95
+ array = cubed .ones ((2 , 2 ), chunks = (1 , 1 ))
96
+ data_array = xr .DataArray (array , dims = ("x" , "y" ))
97
+
98
+ def parallel_identity (x ):
99
+ return xr .apply_ufunc (
100
+ identity ,
101
+ x ,
102
+ output_dtypes = [x .dtype ],
103
+ dask = "parallelized" ,
104
+ dask_gufunc_kwargs = {"allow_rechunk" : False },
105
+ )
106
+
107
+ actual = parallel_identity (data_array )
108
+ assert isinstance (actual .data , cubed .Array )
109
+ assert actual .data .chunks == array .chunks
110
+ assert_identical (data_array , actual )
111
+
112
+ computed = data_array .compute ()
113
+ actual = parallel_identity (computed )
114
+ assert_identical (computed , actual )
0 commit comments