1
+ from __future__ import annotations
2
+
1
3
import functools
2
4
import sys
3
5
from abc import ABC , abstractmethod
10
12
from xarray .core import utils
11
13
from xarray .core .parallelcompat import ChunkManagerEntrypoint
12
14
from xarray .core .pycompat import is_chunked_array , is_duck_dask_array
13
- from xarray .core .types import T_Chunks , T_NormalizedChunks
14
-
15
- T_ChunkedArray = TypeVar ("T_ChunkedArray" )
16
15
17
- CHUNK_MANAGERS : dict [str , type ["ChunkManagerEntrypoint" ]] = {}
18
16
19
17
if TYPE_CHECKING :
20
- from xarray .core .types import CubedArray , ZarrArray
18
+ from xarray .core .types import T_Chunks , T_NormalizedChunks
19
+ from cubed import Array as CubedArray
21
20
22
21
23
22
class CubedManager (ChunkManagerEntrypoint ["CubedArray" ]):
24
23
array_cls : type ["CubedArray" ]
25
24
26
- def __init__ (self ):
25
+ def __init__ (self ) -> None :
27
26
from cubed import Array
28
27
29
28
self .array_cls = Array
@@ -33,15 +32,21 @@ def chunks(self, data: "CubedArray") -> T_NormalizedChunks:
33
32
34
33
def normalize_chunks (
35
34
self ,
36
- chunks : T_Chunks ,
37
- shape : Union [ tuple [int ], None ] = None ,
38
- limit : Union [ int , None ] = None ,
39
- dtype : Union [ np .dtype , None ] = None ,
40
- previous_chunks : T_NormalizedChunks = None ,
41
- ) -> tuple [ tuple [ int , ...], ...] :
35
+ chunks : T_Chunks | T_NormalizedChunks ,
36
+ shape : tuple [int , ...] | None = None ,
37
+ limit : int | None = None ,
38
+ dtype : np .dtype | None = None ,
39
+ previous_chunks : T_NormalizedChunks | None = None ,
40
+ ) -> T_NormalizedChunks :
42
41
from cubed .vendor .dask .array .core import normalize_chunks
43
42
44
- return normalize_chunks (chunks , shape = shape , limit = limit , dtype = dtype , previous_chunks = previous_chunks )
43
+ return normalize_chunks (
44
+ chunks ,
45
+ shape = shape ,
46
+ limit = limit ,
47
+ dtype = dtype ,
48
+ previous_chunks = previous_chunks ,
49
+ )
45
50
46
51
def from_array (self , data : np .ndarray , chunks , ** kwargs ) -> "CubedArray" :
47
52
from cubed import from_array
@@ -58,10 +63,7 @@ def from_array(self, data: np.ndarray, chunks, **kwargs) -> "CubedArray":
58
63
spec = spec ,
59
64
)
60
65
61
- def rechunk (self , data : "CubedArray" , chunks , ** kwargs ) -> "CubedArray" :
62
- return data .rechunk (chunks , ** kwargs )
63
-
64
- def compute (self , * data : "CubedArray" , ** kwargs ) -> np .ndarray :
66
+ def compute (self , * data : "CubedArray" , ** kwargs ) -> tuple [np .ndarray , ...]:
65
67
from cubed import compute
66
68
67
69
return compute (* data , ** kwargs )
@@ -74,14 +76,14 @@ def array_api(self) -> Any:
74
76
75
77
def reduction (
76
78
self ,
77
- arr : T_ChunkedArray ,
79
+ arr : "CubedArray" ,
78
80
func : Callable ,
79
- combine_func : Optional [ Callable ] = None ,
80
- aggregate_func : Optional [ Callable ] = None ,
81
- axis : Optional [ Union [ int , Sequence [int ]]] = None ,
82
- dtype : Optional [ np .dtype ] = None ,
81
+ combine_func : Callable | None = None ,
82
+ aggregate_func : Callable | None = None ,
83
+ axis : int | Sequence [int ] | None = None ,
84
+ dtype : np .dtype | None = None ,
83
85
keepdims : bool = False ,
84
- ) -> T_ChunkedArray :
86
+ ) -> "CubedArray" :
85
87
from cubed .core .ops import reduction
86
88
87
89
return reduction (
@@ -96,16 +98,21 @@ def reduction(
96
98
97
99
def map_blocks (
98
100
self ,
99
- func ,
100
- * args ,
101
- dtype = None ,
102
- chunks = None ,
103
- drop_axis = [] ,
104
- new_axis = None ,
101
+ func : Callable ,
102
+ * args : Any ,
103
+ dtype : np . typing . DTypeLike | None = None ,
104
+ chunks : tuple [ int , ...] | None = None ,
105
+ drop_axis : int | Sequence [ int ] | None = None ,
106
+ new_axis : int | Sequence [ int ] | None = None ,
105
107
** kwargs ,
106
108
):
107
109
from cubed .core .ops import map_blocks
108
110
111
+ if drop_axis is None :
112
+ # TODO should fix this upstream in cubed to match dask
113
+ # see https://github.com/pydata/xarray/pull/7019#discussion_r1196729489
114
+ drop_axis = []
115
+
109
116
return map_blocks (
110
117
func ,
111
118
* args ,
@@ -118,14 +125,14 @@ def map_blocks(
118
125
119
126
def blockwise (
120
127
self ,
121
- func ,
122
- out_ind ,
128
+ func : Callable ,
129
+ out_ind : Iterable ,
123
130
* args : Any ,
124
131
# can't type this as mypy assumes args are all same type, but blockwise args alternate types
125
- dtype = None ,
126
- adjust_chunks = None ,
127
- new_axes = None ,
128
- align_arrays = True ,
132
+ dtype : np . dtype | None = None ,
133
+ adjust_chunks : dict [ Any , Callable ] | None = None ,
134
+ new_axes : dict [ Any , int ] | None = None ,
135
+ align_arrays : bool = True ,
129
136
target_store = None ,
130
137
** kwargs ,
131
138
):
@@ -147,16 +154,16 @@ def blockwise(
147
154
148
155
def apply_gufunc (
149
156
self ,
150
- func ,
151
- signature ,
152
- * args ,
153
- axes = None ,
154
- axis = None ,
155
- keepdims = False ,
156
- output_dtypes = None ,
157
- output_sizes = None ,
158
- vectorize = None ,
159
- allow_rechunk = False ,
157
+ func : Callable ,
158
+ signature : str ,
159
+ * args : Any ,
160
+ axes : Sequence [ tuple [ int , ...]] | None = None ,
161
+ axis : int | None = None ,
162
+ keepdims : bool = False ,
163
+ output_dtypes : Sequence [ np . typing . DTypeLike ] | None = None ,
164
+ output_sizes : dict [ str , int ] | None = None ,
165
+ vectorize : bool | None = None ,
166
+ allow_rechunk : bool = False ,
160
167
** kwargs ,
161
168
):
162
169
if allow_rechunk :
@@ -181,17 +188,19 @@ def apply_gufunc(
181
188
)
182
189
183
190
def unify_chunks (
184
- self , * args , ** kwargs
185
- ) -> tuple [dict [str , T_Chunks ], list ["CubedArray" ]]:
191
+ self ,
192
+ * args : Any , # can't type this as mypy assumes args are all same type, but dask unify_chunks args alternate types
193
+ ** kwargs ,
194
+ ) -> tuple [dict [str , T_NormalizedChunks ], list ["CubedArray" ]]:
186
195
from cubed .core import unify_chunks
187
196
188
197
return unify_chunks (* args , ** kwargs )
189
198
190
199
def store (
191
200
self ,
192
201
sources : Union ["CubedArray" , Sequence ["CubedArray" ]],
193
- targets : Union [ "ZarrArray" , Sequence [ "ZarrArray" ]] ,
194
- ** kwargs : dict [ str , Any ] ,
202
+ targets : Any ,
203
+ ** kwargs ,
195
204
):
196
205
"""Used when writing to any backend."""
197
206
from cubed .core .ops import store
0 commit comments