Skip to content

Commit 6a8f723

Browse files
committed
typing fixes to match xarray 7019 (now merged)
1 parent 3ca40ec commit 6a8f723

File tree

1 file changed

+58
-49
lines changed

1 file changed

+58
-49
lines changed

cubed_xarray/cubedmanager.py

Lines changed: 58 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import functools
24
import sys
35
from abc import ABC, abstractmethod
@@ -10,20 +12,17 @@
1012
from xarray.core import utils
1113
from xarray.core.parallelcompat import ChunkManagerEntrypoint
1214
from xarray.core.pycompat import is_chunked_array, is_duck_dask_array
13-
from xarray.core.types import T_Chunks, T_NormalizedChunks
14-
15-
T_ChunkedArray = TypeVar("T_ChunkedArray")
1615

17-
CHUNK_MANAGERS: dict[str, type["ChunkManagerEntrypoint"]] = {}
1816

1917
if TYPE_CHECKING:
20-
from xarray.core.types import CubedArray, ZarrArray
18+
from xarray.core.types import T_Chunks, T_NormalizedChunks
19+
from cubed import Array as CubedArray
2120

2221

2322
class CubedManager(ChunkManagerEntrypoint["CubedArray"]):
2423
array_cls: type["CubedArray"]
2524

26-
def __init__(self):
25+
def __init__(self) -> None:
2726
from cubed import Array
2827

2928
self.array_cls = Array
@@ -33,15 +32,21 @@ def chunks(self, data: "CubedArray") -> T_NormalizedChunks:
3332

3433
def normalize_chunks(
3534
self,
36-
chunks: T_Chunks,
37-
shape: Union[tuple[int], None] = None,
38-
limit: Union[int, None] = None,
39-
dtype: Union[np.dtype, None] = None,
40-
previous_chunks: T_NormalizedChunks = None,
41-
) -> tuple[tuple[int, ...], ...]:
35+
chunks: T_Chunks | T_NormalizedChunks,
36+
shape: tuple[int, ...] | None = None,
37+
limit: int | None = None,
38+
dtype: np.dtype | None = None,
39+
previous_chunks: T_NormalizedChunks | None = None,
40+
) -> T_NormalizedChunks:
4241
from cubed.vendor.dask.array.core import normalize_chunks
4342

44-
return normalize_chunks(chunks, shape=shape, limit=limit, dtype=dtype, previous_chunks=previous_chunks)
43+
return normalize_chunks(
44+
chunks,
45+
shape=shape,
46+
limit=limit,
47+
dtype=dtype,
48+
previous_chunks=previous_chunks,
49+
)
4550

4651
def from_array(self, data: np.ndarray, chunks, **kwargs) -> "CubedArray":
4752
from cubed import from_array
@@ -58,10 +63,7 @@ def from_array(self, data: np.ndarray, chunks, **kwargs) -> "CubedArray":
5863
spec=spec,
5964
)
6065

61-
def rechunk(self, data: "CubedArray", chunks, **kwargs) -> "CubedArray":
62-
return data.rechunk(chunks, **kwargs)
63-
64-
def compute(self, *data: "CubedArray", **kwargs) -> np.ndarray:
66+
def compute(self, *data: "CubedArray", **kwargs) -> tuple[np.ndarray, ...]:
6567
from cubed import compute
6668

6769
return compute(*data, **kwargs)
@@ -74,14 +76,14 @@ def array_api(self) -> Any:
7476

7577
def reduction(
7678
self,
77-
arr: T_ChunkedArray,
79+
arr: "CubedArray",
7880
func: Callable,
79-
combine_func: Optional[Callable] = None,
80-
aggregate_func: Optional[Callable] = None,
81-
axis: Optional[Union[int, Sequence[int]]] = None,
82-
dtype: Optional[np.dtype] = None,
81+
combine_func: Callable | None = None,
82+
aggregate_func: Callable | None = None,
83+
axis: int | Sequence[int] | None = None,
84+
dtype: np.dtype | None = None,
8385
keepdims: bool = False,
84-
) -> T_ChunkedArray:
86+
) -> "CubedArray":
8587
from cubed.core.ops import reduction
8688

8789
return reduction(
@@ -96,16 +98,21 @@ def reduction(
9698

9799
def map_blocks(
98100
self,
99-
func,
100-
*args,
101-
dtype=None,
102-
chunks=None,
103-
drop_axis=[],
104-
new_axis=None,
101+
func: Callable,
102+
*args: Any,
103+
dtype: np.typing.DTypeLike | None = None,
104+
chunks: tuple[int, ...] | None = None,
105+
drop_axis: int | Sequence[int] | None = None,
106+
new_axis: int | Sequence[int] | None = None,
105107
**kwargs,
106108
):
107109
from cubed.core.ops import map_blocks
108110

111+
if drop_axis is None:
112+
# TODO should fix this upstream in cubed to match dask
113+
# see https://github.com/pydata/xarray/pull/7019#discussion_r1196729489
114+
drop_axis = []
115+
109116
return map_blocks(
110117
func,
111118
*args,
@@ -118,14 +125,14 @@ def map_blocks(
118125

119126
def blockwise(
120127
self,
121-
func,
122-
out_ind,
128+
func: Callable,
129+
out_ind: Iterable,
123130
*args: Any,
124131
# can't type this as mypy assumes args are all same type, but blockwise args alternate types
125-
dtype=None,
126-
adjust_chunks=None,
127-
new_axes=None,
128-
align_arrays=True,
132+
dtype: np.dtype | None = None,
133+
adjust_chunks: dict[Any, Callable] | None = None,
134+
new_axes: dict[Any, int] | None = None,
135+
align_arrays: bool = True,
129136
target_store=None,
130137
**kwargs,
131138
):
@@ -147,16 +154,16 @@ def blockwise(
147154

148155
def apply_gufunc(
149156
self,
150-
func,
151-
signature,
152-
*args,
153-
axes=None,
154-
axis=None,
155-
keepdims=False,
156-
output_dtypes=None,
157-
output_sizes=None,
158-
vectorize=None,
159-
allow_rechunk=False,
157+
func: Callable,
158+
signature: str,
159+
*args: Any,
160+
axes: Sequence[tuple[int, ...]] | None = None,
161+
axis: int | None = None,
162+
keepdims: bool = False,
163+
output_dtypes: Sequence[np.typing.DTypeLike] | None = None,
164+
output_sizes: dict[str, int] | None = None,
165+
vectorize: bool | None = None,
166+
allow_rechunk: bool = False,
160167
**kwargs,
161168
):
162169
if allow_rechunk:
@@ -181,17 +188,19 @@ def apply_gufunc(
181188
)
182189

183190
def unify_chunks(
184-
self, *args, **kwargs
185-
) -> tuple[dict[str, T_Chunks], list["CubedArray"]]:
191+
self,
192+
*args: Any, # can't type this as mypy assumes args are all same type, but dask unify_chunks args alternate types
193+
**kwargs,
194+
) -> tuple[dict[str, T_NormalizedChunks], list["CubedArray"]]:
186195
from cubed.core import unify_chunks
187196

188197
return unify_chunks(*args, **kwargs)
189198

190199
def store(
191200
self,
192201
sources: Union["CubedArray", Sequence["CubedArray"]],
193-
targets: Union["ZarrArray", Sequence["ZarrArray"]],
194-
**kwargs: dict[str, Any],
202+
targets: Any,
203+
**kwargs,
195204
):
196205
"""Used when writing to any backend."""
197206
from cubed.core.ops import store

0 commit comments

Comments
 (0)