Skip to content

Commit 8ee626e

Browse files
committed
sparse: generic sparray
1 parent 4c116f1 commit 8ee626e

File tree

10 files changed

+47
-51
lines changed

10 files changed

+47
-51
lines changed

scipy-stubs/sparse/_base.pyi

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -707,7 +707,7 @@ class _spbase(Generic[_SCT_co, _ShapeT_co]):
707707
copy: bool = True,
708708
) -> Self: ...
709709
@overload # known type -> sparray
710-
def astype( # pyright: ignore[reportOverlappingOverload]
710+
def astype(
711711
self: bsr_array,
712712
/,
713713
dtype: onp.ToDType[_SCT],
@@ -1027,11 +1027,13 @@ class _spbase(Generic[_SCT_co, _ShapeT_co]):
10271027
@overload
10281028
def tolil(self: spmatrix, /, copy: bool = False) -> lil_matrix[_SCT_co]: ...
10291029

1030-
#
1030+
# NOTE: Don't do this; it's type-unsafe.
10311031
def resize(self, /, shape: ToShapeMin1D) -> None: ...
1032+
1033+
#
10321034
def setdiag(self, /, values: onp.ToComplex1D, k: int = 0) -> None: ...
10331035

1034-
class sparray: ...
1036+
class sparray(Generic[_SCT_co, _ShapeT_co]): ...
10351037

10361038
def issparse(x: object) -> TypeIs[_spbase]: ...
10371039
def isspmatrix(x: object) -> TypeIs[spmatrix]: ...

scipy-stubs/sparse/_bsr.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ class _bsr_base(_cs_matrix[_SCT, tuple[int, int]], _minmax_mixin[_SCT, tuple[int
134134
maxprint: int | None = None,
135135
) -> None: ...
136136

137-
class bsr_array(_bsr_base[_SCT], sparray, Generic[_SCT]): ...
137+
class bsr_array(_bsr_base[_SCT], sparray[_SCT, tuple[int, int]], Generic[_SCT]): ...
138138
class bsr_matrix(_bsr_base[_SCT], spmatrix[_SCT], Generic[_SCT]): ... # type: ignore[misc]
139139

140140
def isspmatrix_bsr(x: object) -> TypeIs[bsr_matrix]: ...

scipy-stubs/sparse/_construct.pyi

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -821,7 +821,10 @@ def bmat(blocks: Seq[Seq[spmatrix[_SCT]]], format: _FmtCOO | None = None, dtype:
821821
def bmat(blocks: _ToBlocks, format: _FmtCOO | None, dtype: onp.ToDType[_SCT]) -> _COOArray2D[_SCT] | coo_matrix[_SCT]: ...
822822
@overload # sparray, blocks: <unknown, unknown dtype>, format: <default>, dtype: <known> (keyword)
823823
def bmat(
824-
blocks: _ToBlocks, format: _FmtCOO | None = None, *, dtype: onp.ToDType[_SCT]
824+
blocks: _ToBlocks,
825+
format: _FmtCOO | None = None,
826+
*,
827+
dtype: onp.ToDType[_SCT],
825828
) -> _COOArray2D[_SCT] | coo_matrix[_SCT]: ...
826829
@overload # sparray, blocks: <unknown, unknown dtype>, format: <default>, dtype: <unknown>
827830
def bmat(blocks: _ToBlocks, format: _FmtCOO | None = None, dtype: npt.DTypeLike | None = None) -> _COOArray2D | coo_matrix: ...

scipy-stubs/sparse/_coo.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,7 @@ class _coo_base(_data_matrix[_SCT, _ShapeT_co], _minmax_mixin[_SCT, _ShapeT_co],
312312
@overload
313313
def tensordot(self: _spbase[_SupIntT], /, other: _JustND[int], axes: _Axes = 2) -> _ScalarOrDense[_SupIntT]: ...
314314

315-
class coo_array(_coo_base[_SCT, _ShapeT_co], sparray, Generic[_SCT, _ShapeT_co]): ...
315+
class coo_array(_coo_base[_SCT, _ShapeT_co], sparray[_SCT, _ShapeT_co], Generic[_SCT, _ShapeT_co]): ...
316316

317317
class coo_matrix(_coo_base[_SCT, tuple[int, int]], spmatrix[_SCT], Generic[_SCT]): # type: ignore[misc]
318318
@property

scipy-stubs/sparse/_csc.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ class _csc_base(_cs_matrix[_SCT, tuple[int, int]], Generic[_SCT]):
3232
@overload
3333
def count_nonzero(self, /, axis: op.CanIndex) -> onp.Array1D[np.intp]: ...
3434

35-
class csc_array(_csc_base[_SCT], sparray, Generic[_SCT]): ...
35+
class csc_array(_csc_base[_SCT], sparray[_SCT, tuple[int, int]], Generic[_SCT]): ...
3636

3737
class csc_matrix(_csc_base[_SCT], spmatrix[_SCT], Generic[_SCT]): # type: ignore[misc]
3838
# NOTE: using `@override` together with `@overload` causes stubtest to crash...

scipy-stubs/sparse/_csr.pyi

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,11 @@ class _csr_base(_cs_matrix[_SCT, _ShapeT_co], Generic[_SCT, _ShapeT_co]):
3030
@overload
3131
def count_nonzero(self: _csr_base[Any, tuple[int]], /, axis: op.CanIndex) -> int: ...
3232
@overload
33-
def count_nonzero(self: _csr_base[Any, tuple[int, int]], /, axis: op.CanIndex) -> onp.Array1D[np.intp]: ...
33+
def count_nonzero(self: _csr_base[Any, tuple[int, int]], /, axis: op.CanIndex) -> onp.Array1D[np.int32 | np.int64]: ...
3434
@overload
35-
def count_nonzero(self: csr_array, /, axis: op.CanIndex) -> int | onp.Array1D[np.intp]: ... # type: ignore[misc]
35+
def count_nonzero(self: csr_array, /, axis: op.CanIndex) -> int | onp.Array1D[np.int32 | np.int64]: ... # type: ignore[misc]
3636

37-
class csr_array(_csr_base[_SCT, _ShapeT_co], sparray, Generic[_SCT, _ShapeT_co]): ...
37+
class csr_array(_csr_base[_SCT, _ShapeT_co], sparray[_SCT, _ShapeT_co], Generic[_SCT, _ShapeT_co]): ...
3838

3939
class csr_matrix(_csr_base[_SCT, tuple[int, int]], spmatrix[_SCT], Generic[_SCT]): # type: ignore[misc]
4040
# NOTE: using `@override` together with `@overload` causes stubtest to crash...

scipy-stubs/sparse/_dia.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ class _dia_base(_data_matrix[_SCT, tuple[int, int]], Generic[_SCT]):
125125
maxprint: int | None = None,
126126
) -> None: ...
127127

128-
class dia_array(_dia_base[_SCT], sparray, Generic[_SCT]): ...
128+
class dia_array(_dia_base[_SCT], sparray[_SCT, tuple[int, int]], Generic[_SCT]): ...
129129
class dia_matrix(_dia_base[_SCT], spmatrix[_SCT], Generic[_SCT]): ... # type: ignore[misc]
130130

131131
def isspmatrix_dia(x: object) -> TypeIs[dia_matrix]: ...

scipy-stubs/sparse/_dok.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ class _dok_base(
258258
) -> _dok_base[np.complex128, _1D]: ...
259259

260260
#
261-
class dok_array(_dok_base[_SCT, _ShapeT_co], sparray, Generic[_SCT, _ShapeT_co]):
261+
class dok_array(_dok_base[_SCT, _ShapeT_co], sparray[_SCT, _ShapeT_co], Generic[_SCT, _ShapeT_co]):
262262
# NOTE: This horrible code duplication is required due to the lack of higher-kinded typing (HKT) support.
263263
# https://github.com/python/typing/issues/548
264264
@overload

scipy-stubs/sparse/_extract.pyi

Lines changed: 29 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ from typing import Literal, TypeAlias, overload
33
from typing_extensions import TypeVar
44

55
import optype.numpy as onp
6-
from ._base import _spbase
6+
from ._base import _spbase, sparray
77
from ._bsr import bsr_array, bsr_matrix
88
from ._coo import coo_array, coo_matrix
99
from ._csc import csc_array, csc_matrix
@@ -16,19 +16,10 @@ from ._typing import Index1D, Numeric, SPFormat
1616

1717
__all__ = ["find", "tril", "triu"]
1818

19-
_SCT = TypeVar("_SCT", bound=Numeric, default=Numeric)
20-
_ShapeT = TypeVar("_ShapeT", bound=tuple[int] | tuple[int, int], default=tuple[int] | tuple[int, int])
19+
###
2120

21+
_SCT = TypeVar("_SCT", bound=Numeric, default=Numeric)
2222
_ToDense: TypeAlias = onp.CanArrayND[_SCT] | Seq[_SCT] | Seq[Seq[_SCT] | onp.CanArrayND[_SCT]]
23-
_SpArray: TypeAlias = (
24-
bsr_array[_SCT]
25-
| coo_array[_SCT, _ShapeT]
26-
| csc_array[_SCT]
27-
| csr_array[_SCT, _ShapeT]
28-
| dia_array[_SCT]
29-
| dok_array[_SCT, _ShapeT]
30-
| lil_array[_SCT]
31-
)
3223

3324
###
3425

@@ -37,31 +28,31 @@ def find(A: _spbase[_SCT] | _ToDense) -> tuple[Index1D, Index1D, onp.Array1D[_SC
3728

3829
# NOTE: `tril` and `triu` have identical signatures
3930
@overload # sparray -> coo_array (default)
40-
def tril(A: _SpArray[_SCT], k: int = 0, format: Literal["coo"] | None = None) -> coo_array[_SCT, tuple[int, int]]: ...
31+
def tril(A: sparray[_SCT], k: int = 0, format: Literal["coo"] | None = None) -> coo_array[_SCT, tuple[int, int]]: ...
4132
@overload # sparray -> bsr_array (positional)
42-
def tril(A: _SpArray[_SCT], k: int, format: Literal["bsr"]) -> bsr_array[_SCT]: ...
33+
def tril(A: sparray[_SCT], k: int, format: Literal["bsr"]) -> bsr_array[_SCT]: ...
4334
@overload # sparray -> bsr_array (keyword)
44-
def tril(A: _SpArray[_SCT], k: int = 0, *, format: Literal["bsr"]) -> bsr_array[_SCT]: ...
35+
def tril(A: sparray[_SCT], k: int = 0, *, format: Literal["bsr"]) -> bsr_array[_SCT]: ...
4536
@overload # sparray -> csc_array (positional)
46-
def tril(A: _SpArray[_SCT], k: int, format: Literal["csc"]) -> csc_array[_SCT]: ...
37+
def tril(A: sparray[_SCT], k: int, format: Literal["csc"]) -> csc_array[_SCT]: ...
4738
@overload # sparray -> csc_array (keyword)
48-
def tril(A: _SpArray[_SCT], k: int = 0, *, format: Literal["csc"]) -> csc_array[_SCT]: ...
39+
def tril(A: sparray[_SCT], k: int = 0, *, format: Literal["csc"]) -> csc_array[_SCT]: ...
4940
@overload # sparray -> csr_array (positional)
50-
def tril(A: _SpArray[_SCT], k: int, format: Literal["csr"]) -> csr_array[_SCT, tuple[int, int]]: ...
41+
def tril(A: sparray[_SCT], k: int, format: Literal["csr"]) -> csr_array[_SCT, tuple[int, int]]: ...
5142
@overload # sparray -> csr_array (keyword)
52-
def tril(A: _SpArray[_SCT], k: int = 0, *, format: Literal["csr"]) -> csr_array[_SCT, tuple[int, int]]: ...
43+
def tril(A: sparray[_SCT], k: int = 0, *, format: Literal["csr"]) -> csr_array[_SCT, tuple[int, int]]: ...
5344
@overload # sparray -> dia_array (positional)
54-
def tril(A: _SpArray[_SCT], k: int, format: Literal["dia"]) -> dia_array[_SCT]: ...
45+
def tril(A: sparray[_SCT], k: int, format: Literal["dia"]) -> dia_array[_SCT]: ...
5546
@overload # sparray -> dia_array (keyword)
56-
def tril(A: _SpArray[_SCT], k: int = 0, *, format: Literal["dia"]) -> dia_array[_SCT]: ...
47+
def tril(A: sparray[_SCT], k: int = 0, *, format: Literal["dia"]) -> dia_array[_SCT]: ...
5748
@overload # sparray -> dok_array (positional)
58-
def tril(A: _SpArray[_SCT], k: int, format: Literal["dok"]) -> dok_array[_SCT, tuple[int, int]]: ...
49+
def tril(A: sparray[_SCT], k: int, format: Literal["dok"]) -> dok_array[_SCT, tuple[int, int]]: ...
5950
@overload # sparray -> dok_array (keyword)
60-
def tril(A: _SpArray[_SCT], k: int = 0, *, format: Literal["dok"]) -> dok_array[_SCT, tuple[int, int]]: ...
51+
def tril(A: sparray[_SCT], k: int = 0, *, format: Literal["dok"]) -> dok_array[_SCT, tuple[int, int]]: ...
6152
@overload # sparray -> lil_array (positional)
62-
def tril(A: _SpArray[_SCT], k: int, format: Literal["lil"]) -> lil_array[_SCT]: ...
53+
def tril(A: sparray[_SCT], k: int, format: Literal["lil"]) -> lil_array[_SCT]: ...
6354
@overload # sparray -> lil_array (keyword)
64-
def tril(A: _SpArray[_SCT], k: int = 0, *, format: Literal["lil"]) -> lil_array[_SCT]: ...
55+
def tril(A: sparray[_SCT], k: int = 0, *, format: Literal["lil"]) -> lil_array[_SCT]: ...
6556
@overload # spmatrix | array-like -> coo_matrix (default)
6657
def tril(A: spmatrix[_SCT] | _ToDense[_SCT], k: int = 0, format: Literal["coo"] | None = None) -> coo_matrix[_SCT]: ...
6758
@overload # spmatrix | array-like -> bsr_matrix (positional)
@@ -93,31 +84,31 @@ def tril(A: _spbase | onp.ToComplexND, k: int = 0, *, format: SPFormat | None =
9384

9485
#
9586
@overload # sparray -> coo_array (default)
96-
def triu(A: _SpArray[_SCT], k: int = 0, format: Literal["coo"] | None = None) -> coo_array[_SCT, tuple[int, int]]: ...
87+
def triu(A: sparray[_SCT], k: int = 0, format: Literal["coo"] | None = None) -> coo_array[_SCT, tuple[int, int]]: ...
9788
@overload # sparray -> bsr_array (positional)
98-
def triu(A: _SpArray[_SCT], k: int, format: Literal["bsr"]) -> bsr_array[_SCT]: ...
89+
def triu(A: sparray[_SCT], k: int, format: Literal["bsr"]) -> bsr_array[_SCT]: ...
9990
@overload # sparray -> bsr_array (keyword)
100-
def triu(A: _SpArray[_SCT], k: int = 0, *, format: Literal["bsr"]) -> bsr_array[_SCT]: ...
91+
def triu(A: sparray[_SCT], k: int = 0, *, format: Literal["bsr"]) -> bsr_array[_SCT]: ...
10192
@overload # sparray -> csc_array (positional)
102-
def triu(A: _SpArray[_SCT], k: int, format: Literal["csc"]) -> csc_array[_SCT]: ...
93+
def triu(A: sparray[_SCT], k: int, format: Literal["csc"]) -> csc_array[_SCT]: ...
10394
@overload # sparray -> csc_array (keyword)
104-
def triu(A: _SpArray[_SCT], k: int = 0, *, format: Literal["csc"]) -> csc_array[_SCT]: ...
95+
def triu(A: sparray[_SCT], k: int = 0, *, format: Literal["csc"]) -> csc_array[_SCT]: ...
10596
@overload # sparray -> csr_array (positional)
106-
def triu(A: _SpArray[_SCT], k: int, format: Literal["csr"]) -> csr_array[_SCT, tuple[int, int]]: ...
97+
def triu(A: sparray[_SCT], k: int, format: Literal["csr"]) -> csr_array[_SCT, tuple[int, int]]: ...
10798
@overload # sparray -> csr_array (keyword)
108-
def triu(A: _SpArray[_SCT], k: int = 0, *, format: Literal["csr"]) -> csr_array[_SCT, tuple[int, int]]: ...
99+
def triu(A: sparray[_SCT], k: int = 0, *, format: Literal["csr"]) -> csr_array[_SCT, tuple[int, int]]: ...
109100
@overload # sparray -> dia_array (positional)
110-
def triu(A: _SpArray[_SCT], k: int, format: Literal["dia"]) -> dia_array[_SCT]: ...
101+
def triu(A: sparray[_SCT], k: int, format: Literal["dia"]) -> dia_array[_SCT]: ...
111102
@overload # sparray -> dia_array (keyword)
112-
def triu(A: _SpArray[_SCT], k: int = 0, *, format: Literal["dia"]) -> dia_array[_SCT]: ...
103+
def triu(A: sparray[_SCT], k: int = 0, *, format: Literal["dia"]) -> dia_array[_SCT]: ...
113104
@overload # sparray -> dok_array (positional)
114-
def triu(A: _SpArray[_SCT], k: int, format: Literal["dok"]) -> dok_array[_SCT, tuple[int, int]]: ...
105+
def triu(A: sparray[_SCT], k: int, format: Literal["dok"]) -> dok_array[_SCT, tuple[int, int]]: ...
115106
@overload # sparray -> dok_array (keyword)
116-
def triu(A: _SpArray[_SCT], k: int = 0, *, format: Literal["dok"]) -> dok_array[_SCT, tuple[int, int]]: ...
107+
def triu(A: sparray[_SCT], k: int = 0, *, format: Literal["dok"]) -> dok_array[_SCT, tuple[int, int]]: ...
117108
@overload # sparray -> lil_array (positional)
118-
def triu(A: _SpArray[_SCT], k: int, format: Literal["lil"]) -> lil_array[_SCT]: ...
109+
def triu(A: sparray[_SCT], k: int, format: Literal["lil"]) -> lil_array[_SCT]: ...
119110
@overload # sparray -> lil_array (keyword)
120-
def triu(A: _SpArray[_SCT], k: int = 0, *, format: Literal["lil"]) -> lil_array[_SCT]: ...
111+
def triu(A: sparray[_SCT], k: int = 0, *, format: Literal["lil"]) -> lil_array[_SCT]: ...
121112
@overload # spmatrix | array-like -> coo_matrix (default)
122113
def triu(A: spmatrix[_SCT] | _ToDense[_SCT], k: int = 0, format: Literal["coo"] | None = None) -> coo_matrix[_SCT]: ...
123114
@overload # spmatrix | array-like -> bsr_matrix (positional)

scipy-stubs/sparse/_lil.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ class _lil_base(_spbase[_SCT, tuple[int, int]], IndexMixin[_SCT, tuple[int, int]
155155
def getrowview(self, /, i: int) -> Self: ...
156156
def getrow(self, /, i: onp.ToJustInt) -> csr_array[_SCT, tuple[int, int]] | csr_matrix[_SCT]: ...
157157

158-
class lil_array(_lil_base[_SCT], sparray, Generic[_SCT]):
158+
class lil_array(_lil_base[_SCT], sparray[_SCT, tuple[int, int]], Generic[_SCT]):
159159
@override
160160
def getrow(self, /, i: onp.ToJustInt) -> csr_array[_SCT, tuple[int, int]]: ...
161161

0 commit comments

Comments
 (0)