diff --git a/src/anndata/_core/merge.py b/src/anndata/_core/merge.py index 71a1bf8e2..178062430 100644 --- a/src/anndata/_core/merge.py +++ b/src/anndata/_core/merge.py @@ -10,7 +10,7 @@ from functools import partial, reduce, singledispatch from itertools import repeat from operator import and_, or_, sub -from typing import Literal, TypeVar +from typing import Generic, Literal, TypeVar from warnings import warn import numpy as np @@ -36,11 +36,16 @@ from .index import _subset, make_slice if typing.TYPE_CHECKING: - from collections.abc import Collection, Iterable, Sequence - from typing import Any + from collections.abc import Collection, Iterable, Iterator, Sequence + from typing import Any, Self, TypeGuard from pandas.api.extensions import ExtensionDtype + from anndata._core.aligned_mapping import AlignedMappingBase + + _Array = SpArray | sparse.csr_matrix | sparse.csc_matrix | np.ndarray | DaskArray + +K = TypeVar("K") T = TypeVar("T") ################### @@ -49,40 +54,40 @@ # Pretty much just for maintaining order of keys -class OrderedSet(MutableSet): - def __init__(self, vals=()): +class OrderedSet(MutableSet, Generic[T]): + def __init__(self, vals: Iterable[T] = ()) -> None: self.dict = OrderedDict(zip(vals, repeat(None))) - def __contains__(self, val): + def __contains__(self, val: object) -> bool: return val in self.dict - def __iter__(self): + def __iter__(self) -> Iterator[T]: return iter(self.dict) - def __len__(self): + def __len__(self) -> int: return len(self.dict) - def __repr__(self): + def __repr__(self) -> str: return "OrderedSet: {" + ", ".join(map(str, self)) + "}" - def copy(self): - return OrderedSet(self.dict.copy()) + def copy(self) -> Self: + return type(self)(self.dict.copy()) - def add(self, val): + def add(self, val: T) -> None: self.dict[val] = None - def union(self, *vals) -> OrderedSet: + def union(self, *vals: Iterable[T]) -> Self: return reduce(or_, vals, self) - def discard(self, val): + def discard(self, val: T) -> None: if val in self: del self.dict[val] - def difference(self, *vals) -> OrderedSet: + def difference(self, *vals: Iterable[T]) -> Self: return reduce(sub, vals, self) -def union_keys(ds: Collection) -> OrderedSet: +def union_keys(ds: Collection[Iterable[T]]) -> OrderedSet[T]: return reduce(or_, ds, OrderedSet()) @@ -94,11 +99,11 @@ class MissingVal: """Represents a missing value.""" -def is_missing(v) -> bool: +def is_missing(v: object | MissingVal) -> TypeGuard[MissingVal]: return v is MissingVal -def not_missing(v) -> bool: +def not_missing(v: T | MissingVal) -> TypeGuard[T]: return v is not MissingVal @@ -327,7 +332,7 @@ def check_combinable_cols(cols: list[pd.Index], join: Literal["inner", "outer"]) # TODO: open PR or feature request to cupy -def _cp_block_diag(mats, format=None, dtype=None): +def _cp_block_diag(mats: Iterable[CupyArray], format=None, dtype=None): """ Modified version of scipy.sparse.block_diag for cupy sparse. """ @@ -363,7 +368,7 @@ def _cp_block_diag(mats, format=None, dtype=None): ).asformat(format) -def _dask_block_diag(mats): +def _dask_block_diag(mats: list[DaskArray]) -> DaskArray: from itertools import permutations import dask.array as da @@ -511,7 +516,7 @@ class Reindexer: Together with `old_pos` this forms a mapping. """ - def __init__(self, old_idx, new_idx): + def __init__(self, old_idx: pd.Index, new_idx: pd.Index): self.old_idx = old_idx self.new_idx = new_idx self.no_change = new_idx.equals(old_idx) @@ -524,10 +529,14 @@ def __init__(self, old_idx, new_idx): self.new_pos = new_pos[mask] self.old_pos = old_pos[mask] - def __call__(self, el, *, axis=1, fill_value=None): + def __call__( + self, el: _Array, *, axis: Literal[0, 1] = 1, fill_value: object | None = None + ) -> _Array: return self.apply(el, axis=axis, fill_value=fill_value) - def apply(self, el, *, axis, fill_value=None): + def apply( + self, el: _Array, *, axis: Literal[0, 1], fill_value: object | None = None + ) -> _Array: """ Reindex element so el[axis] is aligned to self.new_idx. @@ -724,7 +733,7 @@ def merge_indices( raise ValueError(msg) -def default_fill_value(els): +def default_fill_value(els: Iterable[_Array]) -> int | float: """Given some arrays, returns what the default fill value should be. This is largely due to backwards compat, and might not be the ideal solution. @@ -742,7 +751,7 @@ def default_fill_value(els): return np.nan -def gen_reindexer(new_var: pd.Index, cur_var: pd.Index): +def gen_reindexer(new_var: pd.Index, cur_var: pd.Index) -> Reindexer: """ Given a new set of var_names, and a current set, generates a function which will reindex a matrix to be aligned with the new set. @@ -763,14 +772,20 @@ def gen_reindexer(new_var: pd.Index, cur_var: pd.Index): return Reindexer(cur_var, new_var) -def np_bool_to_pd_bool_array(df: pd.DataFrame): +def np_bool_to_pd_bool_array(df: pd.DataFrame) -> pd.DataFrame: for col_name, col_type in dict(df.dtypes).items(): if col_type is np.dtype(bool): df[col_name] = pd.array(df[col_name].values) return df -def concat_arrays(arrays, reindexers, axis=0, index=None, fill_value=None): +def concat_arrays( + arrays: Iterable[_Array], + reindexers: Sequence[Reindexer] | Sequence[Callable[[_Array], _Array]], + axis: Literal[0, 1] = 0, + index: pd.Index | None = None, + fill_value: object | None = None, +): arrays = list(arrays) if fill_value is None: fill_value = default_fill_value(arrays) @@ -897,7 +912,7 @@ def gen_inner_reindexers(els, new_index, axis: Literal[0, 1] = 0): return reindexers -def gen_outer_reindexers(els, shapes, new_index: pd.Index, *, axis=0): +def gen_outer_reindexers(els, shapes) -> list[Reindexer] | list[Callable[[T], T]]: if all(isinstance(el, pd.DataFrame) for el in els if not_missing(el)): reindexers = [ (lambda x: x) @@ -941,7 +956,7 @@ def gen_outer_reindexers(els, shapes, new_index: pd.Index, *, axis=0): def missing_element( n: int, - els: list[SpArray | sparse.csr_matrix | sparse.csc_matrix | np.ndarray | DaskArray], + els: list[_Array], axis: Literal[0, 1] = 0, fill_value: Any | None = None, off_axis_size: int = 0, @@ -960,15 +975,20 @@ def missing_element( def outer_concat_aligned_mapping( - mappings, *, reindexers=None, index=None, axis=0, fill_value=None -): - result = {} + mappings: Collection[AlignedMappingBase], + *, + reindexers: Sequence[Reindexer] | Sequence[Callable[[T], T]] | None = None, + index: pd.Index | None = None, + axis: Literal[0, 1] = 0, + fill_value: object | None = None, +) -> dict[str, _Array]: + result: dict[str, _Array] = {} ns = [m.parent.shape[axis] for m in mappings] for k in union_keys(mappings): els = [m.get(k, MissingVal) for m in mappings] if reindexers is None: - cur_reindexers = gen_outer_reindexers(els, ns, new_index=index, axis=axis) + cur_reindexers = gen_outer_reindexers(els, ns) else: cur_reindexers = reindexers @@ -1004,8 +1024,8 @@ def outer_concat_aligned_mapping( def concat_pairwise_mapping( mappings: Collection[Mapping], shapes: Collection[int], join_keys=intersect_keys -): - result = {} +) -> dict[str, _Array]: + result: dict[str, _Array] = {} if any(any(isinstance(v, SpArray) for v in m.values()) for m in mappings): sparse_class = sparse.csr_array else: @@ -1067,7 +1087,12 @@ def axis_indices(adata: AnnData, axis: Literal["obs", 0, "var", 1]) -> pd.Index: # TODO: Resolve https://github.com/scverse/anndata/issues/678 and remove this function -def concat_Xs(adatas, reindexers, axis, fill_value): +def concat_Xs( + adatas: Iterable[AnnData], + reindexers: Sequence[Reindexer] | Sequence[Callable[[_Array], _Array]], + axis: Literal[0, 1], + fill_value: object | None, +): """ Shimy until support for some missing X's is implemented.