Skip to content

Commit 81b7a41

Browse files
committed
sparse: improved _sputils signatures
1 parent ab0e391 commit 81b7a41

File tree

1 file changed

+65
-41
lines changed

1 file changed

+65
-41
lines changed

scipy-stubs/sparse/_sputils.pyi

Lines changed: 65 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
from collections.abc import Iterable, Sequence as Seq
2-
from typing import Any, Final, Literal, Protocol, TypeAlias, TypedDict, TypeVar, overload, type_check_only
1+
from collections.abc import Iterable
2+
from typing import Any, Final, Literal as L, Protocol, TypeAlias, TypedDict, TypeVar, overload, type_check_only
33
from typing_extensions import TypeIs
44

55
import numpy as np
66
import numpy.typing as npt
77
import optype as op
88
import optype.numpy as onp
9+
import optype.numpy.compat as npc
910
from scipy._typing import OrderKACF
1011
from scipy.sparse import (
1112
bsr_array,
@@ -19,7 +20,7 @@ from scipy.sparse import (
1920
dia_array,
2021
dia_matrix,
2122
)
22-
from scipy.sparse._typing import ToDType
23+
from ._typing import Scalar
2324

2425
__all__ = [
2526
"broadcast_shapes",
@@ -38,24 +39,20 @@ __all__ = [
3839
_ShapeT = TypeVar("_ShapeT", bound=tuple[int, ...], default=Any)
3940
_DTypeT = TypeVar("_DTypeT", bound=np.dtype[Any])
4041
_SCT = TypeVar("_SCT", bound=np.generic, default=Any)
41-
_IntT = TypeVar("_IntT", bound=np.integer[Any])
42-
_NonIntDTypeT = TypeVar(
43-
"_NonIntDTypeT",
44-
bound=np.dtype[np.inexact[Any] | np.flexible | np.datetime64 | np.timedelta64 | np.object_],
45-
)
42+
_IntT = TypeVar("_IntT", bound=npc.integer)
43+
_NonIntDTypeT = TypeVar("_NonIntDTypeT", bound=np.dtype[npc.inexact | np.flexible | np.datetime64 | np.timedelta64 | np.object_])
4644

47-
_SupportedScalar: TypeAlias = np.bool_ | np.integer[Any] | np.float32 | np.float64 | np.longdouble | np.complexfloating[Any, Any]
4845
_ShapeLike: TypeAlias = Iterable[op.CanIndex]
4946
_ScalarLike: TypeAlias = complex | bytes | str | np.generic | onp.Array0D
5047
_SequenceLike: TypeAlias = tuple[_ScalarLike, ...] | list[_ScalarLike] | onp.Array1D
5148
_MatrixLike: TypeAlias = tuple[_SequenceLike, ...] | list[_SequenceLike] | onp.Array2D
5249

53-
_ToArray: TypeAlias = onp.CanArrayND[_SCT, _ShapeT] | onp.SequenceND[_SCT | complex | bytes | str]
54-
_ToArray2D: TypeAlias = onp.CanArrayND[_SCT, _ShapeT] | Seq[Seq[_SCT | complex | bytes | str] | onp.CanArrayND[_SCT]]
50+
_IntP: TypeAlias = np.int32 | np.int64
51+
_UIntP: TypeAlias = np.uint32 | np.uint64
5552

5653
@type_check_only
5754
class _ReshapeKwargs(TypedDict, total=False):
58-
order: Literal["C", "F"]
55+
order: L["C", "F"]
5956
copy: bool
6057

6158
@type_check_only
@@ -65,18 +62,21 @@ class _SizedIndexIterable(Protocol):
6562

6663
###
6764

68-
supported_dtypes: Final[list[type[_SupportedScalar]]] = ...
65+
supported_dtypes: Final[list[type[Scalar]]] = ...
6966

7067
#
7168
# NOTE: Technically any `numpy.generic` could be returned, but we only care about the supported scalar types in `scipy.sparse`.
72-
def upcast(*args: npt.DTypeLike) -> _SupportedScalar: ...
73-
def upcast_char(*args: npt.DTypeLike) -> _SupportedScalar: ...
74-
def upcast_scalar(dtype: npt.DTypeLike, scalar: _ScalarLike) -> np.dtype[_SupportedScalar]: ...
69+
def upcast(*args: npt.DTypeLike) -> Scalar: ...
70+
def upcast_char(*args: npt.DTypeLike) -> Scalar: ...
71+
@overload
72+
def upcast_scalar(dtype: onp.ToDType[_SCT], scalar: onp.ToScalar) -> np.dtype[_SCT]: ...
73+
@overload
74+
def upcast_scalar(dtype: npt.DTypeLike, scalar: onp.ToScalar) -> np.dtype[Any]: ...
7575

7676
#
7777
def downcast_intp_index(
78-
arr: onp.Array[_ShapeT, np.bool_ | np.integer[Any] | np.floating[Any] | np.timedelta64 | np.object_],
79-
) -> onp.Array[_ShapeT, np.intp]: ...
78+
arr: onp.Array[_ShapeT, np.bool_ | npc.integer | npc.floating | np.timedelta64 | np.object_],
79+
) -> onp.Array[_ShapeT, _IntP]: ...
8080

8181
#
8282
@overload
@@ -88,29 +88,37 @@ def to_native(A: onp.HasDType[_DTypeT]) -> np.ndarray[Any, _DTypeT]: ...
8888

8989
#
9090
def getdtype(
91-
dtype: ToDType[_SCT] | None,
91+
dtype: onp.ToDType[_SCT] | None,
9292
a: onp.HasDType[np.dtype[_SCT]] | None = None,
93-
default: ToDType[_SCT] | None = None,
93+
default: onp.ToDType[_SCT] | None = None,
9494
) -> np.dtype[_SCT]: ...
9595

9696
#
9797
@overload
98-
def getdata(obj: _SCT | complex | bytes | str, dtype: ToDType[_SCT] | None = None, copy: bool = False) -> onp.Array0D[_SCT]: ...
98+
def getdata(obj: _SCT, dtype: onp.ToDType[_SCT] | None = None, copy: bool = False) -> onp.Array0D[_SCT]: ...
99+
@overload
100+
def getdata(obj: onp.ToComplex, dtype: onp.ToDType[_SCT], copy: bool = False) -> onp.Array0D[_SCT]: ...
99101
@overload
100-
def getdata(obj: _ToArray[_SCT, _ShapeT], dtype: ToDType[_SCT] | None = None, copy: bool = False) -> onp.Array[_ShapeT, _SCT]: ...
102+
def getdata(obj: onp.ToComplexStrict1D, dtype: onp.ToDType[_SCT], copy: bool = False) -> onp.Array1D[_SCT]: ...
103+
@overload
104+
def getdata(obj: onp.ToComplexStrict2D, dtype: onp.ToDType[_SCT], copy: bool = False) -> onp.Array2D[_SCT]: ...
105+
@overload
106+
def getdata(obj: onp.ToComplexStrict3D, dtype: onp.ToDType[_SCT], copy: bool = False) -> onp.Array3D[_SCT]: ...
107+
@overload
108+
def getdata(obj: onp.ToArrayND[_SCT, _SCT], dtype: onp.ToDType[_SCT] | None = None, copy: bool = False) -> onp.ArrayND[_SCT]: ...
101109

102110
#
103111
def get_index_dtype(
104112
arrays: tuple[onp.ToInt | onp.ToIntND, ...] = (),
105113
maxval: onp.ToFloat | None = None,
106114
check_contents: op.CanBool = False,
107-
) -> np.int32 | np.int64: ...
115+
) -> _IntP: ...
108116

109117
# NOTE: The inline annotations (`(np.dtype) -> np.dtype`) are incorrect.
110118
@overload
111-
def get_sum_dtype(dtype: np.dtype[np.unsignedinteger[Any]]) -> type[np.uint]: ...
119+
def get_sum_dtype(dtype: np.dtype[npc.unsignedinteger]) -> type[_UIntP]: ...
112120
@overload
113-
def get_sum_dtype(dtype: np.dtype[np.bool_ | np.signedinteger[Any]]) -> type[np.int_]: ...
121+
def get_sum_dtype(dtype: np.dtype[np.bool_ | npc.signedinteger]) -> type[_IntP]: ...
114122
@overload
115123
def get_sum_dtype(dtype: _NonIntDTypeT) -> _NonIntDTypeT: ...
116124

@@ -124,79 +132,95 @@ def ismatrix(t: object) -> TypeIs[_MatrixLike]: ...
124132
def isdense(x: object) -> TypeIs[onp.Array]: ...
125133

126134
#
127-
def validateaxis(axis: Literal[-2, -1, 0, 1] | bool | np.bool_ | np.integer[Any] | None) -> None: ...
135+
def validateaxis(axis: L[-2, -1, 0, 1] | bool | np.bool_ | npc.integer | None) -> None: ...
128136
def check_shape(
129137
args: _ShapeLike | tuple[_ShapeLike, ...],
130138
current_shape: tuple[int, ...] | None = None,
131139
*,
132140
allow_nd: tuple[int, ...] = (2,),
133141
) -> tuple[int, ...]: ...
134-
def check_reshape_kwargs(kwargs: _ReshapeKwargs) -> Literal["C", "F"] | bool: ...
142+
def check_reshape_kwargs(kwargs: _ReshapeKwargs) -> L["C", "F"] | bool: ...
135143

136144
#
137145
def matrix(
138-
object: _ToArray2D[_SCT],
139-
dtype: ToDType[_SCT] | type | str | None = None,
146+
object: onp.ToArray2D[_SCT],
147+
dtype: onp.ToDType[_SCT] | type | str | None = None,
140148
*,
141-
copy: Literal[0, 1, 2] | bool | None = True,
149+
copy: L[0, 1, 2] | bool | None = True,
142150
order: OrderKACF = "K",
143151
subok: bool = False,
144-
ndmin: Literal[0, 1, 2] = 0,
152+
ndmin: L[0, 1, 2] = 0,
145153
like: onp.CanArrayFunction | None = None,
146154
) -> onp.Matrix[_SCT]: ...
147155

148156
#
149-
def asmatrix(data: _ToArray2D[_SCT], dtype: ToDType[_SCT] | type | str | None = None) -> onp.Matrix[_SCT]: ...
157+
@overload
158+
def asmatrix(data: onp.ToArray2D[Any], dtype: onp.ToDType[_SCT]) -> onp.Matrix[_SCT]: ...
159+
@overload
160+
def asmatrix(data: onp.ToArray2D[_SCT], dtype: onp.ToDType[_SCT] | None = None) -> onp.Matrix[_SCT]: ...
161+
@overload
162+
def asmatrix(data: onp.ToArray2D[Any], dtype: npt.DTypeLike) -> onp.Matrix[Any]: ...
150163

151164
#
152165
@overload # BSR/CSC/CSR, dtype: <default>
153166
def safely_cast_index_arrays(
154167
A: bsr_array | bsr_matrix | csc_array | csc_matrix | csr_array | csr_matrix,
155-
idx_dtype: ToDType[np.int32] = ...,
168+
idx_dtype: onp.ToDType[np.int32] = ..., # = np.int32
156169
msg: str = "",
157170
) -> tuple[onp.Array1D[np.int32], onp.Array1D[np.int32]]: ...
158171
@overload # BSR/CSC/CSR, dtype: <known>
159172
def safely_cast_index_arrays(
160173
A: bsr_array | bsr_matrix | csc_array | csc_matrix | csr_array | csr_matrix,
161-
idx_dtype: ToDType[_IntT],
174+
idx_dtype: onp.ToDType[_IntT],
162175
msg: str = "",
163176
) -> tuple[onp.Array1D[_IntT], onp.Array1D[_IntT]]: ...
164177
@overload # 2d COO, dtype: <default>
165178
def safely_cast_index_arrays(
166179
A: coo_array[Any, tuple[int, int]] | coo_matrix,
167-
idx_dtype: ToDType[np.int32] = ...,
180+
idx_dtype: onp.ToDType[np.int32] = ..., # = np.int32
168181
msg: str = "",
169182
) -> tuple[onp.Array1D[np.int32], onp.Array1D[np.int32]]: ...
170183
@overload # 2d COO, dtype: <known>
171184
def safely_cast_index_arrays(
172185
A: coo_array[Any, tuple[int, int]] | coo_matrix,
173-
idx_dtype: ToDType[_IntT],
186+
idx_dtype: onp.ToDType[_IntT],
174187
msg: str = "",
175188
) -> tuple[onp.Array1D[_IntT], onp.Array1D[_IntT]]: ...
176189
@overload # nd COO, dtype: <default>
177190
def safely_cast_index_arrays(
178191
A: coo_array,
179-
idx_dtype: ToDType[np.int32] = ...,
192+
idx_dtype: onp.ToDType[np.int32] = ..., # = np.int32
180193
msg: str = "",
181194
) -> tuple[onp.Array1D[np.int32], ...]: ...
182195
@overload # nd COO, dtype: <known>
183196
def safely_cast_index_arrays(
184197
A: coo_array,
185-
idx_dtype: ToDType[_IntT],
198+
idx_dtype: onp.ToDType[_IntT],
186199
msg: str = "",
187200
) -> tuple[onp.Array1D[_IntT], ...]: ...
188201
@overload # DIA, dtype: <default>
189202
def safely_cast_index_arrays(
190203
A: dia_array | dia_matrix,
191-
idx_dtype: ToDType[np.int32] = ...,
204+
idx_dtype: onp.ToDType[np.int32] = ..., # = np.int32
192205
msg: str = "",
193206
) -> onp.Array1D[np.int32]: ...
194207
@overload # DIA, dtype: <known>
195208
def safely_cast_index_arrays(
196209
A: dia_array | dia_matrix,
197-
idx_dtype: ToDType[_IntT],
210+
idx_dtype: onp.ToDType[_IntT],
198211
msg: str = "",
199212
) -> onp.Array1D[_IntT]: ...
200213

201214
#
202-
def broadcast_shapes(*shapes: tuple[int, ...]) -> tuple[int, ...]: ...
215+
@overload
216+
def broadcast_shapes() -> tuple[()]: ...
217+
@overload
218+
def broadcast_shapes(shape0: tuple[()], /, *shapes: tuple[()]) -> tuple[()]: ...
219+
@overload
220+
def broadcast_shapes(shape0: tuple[int], /, *shapes: onp.AtMost1D) -> tuple[int]: ...
221+
@overload
222+
def broadcast_shapes(shape0: tuple[int, int], /, *shapes: onp.AtMost2D) -> tuple[int, int]: ...
223+
@overload
224+
def broadcast_shapes(shape0: tuple[int, int, int], /, *shapes: onp.AtMost3D) -> tuple[int, int, int]: ...
225+
@overload
226+
def broadcast_shapes(shape0: _ShapeT, /, *shapes: tuple[()] | _ShapeT) -> _ShapeT: ...

0 commit comments

Comments
 (0)