1
+ import sqlite3
1
2
import sys
2
- from _typeshed import ReadableBuffer, Self, StrOrBytesPath
3
+ from _typeshed import ReadableBuffer, Self, StrOrBytesPath, SupportsLenAndGetItem
4
+ from collections.abc import Callable, Generator, Iterable, Iterator, Mapping
3
5
from datetime import date, datetime, time
4
6
from types import TracebackType
5
- from typing import Any, Callable, Generator, Iterable, Iterator , Protocol, TypeVar, overload
6
- from typing_extensions import Literal, final
7
+ from typing import Any, Generic , Protocol, TypeVar, overload
8
+ from typing_extensions import Literal, SupportsIndex, TypeAlias, final
7
9
8
10
_T = TypeVar("_T")
9
- _SqliteData = str | bytes | int | float | None
11
+ _T_co = TypeVar("_T_co", covariant=True)
12
+ _CursorT = TypeVar("_CursorT", bound=Cursor)
13
+ _SqliteData: TypeAlias = str | ReadableBuffer | int | float | None
14
+ # Data that is passed through adapters can be of any type accepted by an adapter.
15
+ _AdaptedInputData: TypeAlias = _SqliteData | Any
16
+ # The Mapping must really be a dict, but making it invariant is too annoying.
17
+ _Parameters: TypeAlias = SupportsLenAndGetItem[_AdaptedInputData] | Mapping[str, _AdaptedInputData]
18
+ _SqliteOutputData: TypeAlias = str | bytes | int | float | None
19
+ _Adapter: TypeAlias = Callable[[_T], _SqliteData]
20
+ _Converter: TypeAlias = Callable[[bytes], Any]
10
21
11
22
paramstyle: str
12
23
threadsafety: int
@@ -81,43 +92,39 @@ if sys.version_info >= (3, 7):
81
92
SQLITE_SELECT: int
82
93
SQLITE_TRANSACTION: int
83
94
SQLITE_UPDATE: int
84
- adapters: Any
85
- converters: Any
95
+ adapters: dict[tuple[type[ Any], type[Any]], _Adapter[Any]]
96
+ converters: dict[str, _Converter]
86
97
sqlite_version: str
87
98
version: str
88
99
89
- # TODO: adapt needs to get probed
90
- def adapt(obj, protocol, alternate): ...
100
+ # Can take or return anything depending on what's in the registry.
101
+ @overload
102
+ def adapt(__obj: Any, __proto: Any) -> Any: ...
103
+ @overload
104
+ def adapt(__obj: Any, __proto: Any, __alt: _T) -> Any | _T: ...
91
105
def complete_statement(statement: str) -> bool: ...
92
106
93
107
if sys.version_info >= (3, 7):
94
- def connect(
95
- database: StrOrBytesPath,
96
- timeout: float = ...,
97
- detect_types: int = ...,
98
- isolation_level: str | None = ...,
99
- check_same_thread: bool = ...,
100
- factory: type[Connection] | None = ...,
101
- cached_statements: int = ...,
102
- uri: bool = ...,
103
- ) -> Connection: ...
104
-
108
+ _DatabaseArg: TypeAlias = StrOrBytesPath
105
109
else:
106
- def connect(
107
- database: bytes | str,
108
- timeout: float = ...,
109
- detect_types: int = ...,
110
- isolation_level: str | None = ...,
111
- check_same_thread: bool = ...,
112
- factory: type[Connection] | None = ...,
113
- cached_statements: int = ...,
114
- uri: bool = ...,
115
- ) -> Connection: ...
110
+ _DatabaseArg: TypeAlias = bytes | str
116
111
112
+ def connect(
113
+ database: _DatabaseArg,
114
+ timeout: float = ...,
115
+ detect_types: int = ...,
116
+ isolation_level: str | None = ...,
117
+ check_same_thread: bool = ...,
118
+ factory: type[Connection] | None = ...,
119
+ cached_statements: int = ...,
120
+ uri: bool = ...,
121
+ ) -> Connection: ...
117
122
def enable_callback_tracebacks(__enable: bool) -> None: ...
123
+
124
+ # takes a pos-or-keyword argument because there is a C wrapper
118
125
def enable_shared_cache(enable: int) -> None: ...
119
- def register_adapter(__type: type[_T], __caster: Callable[[_T], int | float | str | bytes ]) -> None: ...
120
- def register_converter(__name: str, __converter: Callable[[bytes], Any] ) -> None: ...
126
+ def register_adapter(__type: type[_T], __caster: _Adapter[_T ]) -> None: ...
127
+ def register_converter(__name: str, __converter: _Converter ) -> None: ...
121
128
122
129
if sys.version_info < (3, 8):
123
130
class Cache:
@@ -126,7 +133,7 @@ if sys.version_info < (3, 8):
126
133
def get(self, *args, **kwargs) -> None: ...
127
134
128
135
class _AggregateProtocol(Protocol):
129
- def step(self, value : int) -> object: ...
136
+ def step(self, __value : int) -> object: ...
130
137
def finalize(self) -> int: ...
131
138
132
139
class _SingleParamWindowAggregateClass(Protocol):
@@ -148,22 +155,44 @@ class _WindowAggregateClass(Protocol):
148
155
def finalize(self) -> _SqliteData: ...
149
156
150
157
class Connection:
151
- DataError: Any
152
- DatabaseError: Any
153
- Error: Any
154
- IntegrityError: Any
155
- InterfaceError: Any
156
- InternalError: Any
157
- NotSupportedError: Any
158
- OperationalError: Any
159
- ProgrammingError: Any
160
- Warning: Any
161
- in_transaction: Any
162
- isolation_level: Any
158
+ @property
159
+ def DataError(self) -> type[sqlite3.DataError]: ...
160
+ @property
161
+ def DatabaseError(self) -> type[sqlite3.DatabaseError]: ...
162
+ @property
163
+ def Error(self) -> type[sqlite3.Error]: ...
164
+ @property
165
+ def IntegrityError(self) -> type[sqlite3.IntegrityError]: ...
166
+ @property
167
+ def InterfaceError(self) -> type[sqlite3.InterfaceError]: ...
168
+ @property
169
+ def InternalError(self) -> type[sqlite3.InternalError]: ...
170
+ @property
171
+ def NotSupportedError(self) -> type[sqlite3.NotSupportedError]: ...
172
+ @property
173
+ def OperationalError(self) -> type[sqlite3.OperationalError]: ...
174
+ @property
175
+ def ProgrammingError(self) -> type[sqlite3.ProgrammingError]: ...
176
+ @property
177
+ def Warning(self) -> type[sqlite3.Warning]: ...
178
+ @property
179
+ def in_transaction(self) -> bool: ...
180
+ isolation_level: str | None # one of '', 'DEFERRED', 'IMMEDIATE' or 'EXCLUSIVE'
181
+ @property
182
+ def total_changes(self) -> int: ...
163
183
row_factory: Any
164
184
text_factory: Any
165
- total_changes: Any
166
- def __init__(self, *args: Any, **kwargs: Any) -> None: ...
185
+ def __init__(
186
+ self,
187
+ database: _DatabaseArg,
188
+ timeout: float = ...,
189
+ detect_types: int = ...,
190
+ isolation_level: str | None = ...,
191
+ check_same_thread: bool = ...,
192
+ factory: type[Connection] | None = ...,
193
+ cached_statements: int = ...,
194
+ uri: bool = ...,
195
+ ) -> None: ...
167
196
def close(self) -> None: ...
168
197
if sys.version_info >= (3, 11):
169
198
def blobopen(self, __table: str, __column: str, __row: int, *, readonly: bool = ..., name: str = ...) -> Blob: ...
@@ -187,17 +216,21 @@ class Connection:
187
216
self, __name: str, __num_params: int, __aggregate_class: Callable[[], _WindowAggregateClass] | None
188
217
) -> None: ...
189
218
190
- def create_collation(self, __name: str, __callback: Any ) -> None: ...
219
+ def create_collation(self, __name: str, __callback: Callable[[str, str], int | SupportsIndex] | None ) -> None: ...
191
220
if sys.version_info >= (3, 8):
192
- def create_function(self, name: str, narg: int, func: Any, *, deterministic: bool = ...) -> None: ...
221
+ def create_function(
222
+ self, name: str, narg: int, func: Callable[..., _SqliteData], *, deterministic: bool = ...
223
+ ) -> None: ...
193
224
else:
194
- def create_function(self, name: str, num_params: int, func: Any ) -> None: ...
225
+ def create_function(self, name: str, num_params: int, func: Callable[..., _SqliteData] ) -> None: ...
195
226
196
- def cursor(self, cursorClass: type | None = ...) -> Cursor: ...
197
- def execute(self, sql: str, parameters: Iterable[Any] = ...) -> Cursor: ...
198
- # TODO: please check in executemany() if seq_of_parameters type is possible like this
199
- def executemany(self, __sql: str, __parameters: Iterable[Iterable[Any]]) -> Cursor: ...
200
- def executescript(self, __sql_script: bytes | str) -> Cursor: ...
227
+ @overload
228
+ def cursor(self, cursorClass: None = ...) -> Cursor: ...
229
+ @overload
230
+ def cursor(self, cursorClass: Callable[[], _CursorT]) -> _CursorT: ...
231
+ def execute(self, sql: str, parameters: _Parameters = ...) -> Cursor: ...
232
+ def executemany(self, __sql: str, __parameters: Iterable[_Parameters]) -> Cursor: ...
233
+ def executescript(self, __sql_script: str) -> Cursor: ...
201
234
def interrupt(self) -> None: ...
202
235
def iterdump(self) -> Generator[str, None, None]: ...
203
236
def rollback(self) -> None: ...
@@ -208,8 +241,8 @@ class Connection:
208
241
def set_trace_callback(self, trace_callback: Callable[[str], object] | None) -> None: ...
209
242
# enable_load_extension and load_extension is not available on python distributions compiled
210
243
# without sqlite3 loadable extension support. see footnotes https://docs.python.org/3/library/sqlite3.html#f1
211
- def enable_load_extension(self, enabled : bool) -> None: ...
212
- def load_extension(self, path : str) -> None: ...
244
+ def enable_load_extension(self, __enabled : bool) -> None: ...
245
+ def load_extension(self, __name : str) -> None: ...
213
246
if sys.version_info >= (3, 7):
214
247
def backup(
215
248
self,
@@ -226,29 +259,32 @@ class Connection:
226
259
def serialize(self, *, name: str = ...) -> bytes: ...
227
260
def deserialize(self, __data: ReadableBuffer, *, name: str = ...) -> None: ...
228
261
229
- def __call__(self, *args: Any, **kwargs: Any ) -> Any : ...
262
+ def __call__(self, __sql: str ) -> _Statement : ...
230
263
def __enter__(self: Self) -> Self: ...
231
264
def __exit__(
232
265
self, __type: type[BaseException] | None, __value: BaseException | None, __traceback: TracebackType | None
233
266
) -> Literal[False]: ...
234
267
235
268
class Cursor(Iterator[Any]):
236
- arraysize: Any
237
- connection: Any
238
- description: Any
239
- lastrowid: Any
240
- row_factory: Any
241
- rowcount: int
242
- # TODO: Cursor class accepts exactly 1 argument
243
- # required type is sqlite3.Connection (which is imported as _Connection)
244
- # however, the name of the __init__ variable is unknown
245
- def __init__(self, *args: Any, **kwargs: Any) -> None: ...
269
+ arraysize: int
270
+ @property
271
+ def connection(self) -> Connection: ...
272
+ @property
273
+ def description(self) -> tuple[tuple[str, None, None, None, None, None, None], ...] | None: ...
274
+ @property
275
+ def lastrowid(self) -> int | None: ...
276
+ row_factory: Callable[[Cursor, Row[Any]], object] | None
277
+ @property
278
+ def rowcount(self) -> int: ...
279
+ def __init__(self, __cursor: Connection) -> None: ...
246
280
def close(self) -> None: ...
247
- def execute(self, __sql: str, __parameters: Iterable[Any] = ...) -> Cursor : ...
248
- def executemany(self, __sql: str, __seq_of_parameters: Iterable[Iterable[Any]] ) -> Cursor : ...
249
- def executescript(self, __sql_script: bytes | str) -> Cursor: ...
281
+ def execute(self: Self , __sql: str, __parameters: _Parameters = ...) -> Self : ...
282
+ def executemany(self: Self , __sql: str, __seq_of_parameters: Iterable[_Parameters] ) -> Self : ...
283
+ def executescript(self, __sql_script: str) -> Cursor: ...
250
284
def fetchall(self) -> list[Any]: ...
251
285
def fetchmany(self, size: int | None = ...) -> list[Any]: ...
286
+ # Returns either a row (as created by the row_factory) or None, but
287
+ # putting None in the return annotation causes annoying false positives.
252
288
def fetchone(self) -> Any: ...
253
289
def setinputsizes(self, __sizes: object) -> None: ... # does nothing
254
290
def setoutputsize(self, __size: object, __column: object = ...) -> None: ... # does nothing
@@ -273,28 +309,37 @@ OptimizedUnicode = str
273
309
274
310
@final
275
311
class PrepareProtocol:
276
- def __init__(self, *args: Any , **kwargs: Any ) -> None: ...
312
+ def __init__(self, *args: object , **kwargs: object ) -> None: ...
277
313
278
314
class ProgrammingError(DatabaseError): ...
279
315
280
- class Row:
281
- def __init__(self, *args: Any, **kwargs: Any) -> None: ...
282
- def keys(self): ...
283
- def __eq__(self, __other): ...
284
- def __ge__(self, __other): ...
285
- def __getitem__(self, __index): ...
286
- def __gt__(self, __other): ...
287
- def __hash__(self): ...
288
- def __iter__(self): ...
289
- def __le__(self, __other): ...
290
- def __len__(self): ...
291
- def __lt__(self, __other): ...
292
- def __ne__(self, __other): ...
316
+ class Row(Generic[_T_co]):
317
+ def __init__(self, __cursor: Cursor, __data: tuple[_T_co, ...]) -> None: ...
318
+ def keys(self) -> list[str]: ...
319
+ @overload
320
+ def __getitem__(self, __index: int | str) -> _T_co: ...
321
+ @overload
322
+ def __getitem__(self, __index: slice) -> tuple[_T_co, ...]: ...
323
+ def __hash__(self) -> int: ...
324
+ def __iter__(self) -> Iterator[_T_co]: ...
325
+ def __len__(self) -> int: ...
326
+ # These return NotImplemented for anything that is not a Row.
327
+ def __eq__(self, __other: object) -> bool: ...
328
+ def __ge__(self, __other: object) -> bool: ...
329
+ def __gt__(self, __other: object) -> bool: ...
330
+ def __le__(self, __other: object) -> bool: ...
331
+ def __lt__(self, __other: object) -> bool: ...
332
+ def __ne__(self, __other: object) -> bool: ...
293
333
294
- if sys.version_info < (3, 8):
334
+ if sys.version_info >= (3, 8):
335
+ @final
336
+ class _Statement: ...
337
+
338
+ else:
295
339
@final
296
340
class Statement:
297
341
def __init__(self, *args, **kwargs): ...
342
+ _Statement: TypeAlias = Statement
298
343
299
344
class Warning(Exception): ...
300
345
0 commit comments