diff --git a/Makefile b/Makefile index 59d08bac..9af372f7 100644 --- a/Makefile +++ b/Makefile @@ -24,7 +24,7 @@ tests_websockets: pytest tests --websockets-only check: - isort --recursive $(SRC_PYTHON) + isort $(SRC_PYTHON) black $(SRC_PYTHON) flake8 $(SRC_PYTHON) mypy $(SRC_PYTHON) diff --git a/docs/code_examples/aiohttp_async_dsl.py b/docs/code_examples/aiohttp_async_dsl.py index 958ea490..2c4804db 100644 --- a/docs/code_examples/aiohttp_async_dsl.py +++ b/docs/code_examples/aiohttp_async_dsl.py @@ -17,6 +17,8 @@ async def main(): # GQL will fetch the schema just after the establishment of the first session async with client as session: + assert client.schema is not None + # Instantiate the root of the DSL Schema as ds ds = DSLSchema(client.schema) diff --git a/docs/code_examples/console_async.py b/docs/code_examples/console_async.py index 9a5e94e5..6c0b86d0 100644 --- a/docs/code_examples/console_async.py +++ b/docs/code_examples/console_async.py @@ -1,8 +1,11 @@ import asyncio import logging +from typing import Optional from aioconsole import ainput + from gql import Client, gql +from gql.client import AsyncClientSession from gql.transport.aiohttp import AIOHTTPTransport logging.basicConfig(level=logging.INFO) @@ -21,7 +24,7 @@ def __init__(self): self._client = Client( transport=AIOHTTPTransport(url="https://countries.trevorblades.com/") ) - self._session = None + self._session: Optional[AsyncClientSession] = None self.get_continent_name_query = gql(GET_CONTINENT_NAME) @@ -34,11 +37,13 @@ async def close(self): async def get_continent_name(self, code): params = {"code": code} + assert self._session is not None + answer = await self._session.execute( self.get_continent_name_query, variable_values=params ) - return answer.get("continent").get("name") + return answer.get("continent").get("name") # type: ignore async def main(): diff --git a/docs/code_examples/fastapi_async.py b/docs/code_examples/fastapi_async.py index 3bedd187..f4a5c14b 100644 --- a/docs/code_examples/fastapi_async.py +++ b/docs/code_examples/fastapi_async.py @@ -12,6 +12,7 @@ from fastapi.responses import HTMLResponse from gql import Client, gql +from gql.client import ReconnectingAsyncClientSession from gql.transport.aiohttp import AIOHTTPTransport logging.basicConfig(level=logging.DEBUG) @@ -91,6 +92,7 @@ async def get_continent(continent_code): raise HTTPException(status_code=404, detail="Continent not found") try: + assert isinstance(client.session, ReconnectingAsyncClientSession) result = await client.session.execute( query, variable_values={"code": continent_code} ) diff --git a/docs/conf.py b/docs/conf.py index 94daf942..8289ef4b 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -83,11 +83,11 @@ intersphinx_mapping = { 'aiohttp': ('https://docs.aiohttp.org/en/stable/', None), 'graphql': ('https://graphql-core-3.readthedocs.io/en/latest/', None), - 'multidict': ('https://multidict.readthedocs.io/en/stable/', None), + 'multidict': ('https://multidict.aio-libs.org/en/stable/', None), 'python': ('https://docs.python.org/3/', None), 'requests': ('https://requests.readthedocs.io/en/latest/', None), 'websockets': ('https://websockets.readthedocs.io/en/11.0.3/', None), - 'yarl': ('https://yarl.readthedocs.io/en/stable/', None), + 'yarl': ('https://yarl.aio-libs.org/en/stable/', None), } nitpick_ignore = [ @@ -100,6 +100,8 @@ ('py:class', 'asyncio.locks.Event'), # aiohttp: should be fixed + # See issue: https://github.com/aio-libs/aiohttp/issues/10468 + ('py:class', 'aiohttp.client.ClientSession'), ('py:class', 'aiohttp.client_reqrep.Fingerprint'), ('py:class', 'aiohttp.helpers.BasicAuth'), diff --git a/docs/modules/gql.rst b/docs/modules/gql.rst index b7c13c7c..035f196f 100644 --- a/docs/modules/gql.rst +++ b/docs/modules/gql.rst @@ -24,11 +24,15 @@ Sub-Packages transport_aiohttp_websockets transport_appsync_auth transport_appsync_websockets + transport_common_base + transport_common_adapters_connection + transport_common_adapters_aiohttp + transport_common_adapters_websockets transport_exceptions transport_phoenix_channel_websockets transport_requests transport_httpx transport_websockets - transport_websockets_base + transport_websockets_protocol dsl utilities diff --git a/docs/modules/transport_common_adapters_aiohttp.rst b/docs/modules/transport_common_adapters_aiohttp.rst new file mode 100644 index 00000000..537c8673 --- /dev/null +++ b/docs/modules/transport_common_adapters_aiohttp.rst @@ -0,0 +1,7 @@ +gql.transport.common.adapters.aiohttp +===================================== + +.. currentmodule:: gql.transport.common.adapters.aiohttp + +.. automodule:: gql.transport.common.adapters.aiohttp + :member-order: bysource diff --git a/docs/modules/transport_common_adapters_connection.rst b/docs/modules/transport_common_adapters_connection.rst new file mode 100644 index 00000000..ffa1a1b3 --- /dev/null +++ b/docs/modules/transport_common_adapters_connection.rst @@ -0,0 +1,7 @@ +gql.transport.common.adapters.connection +======================================== + +.. currentmodule:: gql.transport.common.adapters.connection + +.. automodule:: gql.transport.common.adapters.connection + :member-order: bysource diff --git a/docs/modules/transport_common_adapters_websockets.rst b/docs/modules/transport_common_adapters_websockets.rst new file mode 100644 index 00000000..4005694c --- /dev/null +++ b/docs/modules/transport_common_adapters_websockets.rst @@ -0,0 +1,7 @@ +gql.transport.common.adapters.websockets +======================================== + +.. currentmodule:: gql.transport.common.adapters.websockets + +.. automodule:: gql.transport.common.adapters.websockets + :member-order: bysource diff --git a/docs/modules/transport_common_base.rst b/docs/modules/transport_common_base.rst new file mode 100644 index 00000000..4a7ec15a --- /dev/null +++ b/docs/modules/transport_common_base.rst @@ -0,0 +1,7 @@ +gql.transport.common.base +========================= + +.. currentmodule:: gql.transport.common.base + +.. automodule:: gql.transport.common.base + :member-order: bysource diff --git a/docs/modules/transport_websockets_base.rst b/docs/modules/transport_websockets_base.rst deleted file mode 100644 index 548351eb..00000000 --- a/docs/modules/transport_websockets_base.rst +++ /dev/null @@ -1,7 +0,0 @@ -gql.transport.websockets_base -============================= - -.. currentmodule:: gql.transport.websockets_base - -.. automodule:: gql.transport.websockets_base - :member-order: bysource diff --git a/docs/modules/transport_websockets_protocol.rst b/docs/modules/transport_websockets_protocol.rst new file mode 100644 index 00000000..b835abee --- /dev/null +++ b/docs/modules/transport_websockets_protocol.rst @@ -0,0 +1,7 @@ +gql.transport.websockets_protocol +================================= + +.. currentmodule:: gql.transport.websockets_protocol + +.. automodule:: gql.transport.websockets_protocol + :member-order: bysource diff --git a/gql/cli.py b/gql/cli.py index 91c67873..9ae92e83 100644 --- a/gql/cli.py +++ b/gql/cli.py @@ -391,9 +391,10 @@ def get_transport(args: Namespace) -> Optional[AsyncTransport]: auth = AppSyncJWTAuthentication(host=url.host, jwt=args.jwt) else: - from gql.transport.appsync_auth import AppSyncIAMAuthentication from botocore.exceptions import NoRegionError + from gql.transport.appsync_auth import AppSyncIAMAuthentication + try: auth = AppSyncIAMAuthentication(host=url.host) except NoRegionError: diff --git a/gql/client.py b/gql/client.py index c52a00b2..faf3230a 100644 --- a/gql/client.py +++ b/gql/client.py @@ -131,7 +131,10 @@ def __init__( self.introspection: Optional[IntrospectionQuery] = introspection # GraphQL transport chosen - self.transport: Optional[Union[Transport, AsyncTransport]] = transport + assert ( + transport is not None + ), "You need to provide either a transport or a schema to the Client." + self.transport: Union[Transport, AsyncTransport] = transport # Flag to indicate that we need to fetch the schema from the transport # On async transports, we fetch the schema before executing the first query @@ -149,10 +152,10 @@ def __init__( self.batch_max = batch_max @property - def batching_enabled(self): + def batching_enabled(self) -> bool: return self.batch_interval != 0 - def validate(self, document: DocumentNode): + def validate(self, document: DocumentNode) -> None: """:meta private:""" assert ( self.schema @@ -162,7 +165,9 @@ def validate(self, document: DocumentNode): if validation_errors: raise validation_errors[0] - def _build_schema_from_introspection(self, execution_result: ExecutionResult): + def _build_schema_from_introspection( + self, execution_result: ExecutionResult + ) -> None: if execution_result.errors: raise TransportQueryError( ( @@ -189,9 +194,8 @@ def execute_sync( parse_result: Optional[bool] = ..., *, # https://github.com/python/mypy/issues/7333#issuecomment-788255229 get_execution_result: Literal[False] = ..., - **kwargs, - ) -> Dict[str, Any]: - ... # pragma: no cover + **kwargs: Any, + ) -> Dict[str, Any]: ... # pragma: no cover @overload def execute_sync( @@ -203,9 +207,8 @@ def execute_sync( parse_result: Optional[bool] = ..., *, get_execution_result: Literal[True], - **kwargs, - ) -> ExecutionResult: - ... # pragma: no cover + **kwargs: Any, + ) -> ExecutionResult: ... # pragma: no cover @overload def execute_sync( @@ -217,9 +220,8 @@ def execute_sync( parse_result: Optional[bool] = ..., *, get_execution_result: bool, - **kwargs, - ) -> Union[Dict[str, Any], ExecutionResult]: - ... # pragma: no cover + **kwargs: Any, + ) -> Union[Dict[str, Any], ExecutionResult]: ... # pragma: no cover def execute_sync( self, @@ -229,7 +231,7 @@ def execute_sync( serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, get_execution_result: bool = False, - **kwargs, + **kwargs: Any, ) -> Union[Dict[str, Any], ExecutionResult]: """:meta private:""" with self as session: @@ -251,9 +253,8 @@ def execute_batch_sync( serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, get_execution_result: Literal[False] = ..., - **kwargs, - ) -> List[Dict[str, Any]]: - ... # pragma: no cover + **kwargs: Any, + ) -> List[Dict[str, Any]]: ... # pragma: no cover @overload def execute_batch_sync( @@ -263,9 +264,8 @@ def execute_batch_sync( serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, get_execution_result: Literal[True], - **kwargs, - ) -> List[ExecutionResult]: - ... # pragma: no cover + **kwargs: Any, + ) -> List[ExecutionResult]: ... # pragma: no cover @overload def execute_batch_sync( @@ -275,9 +275,8 @@ def execute_batch_sync( serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, get_execution_result: bool, - **kwargs, - ) -> Union[List[Dict[str, Any]], List[ExecutionResult]]: - ... # pragma: no cover + **kwargs: Any, + ) -> Union[List[Dict[str, Any]], List[ExecutionResult]]: ... # pragma: no cover def execute_batch_sync( self, @@ -286,7 +285,7 @@ def execute_batch_sync( serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, get_execution_result: bool = False, - **kwargs, + **kwargs: Any, ) -> Union[List[Dict[str, Any]], List[ExecutionResult]]: """:meta private:""" with self as session: @@ -308,9 +307,8 @@ async def execute_async( parse_result: Optional[bool] = ..., *, # https://github.com/python/mypy/issues/7333#issuecomment-788255229 get_execution_result: Literal[False] = ..., - **kwargs, - ) -> Dict[str, Any]: - ... # pragma: no cover + **kwargs: Any, + ) -> Dict[str, Any]: ... # pragma: no cover @overload async def execute_async( @@ -322,9 +320,8 @@ async def execute_async( parse_result: Optional[bool] = ..., *, get_execution_result: Literal[True], - **kwargs, - ) -> ExecutionResult: - ... # pragma: no cover + **kwargs: Any, + ) -> ExecutionResult: ... # pragma: no cover @overload async def execute_async( @@ -336,9 +333,8 @@ async def execute_async( parse_result: Optional[bool] = ..., *, get_execution_result: bool, - **kwargs, - ) -> Union[Dict[str, Any], ExecutionResult]: - ... # pragma: no cover + **kwargs: Any, + ) -> Union[Dict[str, Any], ExecutionResult]: ... # pragma: no cover async def execute_async( self, @@ -348,7 +344,7 @@ async def execute_async( serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, get_execution_result: bool = False, - **kwargs, + **kwargs: Any, ) -> Union[Dict[str, Any], ExecutionResult]: """:meta private:""" async with self as session: @@ -372,9 +368,8 @@ def execute( parse_result: Optional[bool] = ..., *, # https://github.com/python/mypy/issues/7333#issuecomment-788255229 get_execution_result: Literal[False] = ..., - **kwargs, - ) -> Dict[str, Any]: - ... # pragma: no cover + **kwargs: Any, + ) -> Dict[str, Any]: ... # pragma: no cover @overload def execute( @@ -386,9 +381,8 @@ def execute( parse_result: Optional[bool] = ..., *, get_execution_result: Literal[True], - **kwargs, - ) -> ExecutionResult: - ... # pragma: no cover + **kwargs: Any, + ) -> ExecutionResult: ... # pragma: no cover @overload def execute( @@ -400,9 +394,8 @@ def execute( parse_result: Optional[bool] = ..., *, get_execution_result: bool, - **kwargs, - ) -> Union[Dict[str, Any], ExecutionResult]: - ... # pragma: no cover + **kwargs: Any, + ) -> Union[Dict[str, Any], ExecutionResult]: ... # pragma: no cover def execute( self, @@ -412,7 +405,7 @@ def execute( serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, get_execution_result: bool = False, - **kwargs, + **kwargs: Any, ) -> Union[Dict[str, Any], ExecutionResult]: """Execute the provided document AST against the remote server using the transport provided during init. @@ -487,9 +480,8 @@ def execute_batch( serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, get_execution_result: Literal[False] = ..., - **kwargs, - ) -> List[Dict[str, Any]]: - ... # pragma: no cover + **kwargs: Any, + ) -> List[Dict[str, Any]]: ... # pragma: no cover @overload def execute_batch( @@ -499,9 +491,8 @@ def execute_batch( serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, get_execution_result: Literal[True], - **kwargs, - ) -> List[ExecutionResult]: - ... # pragma: no cover + **kwargs: Any, + ) -> List[ExecutionResult]: ... # pragma: no cover @overload def execute_batch( @@ -511,9 +502,8 @@ def execute_batch( serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, get_execution_result: bool, - **kwargs, - ) -> Union[List[Dict[str, Any]], List[ExecutionResult]]: - ... # pragma: no cover + **kwargs: Any, + ) -> Union[List[Dict[str, Any]], List[ExecutionResult]]: ... # pragma: no cover def execute_batch( self, @@ -522,7 +512,7 @@ def execute_batch( serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, get_execution_result: bool = False, - **kwargs, + **kwargs: Any, ) -> Union[List[Dict[str, Any]], List[ExecutionResult]]: """Execute multiple GraphQL requests in a batch against the remote server using the transport provided during init. @@ -568,9 +558,8 @@ def subscribe_async( parse_result: Optional[bool] = ..., *, get_execution_result: Literal[False] = ..., - **kwargs, - ) -> AsyncGenerator[Dict[str, Any], None]: - ... # pragma: no cover + **kwargs: Any, + ) -> AsyncGenerator[Dict[str, Any], None]: ... # pragma: no cover @overload def subscribe_async( @@ -582,9 +571,8 @@ def subscribe_async( parse_result: Optional[bool] = ..., *, get_execution_result: Literal[True], - **kwargs, - ) -> AsyncGenerator[ExecutionResult, None]: - ... # pragma: no cover + **kwargs: Any, + ) -> AsyncGenerator[ExecutionResult, None]: ... # pragma: no cover @overload def subscribe_async( @@ -596,11 +584,10 @@ def subscribe_async( parse_result: Optional[bool] = ..., *, get_execution_result: bool, - **kwargs, + **kwargs: Any, ) -> Union[ AsyncGenerator[Dict[str, Any], None], AsyncGenerator[ExecutionResult, None] - ]: - ... # pragma: no cover + ]: ... # pragma: no cover async def subscribe_async( self, @@ -610,7 +597,7 @@ async def subscribe_async( serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, get_execution_result: bool = False, - **kwargs, + **kwargs: Any, ) -> Union[ AsyncGenerator[Dict[str, Any], None], AsyncGenerator[ExecutionResult, None] ]: @@ -639,9 +626,8 @@ def subscribe( parse_result: Optional[bool] = ..., *, get_execution_result: Literal[False] = ..., - **kwargs, - ) -> Generator[Dict[str, Any], None, None]: - ... # pragma: no cover + **kwargs: Any, + ) -> Generator[Dict[str, Any], None, None]: ... # pragma: no cover @overload def subscribe( @@ -653,9 +639,8 @@ def subscribe( parse_result: Optional[bool] = ..., *, get_execution_result: Literal[True], - **kwargs, - ) -> Generator[ExecutionResult, None, None]: - ... # pragma: no cover + **kwargs: Any, + ) -> Generator[ExecutionResult, None, None]: ... # pragma: no cover @overload def subscribe( @@ -667,11 +652,10 @@ def subscribe( parse_result: Optional[bool] = ..., *, get_execution_result: bool, - **kwargs, + **kwargs: Any, ) -> Union[ Generator[Dict[str, Any], None, None], Generator[ExecutionResult, None, None] - ]: - ... # pragma: no cover + ]: ... # pragma: no cover def subscribe( self, @@ -682,7 +666,7 @@ def subscribe( parse_result: Optional[bool] = None, *, get_execution_result: bool = False, - **kwargs, + **kwargs: Any, ) -> Union[ Generator[Dict[str, Any], None, None], Generator[ExecutionResult, None, None] ]: @@ -770,6 +754,8 @@ async def connect_async(self, reconnecting=False, **kwargs): self.transport, AsyncTransport ), "Only a transport of type AsyncTransport can be used asynchronously" + self.session: Union[AsyncClientSession, SyncClientSession] + if reconnecting: self.session = ReconnectingAsyncClientSession(client=self, **kwargs) await self.session.start_connecting_task() @@ -825,6 +811,8 @@ def connect_sync(self): if not hasattr(self, "session"): self.session = SyncClientSession(client=self) + assert isinstance(self.session, SyncClientSession) + self.session.connect() # Get schema from transport if needed @@ -846,6 +834,8 @@ def close_sync(self): If batching is enabled, this will block until the remaining queries in the batching queue have been processed. """ + assert isinstance(self.session, SyncClientSession) + self.session.close() def __enter__(self): @@ -873,7 +863,7 @@ def _execute( operation_name: Optional[str] = None, serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, - **kwargs, + **kwargs: Any, ) -> ExecutionResult: """Execute the provided document AST synchronously using the sync transport, returning an ExecutionResult object. @@ -944,9 +934,8 @@ def execute( parse_result: Optional[bool] = ..., *, get_execution_result: Literal[False] = ..., - **kwargs, - ) -> Dict[str, Any]: - ... # pragma: no cover + **kwargs: Any, + ) -> Dict[str, Any]: ... # pragma: no cover @overload def execute( @@ -958,9 +947,8 @@ def execute( parse_result: Optional[bool] = ..., *, get_execution_result: Literal[True], - **kwargs, - ) -> ExecutionResult: - ... # pragma: no cover + **kwargs: Any, + ) -> ExecutionResult: ... # pragma: no cover @overload def execute( @@ -972,9 +960,8 @@ def execute( parse_result: Optional[bool] = ..., *, get_execution_result: bool, - **kwargs, - ) -> Union[Dict[str, Any], ExecutionResult]: - ... # pragma: no cover + **kwargs: Any, + ) -> Union[Dict[str, Any], ExecutionResult]: ... # pragma: no cover def execute( self, @@ -984,7 +971,7 @@ def execute( serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, get_execution_result: bool = False, - **kwargs, + **kwargs: Any, ) -> Union[Dict[str, Any], ExecutionResult]: """Execute the provided document AST synchronously using the sync transport. @@ -1040,7 +1027,7 @@ def _execute_batch( serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, validate_document: Optional[bool] = True, - **kwargs, + **kwargs: Any, ) -> List[ExecutionResult]: """Execute multiple GraphQL requests in a batch, using the sync transport, returning a list of ExecutionResult objects. @@ -1067,9 +1054,11 @@ def _execute_batch( serialize_variables is None and self.client.serialize_variables ): requests = [ - req.serialize_variable_values(self.client.schema) - if req.variable_values is not None - else req + ( + req.serialize_variable_values(self.client.schema) + if req.variable_values is not None + else req + ) for req in requests ] @@ -1096,9 +1085,8 @@ def execute_batch( serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, get_execution_result: Literal[False] = ..., - **kwargs, - ) -> List[Dict[str, Any]]: - ... # pragma: no cover + **kwargs: Any, + ) -> List[Dict[str, Any]]: ... # pragma: no cover @overload def execute_batch( @@ -1108,9 +1096,8 @@ def execute_batch( serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, get_execution_result: Literal[True], - **kwargs, - ) -> List[ExecutionResult]: - ... # pragma: no cover + **kwargs: Any, + ) -> List[ExecutionResult]: ... # pragma: no cover @overload def execute_batch( @@ -1120,9 +1107,8 @@ def execute_batch( serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, get_execution_result: bool, - **kwargs, - ) -> Union[List[Dict[str, Any]], List[ExecutionResult]]: - ... # pragma: no cover + **kwargs: Any, + ) -> Union[List[Dict[str, Any]], List[ExecutionResult]]: ... # pragma: no cover def execute_batch( self, @@ -1131,7 +1117,7 @@ def execute_batch( serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, get_execution_result: bool = False, - **kwargs, + **kwargs: Any, ) -> Union[List[Dict[str, Any]], List[ExecutionResult]]: """Execute multiple GraphQL requests in a batch, using the sync transport. This method sends the requests to the server all at once. @@ -1312,7 +1298,7 @@ async def _subscribe( operation_name: Optional[str] = None, serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, - **kwargs, + **kwargs: Any, ) -> AsyncGenerator[ExecutionResult, None]: """Coroutine to subscribe asynchronously to the provided document AST asynchronously using the async transport, @@ -1349,13 +1335,13 @@ async def _subscribe( ) # Subscribe to the transport - inner_generator: AsyncGenerator[ - ExecutionResult, None - ] = self.transport.subscribe( - document, - variable_values=variable_values, - operation_name=operation_name, - **kwargs, + inner_generator: AsyncGenerator[ExecutionResult, None] = ( + self.transport.subscribe( + document, + variable_values=variable_values, + operation_name=operation_name, + **kwargs, + ) ) # Keep a reference to the inner generator @@ -1390,9 +1376,8 @@ def subscribe( parse_result: Optional[bool] = ..., *, get_execution_result: Literal[False] = ..., - **kwargs, - ) -> AsyncGenerator[Dict[str, Any], None]: - ... # pragma: no cover + **kwargs: Any, + ) -> AsyncGenerator[Dict[str, Any], None]: ... # pragma: no cover @overload def subscribe( @@ -1404,9 +1389,8 @@ def subscribe( parse_result: Optional[bool] = ..., *, get_execution_result: Literal[True], - **kwargs, - ) -> AsyncGenerator[ExecutionResult, None]: - ... # pragma: no cover + **kwargs: Any, + ) -> AsyncGenerator[ExecutionResult, None]: ... # pragma: no cover @overload def subscribe( @@ -1418,11 +1402,10 @@ def subscribe( parse_result: Optional[bool] = ..., *, get_execution_result: bool, - **kwargs, + **kwargs: Any, ) -> Union[ AsyncGenerator[Dict[str, Any], None], AsyncGenerator[ExecutionResult, None] - ]: - ... # pragma: no cover + ]: ... # pragma: no cover async def subscribe( self, @@ -1432,7 +1415,7 @@ async def subscribe( serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, get_execution_result: bool = False, - **kwargs, + **kwargs: Any, ) -> Union[ AsyncGenerator[Dict[str, Any], None], AsyncGenerator[ExecutionResult, None] ]: @@ -1491,7 +1474,7 @@ async def _execute( operation_name: Optional[str] = None, serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, - **kwargs, + **kwargs: Any, ) -> ExecutionResult: """Coroutine to execute the provided document AST asynchronously using the async transport, returning an ExecutionResult object. @@ -1557,9 +1540,8 @@ async def execute( parse_result: Optional[bool] = ..., *, get_execution_result: Literal[False] = ..., - **kwargs, - ) -> Dict[str, Any]: - ... # pragma: no cover + **kwargs: Any, + ) -> Dict[str, Any]: ... # pragma: no cover @overload async def execute( @@ -1571,9 +1553,8 @@ async def execute( parse_result: Optional[bool] = ..., *, get_execution_result: Literal[True], - **kwargs, - ) -> ExecutionResult: - ... # pragma: no cover + **kwargs: Any, + ) -> ExecutionResult: ... # pragma: no cover @overload async def execute( @@ -1585,9 +1566,8 @@ async def execute( parse_result: Optional[bool] = ..., *, get_execution_result: bool, - **kwargs, - ) -> Union[Dict[str, Any], ExecutionResult]: - ... # pragma: no cover + **kwargs: Any, + ) -> Union[Dict[str, Any], ExecutionResult]: ... # pragma: no cover async def execute( self, @@ -1597,7 +1577,7 @@ async def execute( serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, get_execution_result: bool = False, - **kwargs, + **kwargs: Any, ) -> Union[Dict[str, Any], ExecutionResult]: """Coroutine to execute the provided document AST asynchronously using the async transport. @@ -1775,7 +1755,7 @@ async def _execute_once( operation_name: Optional[str] = None, serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, - **kwargs, + **kwargs: Any, ) -> ExecutionResult: """Same Coroutine as parent method _execute but requesting a reconnection if we receive a TransportClosed exception. @@ -1803,7 +1783,7 @@ async def _execute( operation_name: Optional[str] = None, serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, - **kwargs, + **kwargs: Any, ) -> ExecutionResult: """Same Coroutine as parent, but with optional retries and requesting a reconnection if we receive a TransportClosed exception. @@ -1825,7 +1805,7 @@ async def _subscribe( operation_name: Optional[str] = None, serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, - **kwargs, + **kwargs: Any, ) -> AsyncGenerator[ExecutionResult, None]: """Same Async generator as parent method _subscribe but requesting a reconnection if we receive a TransportClosed exception. diff --git a/gql/dsl.py b/gql/dsl.py index be2b5a7e..e5b5131e 100644 --- a/gql/dsl.py +++ b/gql/dsl.py @@ -2,6 +2,7 @@ .. image:: http://www.plantuml.com/plantuml/png/ZLAzJWCn3Dxz51vXw1im50ag8L4XwC1OkLTJ8gMvAd4GwEYxGuC8pTbKtUxy_TZEvsaIYfAt7e1MII9rWfsdbF1cSRzWpvtq4GT0JENduX8GXr_g7brQlf5tw-MBOx_-HlS0LV_Kzp8xr1kZav9PfCsMWvolEA_1VylHoZCExKwKv4Tg2s_VkSkca2kof2JDb0yxZYIk3qMZYUe1B1uUZOROXn96pQMugEMUdRnUUqUf6DBXQyIz2zu5RlgUQAFVNYaeRfBI79_JrUTaeg9JZFQj5MmUc69PDmNGE2iU61fDgfri3x36gxHw3gDHD6xqqQ7P4vjKqz2-602xtkO7uo17SCLhVSv25VjRjUAFcUE73Sspb8ADBl8gTT7j2cFAOPst_Wi0 # noqa :alt: UML diagram """ + import logging import re from abc import ABC, abstractmethod @@ -338,7 +339,7 @@ def select( self, *fields: "DSLSelectable", **fields_with_alias: "DSLSelectableWithAlias", - ): + ) -> Any: r"""Select the fields which should be added. :param \*fields: fields or fragments @@ -595,9 +596,11 @@ def get_ast_definitions(self) -> Tuple[VariableDefinitionNode, ...]: VariableDefinitionNode( type=var.ast_variable_type, variable=var.ast_variable_name, - default_value=None - if var.default_value is None - else ast_from_value(var.default_value, var.type), + default_value=( + None + if var.default_value is None + else ast_from_value(var.default_value, var.type) + ), directives=(), ) for var in self.variables.values() @@ -836,10 +839,10 @@ def name(self): """:meta private:""" return self.ast_field.name.value - def __call__(self, **kwargs) -> "DSLField": + def __call__(self, **kwargs: Any) -> "DSLField": return self.args(**kwargs) - def args(self, **kwargs) -> "DSLField": + def args(self, **kwargs: Any) -> "DSLField": r"""Set the arguments of a field The arguments are parsed to be stored in the AST of this field. diff --git a/gql/transport/aiohttp.py b/gql/transport/aiohttp.py index b581e311..76b46c35 100644 --- a/gql/transport/aiohttp.py +++ b/gql/transport/aiohttp.py @@ -3,7 +3,17 @@ import json import logging from ssl import SSLContext -from typing import Any, AsyncGenerator, Callable, Dict, Optional, Tuple, Type, Union +from typing import ( + Any, + AsyncGenerator, + Callable, + Dict, + NoReturn, + Optional, + Tuple, + Type, + Union, +) import aiohttp from aiohttp.client_exceptions import ClientResponseError @@ -102,9 +112,9 @@ async def connect(self) -> None: client_session_args: Dict[str, Any] = { "cookies": self.cookies, "headers": self.headers, - "auth": None - if isinstance(self.auth, AppSyncAuthentication) - else self.auth, + "auth": ( + None if isinstance(self.auth, AppSyncAuthentication) else self.auth + ), "json_serialize": self.json_serialize, } @@ -262,7 +272,9 @@ async def execute( # Saving latest response headers in the transport self.response_headers = resp.headers - async def raise_response_error(resp: aiohttp.ClientResponse, reason: str): + async def raise_response_error( + resp: aiohttp.ClientResponse, reason: str + ) -> NoReturn: # We raise a TransportServerError if the status code is 400 or higher # We raise a TransportProtocolError in the other cases diff --git a/gql/transport/appsync_websockets.py b/gql/transport/appsync_websockets.py index f35cefe5..a6a7d180 100644 --- a/gql/transport/appsync_websockets.py +++ b/gql/transport/appsync_websockets.py @@ -29,7 +29,7 @@ class AppSyncWebsocketsTransport(SubscriptionTransportBase): on a websocket connection. """ - auth: Optional[AppSyncAuthentication] + auth: AppSyncAuthentication def __init__( self, @@ -72,7 +72,7 @@ def __init__( # May raise NoRegionError or NoCredentialsError or ImportError auth = AppSyncIAMAuthentication(host=host, session=session) - self.auth = auth + self.auth: AppSyncAuthentication = auth self.ack_timeout: Optional[Union[int, float]] = ack_timeout self.init_payload: Dict[str, Any] = {} diff --git a/gql/transport/common/adapters/aiohttp.py b/gql/transport/common/adapters/aiohttp.py index f2dff699..736f2a3e 100644 --- a/gql/transport/common/adapters/aiohttp.py +++ b/gql/transport/common/adapters/aiohttp.py @@ -50,9 +50,9 @@ def __init__( certificate validation. :param session: Optional aiohttp opened session. :param client_session_args: Dict of extra args passed to - `aiohttp.ClientSession`_ + :class:`aiohttp.ClientSession` :param connect_args: Dict of extra args passed to - `aiohttp.ClientSession.ws_connect`_ + :meth:`aiohttp.ClientSession.ws_connect` :param float heartbeat: Send low level `ping` message every `heartbeat` seconds and wait `pong` response, close diff --git a/gql/transport/common/base.py b/gql/transport/common/base.py index 770a8b34..a3d025c0 100644 --- a/gql/transport/common/base.py +++ b/gql/transport/common/base.py @@ -95,29 +95,29 @@ async def _initialize(self): """ pass # pragma: no cover - async def _stop_listener(self, query_id: int): + async def _stop_listener(self, query_id: int) -> None: """Hook to stop to listen to a specific query. Will send a stop message in some subclasses. """ pass # pragma: no cover - async def _after_connect(self): + async def _after_connect(self) -> None: """Hook to add custom code for subclasses after the connection has been established. """ pass # pragma: no cover - async def _after_initialize(self): + async def _after_initialize(self) -> None: """Hook to add custom code for subclasses after the initialization has been done. """ pass # pragma: no cover - async def _close_hook(self): + async def _close_hook(self) -> None: """Hook to add custom code for subclasses for the connection close""" pass # pragma: no cover - async def _connection_terminate(self): + async def _connection_terminate(self) -> None: """Hook to add custom code for subclasses after the initialization has been done. """ @@ -430,7 +430,7 @@ async def connect(self) -> None: log.debug("connect: done") - def _remove_listener(self, query_id) -> None: + def _remove_listener(self, query_id: int) -> None: """After exiting from a subscription, remove the listener and signal an event if this was the last listener for the client. """ diff --git a/gql/transport/httpx.py b/gql/transport/httpx.py index 811601b8..4c5d33d0 100644 --- a/gql/transport/httpx.py +++ b/gql/transport/httpx.py @@ -7,6 +7,7 @@ Callable, Dict, List, + NoReturn, Optional, Tuple, Type, @@ -39,7 +40,7 @@ def __init__( url: Union[str, httpx.URL], json_serialize: Callable = json.dumps, json_deserialize: Callable = json.loads, - **kwargs, + **kwargs: Any, ): """Initialize the transport with the given httpx parameters. @@ -93,7 +94,9 @@ def _prepare_request( return post_args - def _prepare_file_uploads(self, variable_values, payload) -> Dict[str, Any]: + def _prepare_file_uploads( + self, variable_values: Dict[str, Any], payload: Dict[str, Any] + ) -> Dict[str, Any]: # If we upload files, we will extract the files present in the # variable_values dict and replace them by null values nulled_variable_values, files = extract_files( @@ -163,7 +166,7 @@ def _prepare_result(self, response: httpx.Response) -> ExecutionResult: extensions=result.get("extensions"), ) - def _raise_response_error(self, response: httpx.Response, reason: str): + def _raise_response_error(self, response: httpx.Response, reason: str) -> NoReturn: # We raise a TransportServerError if the status code is 400 or higher # We raise a TransportProtocolError in the other cases diff --git a/gql/transport/local_schema.py b/gql/transport/local_schema.py index 04ed4ff1..19760ad6 100644 --- a/gql/transport/local_schema.py +++ b/gql/transport/local_schema.py @@ -1,6 +1,6 @@ import asyncio from inspect import isawaitable -from typing import AsyncGenerator, Awaitable, cast +from typing import Any, AsyncGenerator, Awaitable, cast from graphql import DocumentNode, ExecutionResult, GraphQLSchema, execute, subscribe @@ -31,8 +31,8 @@ async def close(self): async def execute( self, document: DocumentNode, - *args, - **kwargs, + *args: Any, + **kwargs: Any, ) -> ExecutionResult: """Execute the provided document AST for on a local GraphQL Schema.""" @@ -58,8 +58,8 @@ async def _await_if_necessary(obj): async def subscribe( self, document: DocumentNode, - *args, - **kwargs, + *args: Any, + **kwargs: Any, ) -> AsyncGenerator[ExecutionResult, None]: """Send a subscription and receive the results using an async generator diff --git a/gql/transport/phoenix_channel_websockets.py b/gql/transport/phoenix_channel_websockets.py index 3885fcac..8a975b73 100644 --- a/gql/transport/phoenix_channel_websockets.py +++ b/gql/transport/phoenix_channel_websockets.py @@ -42,7 +42,7 @@ def __init__( channel_name: str = "__absinthe__:control", heartbeat_interval: float = 30, ack_timeout: Optional[Union[int, float]] = 10, - **kwargs, + **kwargs: Any, ) -> None: """Initialize the transport with the given parameters. @@ -244,7 +244,7 @@ def _required_value(d: Any, key: str, label: str) -> Any: return value def _required_subscription_id( - d: Any, label: str, must_exist: bool = False, must_not_exist=False + d: Any, label: str, must_exist: bool = False, must_not_exist: bool = False ) -> str: subscription_id = str(_required_value(d, "subscriptionId", label)) if must_exist and (subscription_id not in self.subscriptions): diff --git a/gql/transport/requests.py b/gql/transport/requests.py index bd370908..44f8a362 100644 --- a/gql/transport/requests.py +++ b/gql/transport/requests.py @@ -1,13 +1,25 @@ import io import json import logging -from typing import Any, Callable, Collection, Dict, List, Optional, Tuple, Type, Union +from typing import ( + Any, + Callable, + Collection, + Dict, + List, + NoReturn, + Optional, + Tuple, + Type, + Union, +) import requests from graphql import DocumentNode, ExecutionResult, print_ast from requests.adapters import HTTPAdapter, Retry from requests.auth import AuthBase from requests.cookies import RequestsCookieJar +from requests.structures import CaseInsensitiveDict from requests_toolbelt.multipart.encoder import MultipartEncoder from gql.transport import Transport @@ -100,9 +112,9 @@ def __init__( self.json_deserialize: Callable = json_deserialize self.kwargs = kwargs - self.session = None + self.session: Optional[requests.Session] = None - self.response_headers = None + self.response_headers: Optional[CaseInsensitiveDict[str]] = None def connect(self): if self.session is None: @@ -159,7 +171,7 @@ def execute( # type: ignore if operation_name: payload["operationName"] = operation_name - post_args = { + post_args: Dict[str, Any] = { "headers": self.headers, "auth": self.auth, "cookies": self.cookies, @@ -219,7 +231,7 @@ def execute( # type: ignore if post_args["headers"] is None: post_args["headers"] = {} else: - post_args["headers"] = {**post_args["headers"]} + post_args["headers"] = dict(post_args["headers"]) post_args["headers"]["Content-Type"] = data.content_type @@ -247,7 +259,7 @@ def execute( # type: ignore ) self.response_headers = response.headers - def raise_response_error(resp: requests.Response, reason: str): + def raise_response_error(resp: requests.Response, reason: str) -> NoReturn: # We raise a TransportServerError if the status code is 400 or higher # We raise a TransportProtocolError in the other cases @@ -255,7 +267,8 @@ def raise_response_error(resp: requests.Response, reason: str): # Raise a HTTPError if response status is 400 or higher resp.raise_for_status() except requests.HTTPError as e: - raise TransportServerError(str(e), e.response.status_code) from e + status_code = e.response.status_code if e.response is not None else None + raise TransportServerError(str(e), status_code) from e result_text = resp.text raise TransportProtocolError( diff --git a/gql/transport/transport.py b/gql/transport/transport.py index a5bd7100..49d0aa34 100644 --- a/gql/transport/transport.py +++ b/gql/transport/transport.py @@ -1,5 +1,5 @@ import abc -from typing import List +from typing import Any, List from graphql import DocumentNode, ExecutionResult @@ -8,7 +8,9 @@ class Transport(abc.ABC): @abc.abstractmethod - def execute(self, document: DocumentNode, *args, **kwargs) -> ExecutionResult: + def execute( + self, document: DocumentNode, *args: Any, **kwargs: Any + ) -> ExecutionResult: """Execute GraphQL query. Execute the provided document AST for either a remote or local GraphQL Schema. @@ -23,8 +25,8 @@ def execute(self, document: DocumentNode, *args, **kwargs) -> ExecutionResult: def execute_batch( self, reqs: List[GraphQLRequest], - *args, - **kwargs, + *args: Any, + **kwargs: Any, ) -> List[ExecutionResult]: """Execute multiple GraphQL requests in a batch. @@ -35,7 +37,7 @@ def execute_batch( """ raise NotImplementedError( "This Transport has not implemented the execute_batch method" - ) # pragma: no cover + ) def connect(self): """Establish a session with the transport.""" diff --git a/gql/transport/websockets_protocol.py b/gql/transport/websockets_protocol.py index 3348c576..61a4bb85 100644 --- a/gql/transport/websockets_protocol.py +++ b/gql/transport/websockets_protocol.py @@ -194,7 +194,7 @@ async def _send_complete_message(self, query_id: int) -> None: await self._send(complete_message) - async def _stop_listener(self, query_id: int): + async def _stop_listener(self, query_id: int) -> None: """Stop the listener corresponding to the query_id depending on the detected backend protocol. diff --git a/gql/utilities/node_tree.py b/gql/utilities/node_tree.py index 4313188e..08fb1bf5 100644 --- a/gql/utilities/node_tree.py +++ b/gql/utilities/node_tree.py @@ -8,7 +8,7 @@ def _node_tree_recursive( *, indent: int = 0, ignored_keys: List, -): +) -> str: assert ignored_keys is not None @@ -65,7 +65,7 @@ def node_tree( ignore_loc: bool = True, ignore_block: bool = True, ignored_keys: Optional[List] = None, -): +) -> str: """Method which returns a tree of Node elements as a String. Useful to debug deep DocumentNode instances created by gql or dsl_gql. diff --git a/gql/utilities/parse_result.py b/gql/utilities/parse_result.py index 02355425..f9bc2e0c 100644 --- a/gql/utilities/parse_result.py +++ b/gql/utilities/parse_result.py @@ -44,7 +44,7 @@ } -def _ignore_non_null(type_: GraphQLType): +def _ignore_non_null(type_: GraphQLType) -> GraphQLType: """Removes the GraphQLNonNull wrappings around types.""" if isinstance(type_, GraphQLNonNull): return type_.of_type @@ -153,6 +153,8 @@ def get_current_result_type(self, path): list_level = self.inside_list_level + assert field_type is not None + result_type = _ignore_non_null(field_type) if self.in_first_field(path): diff --git a/gql/utilities/update_schema_enum.py b/gql/utilities/update_schema_enum.py index 80c73862..6f7ba0ce 100644 --- a/gql/utilities/update_schema_enum.py +++ b/gql/utilities/update_schema_enum.py @@ -9,7 +9,7 @@ def update_schema_enum( name: str, values: Union[Dict[str, Any], Type[Enum]], use_enum_values: bool = False, -): +) -> None: """Update in the schema the GraphQLEnumType corresponding to the given name. Example:: diff --git a/gql/utilities/update_schema_scalars.py b/gql/utilities/update_schema_scalars.py index db3adb17..c2c1b4e8 100644 --- a/gql/utilities/update_schema_scalars.py +++ b/gql/utilities/update_schema_scalars.py @@ -3,7 +3,9 @@ from graphql import GraphQLScalarType, GraphQLSchema -def update_schema_scalar(schema: GraphQLSchema, name: str, scalar: GraphQLScalarType): +def update_schema_scalar( + schema: GraphQLSchema, name: str, scalar: GraphQLScalarType +) -> None: """Update the scalar in a schema with the scalar provided. :param schema: the GraphQL schema @@ -36,7 +38,9 @@ def update_schema_scalar(schema: GraphQLSchema, name: str, scalar: GraphQLScalar setattr(schema_scalar, "parse_literal", scalar.parse_literal) -def update_schema_scalars(schema: GraphQLSchema, scalars: List[GraphQLScalarType]): +def update_schema_scalars( + schema: GraphQLSchema, scalars: List[GraphQLScalarType] +) -> None: """Update the scalars in a schema with the scalars provided. :param schema: the GraphQL schema diff --git a/gql/utils.py b/gql/utils.py index b4265ce1..6a7d0791 100644 --- a/gql/utils.py +++ b/gql/utils.py @@ -25,17 +25,17 @@ def recurse_extract(path, obj): """ nonlocal files if isinstance(obj, list): - nulled_obj = [] + nulled_list = [] for key, value in enumerate(obj): value = recurse_extract(f"{path}.{key}", value) - nulled_obj.append(value) - return nulled_obj + nulled_list.append(value) + return nulled_list elif isinstance(obj, dict): - nulled_obj = {} + nulled_dict = {} for key, value in obj.items(): value = recurse_extract(f"{path}.{key}", value) - nulled_obj[key] = value - return nulled_obj + nulled_dict[key] = value + return nulled_dict elif isinstance(obj, file_classes): # extract obj from its parent and put it into files instead. files[path] = obj diff --git a/pyproject.toml b/pyproject.toml index 122cec88..f5eb5c8d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,5 +8,15 @@ dynamic = ["authors", "classifiers", "dependencies", "description", "entry-point requires = ["setuptools"] build-backend = "setuptools.build_meta" +[tool.isort] +extra_standard_library = "ssl" +known_first_party = "gql" +profile = "black" + [tool.pytest.ini_options] asyncio_default_fixture_loop_scope = "function" + +[tool.mypy] +ignore_missing_imports = true +check_untyped_defs = true +disallow_incomplete_defs = true diff --git a/setup.cfg b/setup.cfg index 66380493..533b80f1 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,16 +1,5 @@ [flake8] max-line-length = 88 -[isort] -known_standard_library = ssl -known_first_party = gql -multi_line_output = 3 -include_trailing_comma = True -line_length = 88 -not_skip = __init__.py - -[mypy] -ignore_missing_imports = true - [tool:pytest] norecursedirs = venv .venv .tox .git .cache .mypy_cache .pytest_cache diff --git a/setup.py b/setup.py index e8be1ef6..f000136c 100644 --- a/setup.py +++ b/setup.py @@ -24,15 +24,15 @@ ] dev_requires = [ - "black==22.3.0", + "black==25.1.0", "check-manifest>=0.42,<1", - "flake8==7.1.1", - "isort==4.3.21", - "mypy==1.10", + "flake8==7.1.2", + "isort==6.0.1", + "mypy==1.15", "sphinx>=7.0.0,<8;python_version<='3.9'", "sphinx>=8.1.0,<9;python_version>'3.9'", "sphinx_rtd_theme>=3.0.2,<4", - "sphinx-argparse==0.4.0", + "sphinx-argparse==0.5.2", "types-aiofiles", "types-requests", ] + tests_requires diff --git a/tests/conftest.py b/tests/conftest.py index 5b8807ae..70a050d5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -10,10 +10,11 @@ import tempfile import types from concurrent.futures import ThreadPoolExecutor -from typing import Union +from typing import Callable, Iterable, List, Union, cast import pytest import pytest_asyncio +from _pytest.fixtures import SubRequest from gql import Client @@ -219,7 +220,7 @@ async def start(self, handler, extra_serve_args=None): self.server = await self.start_server # Get hostname and port - hostname, port = self.server.sockets[0].getsockname()[:2] + hostname, port = self.server.sockets[0].getsockname()[:2] # type: ignore assert hostname == "127.0.0.1" self.hostname = hostname @@ -250,7 +251,7 @@ def __init__(self, with_ssl=False): if with_ssl: _, self.ssl_context = get_localhost_ssl_context() - def get_default_server_handler(answers): + def get_default_server_handler(answers: Iterable[str]) -> Callable: async def default_server_handler(request): import aiohttp @@ -291,7 +292,7 @@ async def default_server_handler(request): elif msg.type == WSMsgType.ERROR: print(f"WebSocket connection closed with: {ws.exception()}") - raise ws.exception() + raise ws.exception() # type: ignore elif msg.type in ( WSMsgType.CLOSE, WSMsgType.CLOSED, @@ -341,7 +342,8 @@ async def start(self, handler): await self.site.start() # Retrieve the actual port the server is listening on - sockets = self.site._server.sockets + assert self.site._server is not None + sockets = self.site._server.sockets # type: ignore if sockets: self.port = sockets[0].getsockname()[1] protocol = "https" if self.with_ssl else "http" @@ -448,7 +450,7 @@ async def send_connection_ack(ws): class TemporaryFile: """Class used to generate temporary files for the tests""" - def __init__(self, content: Union[str, bytearray]): + def __init__(self, content: Union[str, bytearray, bytes]): mode = "w" if isinstance(content, str) else "wb" @@ -474,24 +476,30 @@ def __exit__(self, type, value, traceback): os.unlink(self.filename) -def get_aiohttp_ws_server_handler(request): +def get_aiohttp_ws_server_handler( + request: SubRequest, +) -> Callable: """Get the server handler for the aiohttp websocket server. Either get it from test or use the default server handler if the test provides only an array of answers. """ + server_handler: Callable + if isinstance(request.param, types.FunctionType): server_handler = request.param else: - answers = request.param + answers = cast(List[str], request.param) server_handler = AIOHTTPWebsocketServer.get_default_server_handler(answers) return server_handler -def get_server_handler(request): +def get_server_handler( + request: SubRequest, +) -> Callable: """Get the server handler. Either get it from test or use the default server handler @@ -501,7 +509,7 @@ def get_server_handler(request): from websockets.exceptions import ConnectionClosed if isinstance(request.param, types.FunctionType): - server_handler = request.param + server_handler: Callable = request.param else: answers = request.param diff --git a/tests/custom_scalars/test_enum_colors.py b/tests/custom_scalars/test_enum_colors.py index 2f15a8ca..3526d548 100644 --- a/tests/custom_scalars/test_enum_colors.py +++ b/tests/custom_scalars/test_enum_colors.py @@ -1,4 +1,5 @@ from enum import Enum +from typing import Optional import pytest from graphql import ( @@ -6,6 +7,7 @@ GraphQLEnumType, GraphQLField, GraphQLList, + GraphQLNamedType, GraphQLNonNull, GraphQLObjectType, GraphQLSchema, @@ -251,19 +253,30 @@ def test_list_of_list_of_list(): def test_update_schema_enum(): - assert schema.get_type("Color").parse_value("RED") == Color.RED + color_type: Optional[GraphQLNamedType] + + color_type = schema.get_type("Color") + assert isinstance(color_type, GraphQLEnumType) + assert color_type is not None + assert color_type.parse_value("RED") == Color.RED # Using values update_schema_enum(schema, "Color", Color, use_enum_values=True) - assert schema.get_type("Color").parse_value("RED") == 0 - assert schema.get_type("Color").serialize(1) == "GREEN" + color_type = schema.get_type("Color") + assert isinstance(color_type, GraphQLEnumType) + assert color_type is not None + assert color_type.parse_value("RED") == 0 + assert color_type.serialize(1) == "GREEN" update_schema_enum(schema, "Color", Color) - assert schema.get_type("Color").parse_value("RED") == Color.RED - assert schema.get_type("Color").serialize(Color.RED) == "RED" + color_type = schema.get_type("Color") + assert isinstance(color_type, GraphQLEnumType) + assert color_type is not None + assert color_type.parse_value("RED") == Color.RED + assert color_type.serialize(Color.RED) == "RED" def test_update_schema_enum_errors(): @@ -273,20 +286,22 @@ def test_update_schema_enum_errors(): assert "Enum Corlo not found in schema!" in str(exc_info) - with pytest.raises(TypeError) as exc_info: - update_schema_enum(schema, "Color", 6) + with pytest.raises(TypeError) as exc_info2: + update_schema_enum(schema, "Color", 6) # type: ignore - assert "Invalid type for enum values: " in str(exc_info) + assert "Invalid type for enum values: " in str(exc_info2) - with pytest.raises(TypeError) as exc_info: + with pytest.raises(TypeError) as exc_info3: update_schema_enum(schema, "RootQueryType", Color) - assert 'The type "RootQueryType" is not a GraphQLEnumType, it is a' in str(exc_info) + assert 'The type "RootQueryType" is not a GraphQLEnumType, it is a' in str( + exc_info3 + ) - with pytest.raises(KeyError) as exc_info: + with pytest.raises(KeyError) as exc_info4: update_schema_enum(schema, "Color", {"RED": Color.RED}) - assert 'Enum key "GREEN" not found in provided values!' in str(exc_info) + assert 'Enum key "GREEN" not found in provided values!' in str(exc_info4) def test_parse_results_with_operation_type(): diff --git a/tests/custom_scalars/test_money.py b/tests/custom_scalars/test_money.py index cf4ca45d..39f1a1cb 100644 --- a/tests/custom_scalars/test_money.py +++ b/tests/custom_scalars/test_money.py @@ -441,9 +441,9 @@ def handle_single(data: Dict[str, Any]) -> ExecutionResult: [ { "data": result.data, - "errors": [str(e) for e in result.errors] - if result.errors - else None, + "errors": ( + [str(e) for e in result.errors] if result.errors else None + ), } for result in results ] @@ -453,9 +453,9 @@ def handle_single(data: Dict[str, Any]) -> ExecutionResult: return web.json_response( { "data": result.data, - "errors": [str(e) for e in result.errors] - if result.errors - else None, + "errors": ( + [str(e) for e in result.errors] if result.errors else None + ), } ) @@ -680,14 +680,14 @@ async def test_update_schema_scalars(aiohttp_server): def test_update_schema_scalars_invalid_scalar(): with pytest.raises(TypeError) as exc_info: - update_schema_scalars(schema, [int]) + update_schema_scalars(schema, [int]) # type: ignore exception = exc_info.value assert str(exception) == "Scalars should be instances of GraphQLScalarType." with pytest.raises(TypeError) as exc_info: - update_schema_scalar(schema, "test", int) + update_schema_scalar(schema, "test", int) # type: ignore exception = exc_info.value @@ -697,7 +697,7 @@ def test_update_schema_scalars_invalid_scalar(): def test_update_schema_scalars_invalid_scalar_argument(): with pytest.raises(TypeError) as exc_info: - update_schema_scalars(schema, MoneyScalar) + update_schema_scalars(schema, MoneyScalar) # type: ignore exception = exc_info.value @@ -787,7 +787,7 @@ def test_code(): def test_serialize_value_with_invalid_type(): with pytest.raises(GraphQLError) as exc_info: - serialize_value("Not a valid type", 50) + serialize_value("Not a valid type", 50) # type: ignore exception = exc_info.value diff --git a/tests/fixtures/aws/fake_signer.py b/tests/fixtures/aws/fake_signer.py index c0177a32..61e80fa0 100644 --- a/tests/fixtures/aws/fake_signer.py +++ b/tests/fixtures/aws/fake_signer.py @@ -12,10 +12,10 @@ def _fake_signer_factory(request=None): class FakeSigner: - def __init__(self, request=None) -> None: + def __init__(self, request=None): self.request = request - def add_auth(self, request) -> None: + def add_auth(self, request): """ A fake for getting a request object that :return: diff --git a/tests/regressions/issue_447_dsl_missing_directives/test_dsl_directives.py b/tests/regressions/issue_447_dsl_missing_directives/test_dsl_directives.py index b31ade7f..e4653d48 100644 --- a/tests/regressions/issue_447_dsl_missing_directives/test_dsl_directives.py +++ b/tests/regressions/issue_447_dsl_missing_directives/test_dsl_directives.py @@ -1,3 +1,5 @@ +from graphql import GraphQLSchema + from gql import Client, gql from gql.dsl import DSLFragment, DSLQuery, DSLSchema, dsl_gql, print_ast from gql.utilities import node_tree @@ -34,6 +36,9 @@ def test_issue_447(): client = Client(schema=schema_str) + + assert isinstance(client.schema, GraphQLSchema) + ds = DSLSchema(client.schema) sprite = DSLFragment("SpriteUnionAsSprite") diff --git a/tests/starwars/fixtures.py b/tests/starwars/fixtures.py index 59d7ddfa..1d179f60 100644 --- a/tests/starwars/fixtures.py +++ b/tests/starwars/fixtures.py @@ -148,9 +148,10 @@ def create_review(episode, review): async def make_starwars_backend(aiohttp_server): from aiohttp import web - from .schema import StarWarsSchema from graphql import graphql_sync + from .schema import StarWarsSchema + async def handler(request): data = await request.json() source = data["query"] diff --git a/tests/starwars/schema.py b/tests/starwars/schema.py index 4b672ad3..8f1efe99 100644 --- a/tests/starwars/schema.py +++ b/tests/starwars/schema.py @@ -1,4 +1,5 @@ import asyncio +from typing import cast from graphql import ( GraphQLArgument, @@ -14,6 +15,7 @@ GraphQLObjectType, GraphQLSchema, GraphQLString, + IntrospectionQuery, get_introspection_query, graphql_sync, print_schema, @@ -271,6 +273,8 @@ async def resolve_review(review, _info, **_args): ) -StarWarsIntrospection = graphql_sync(StarWarsSchema, get_introspection_query()).data +StarWarsIntrospection = cast( + IntrospectionQuery, graphql_sync(StarWarsSchema, get_introspection_query()).data +) StarWarsTypeDef = print_schema(StarWarsSchema) diff --git a/tests/starwars/test_dsl.py b/tests/starwars/test_dsl.py index 5cd051ba..d96435fc 100644 --- a/tests/starwars/test_dsl.py +++ b/tests/starwars/test_dsl.py @@ -4,6 +4,7 @@ GraphQLError, GraphQLFloat, GraphQLID, + GraphQLInputObjectType, GraphQLInt, GraphQLList, GraphQLNonNull, @@ -53,6 +54,7 @@ def client(): def test_ast_from_value_with_input_type_and_not_mapping_value(): obj_type = StarWarsSchema.get_type("ReviewInput") + assert isinstance(obj_type, GraphQLInputObjectType) assert ast_from_value(8, obj_type) is None @@ -78,7 +80,7 @@ def test_ast_from_value_with_graphqlid(): def test_ast_from_value_with_invalid_type(): with pytest.raises(TypeError) as exc_info: - ast_from_value(4, None) + ast_from_value(4, None) # type: ignore assert "Unexpected input type: None." in str(exc_info.value) @@ -114,7 +116,10 @@ def test_ast_from_serialized_value_untyped_typeerror(): def test_variable_to_ast_type_passing_wrapping_type(): - wrapping_type = GraphQLNonNull(GraphQLList(StarWarsSchema.get_type("ReviewInput"))) + review_type = StarWarsSchema.get_type("ReviewInput") + assert isinstance(review_type, GraphQLInputObjectType) + + wrapping_type = GraphQLNonNull(GraphQLList(review_type)) variable = DSLVariable("review_input") ast = variable.to_ast_type(wrapping_type) assert ast == NonNullTypeNode( @@ -383,7 +388,7 @@ def test_fetch_luke_aliased(ds): assert query == str(query_dsl) -def test_fetch_name_aliased(ds: DSLSchema): +def test_fetch_name_aliased(ds: DSLSchema) -> None: query = """ human(id: "1000") { my_name: name @@ -394,7 +399,7 @@ def test_fetch_name_aliased(ds: DSLSchema): assert query == str(query_dsl) -def test_fetch_name_aliased_as_kwargs(ds: DSLSchema): +def test_fetch_name_aliased_as_kwargs(ds: DSLSchema) -> None: query = """ human(id: "1000") { my_name: name @@ -787,7 +792,7 @@ def test_dsl_query_all_fields_should_be_instances_of_DSLField(): TypeError, match="Fields should be instances of DSLSelectable. Received: ", ): - DSLQuery("I am a string") + DSLQuery("I am a string") # type: ignore def test_dsl_query_all_fields_should_correspond_to_the_root_type(ds): @@ -839,7 +844,7 @@ def test_dsl_gql_all_arguments_should_be_operations_or_fragments(): with pytest.raises( TypeError, match="Operations should be instances of DSLExecutable " ): - dsl_gql("I am a string") + dsl_gql("I am a string") # type: ignore def test_DSLSchema_requires_a_schema(client): diff --git a/tests/starwars/test_parse_results.py b/tests/starwars/test_parse_results.py index e8f3f8d4..8020b586 100644 --- a/tests/starwars/test_parse_results.py +++ b/tests/starwars/test_parse_results.py @@ -1,3 +1,5 @@ +from typing import Any, Dict + import pytest from graphql import GraphQLError @@ -87,7 +89,7 @@ def test_key_not_found_in_result(): # Backend returned an invalid result without the hero key # Should be impossible. In that case, we ignore the missing key - result = {} + result: Dict[str, Any] = {} parsed_result = parse_result(StarWarsSchema, query, result) diff --git a/tests/starwars/test_query.py b/tests/starwars/test_query.py index bf15e11a..7a2a8084 100644 --- a/tests/starwars/test_query.py +++ b/tests/starwars/test_query.py @@ -336,4 +336,4 @@ def test_query_from_source(client): def test_already_parsed_query(client): query = gql("{ hero { name } }") with pytest.raises(TypeError, match="must be passed as a string"): - gql(query) + gql(query) # type: ignore diff --git a/tests/starwars/test_validation.py b/tests/starwars/test_validation.py index 38676836..75ce4162 100644 --- a/tests/starwars/test_validation.py +++ b/tests/starwars/test_validation.py @@ -79,7 +79,7 @@ def introspection_schema_no_directives(): introspection = copy.deepcopy(StarWarsIntrospection) # Simulate no directives key - del introspection["__schema"]["directives"] + del introspection["__schema"]["directives"] # type: ignore return Client(introspection=introspection) @@ -108,7 +108,7 @@ def validation_errors(client, query): def test_incompatible_request_gql(client): with pytest.raises(TypeError): - gql(123) + gql(123) # type: ignore """ The error generated depends on graphql-core version @@ -253,7 +253,7 @@ def test_build_client_schema_invalid_introspection(): from gql.utilities import build_client_schema with pytest.raises(TypeError) as exc_info: - build_client_schema("blah") + build_client_schema("blah") # type: ignore assert ( "Invalid or incomplete introspection result. Ensure that you are passing the " diff --git a/tests/test_aiohttp.py b/tests/test_aiohttp.py index e843db6c..04417c4e 100644 --- a/tests/test_aiohttp.py +++ b/tests/test_aiohttp.py @@ -47,6 +47,7 @@ @pytest.mark.asyncio async def test_aiohttp_query(aiohttp_server): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -86,6 +87,7 @@ async def handler(request): @pytest.mark.asyncio async def test_aiohttp_ignore_backend_content_type(aiohttp_server): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -115,6 +117,7 @@ async def handler(request): @pytest.mark.asyncio async def test_aiohttp_cookies(aiohttp_server): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -148,6 +151,7 @@ async def handler(request): @pytest.mark.asyncio async def test_aiohttp_error_code_401(aiohttp_server): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -179,6 +183,7 @@ async def handler(request): @pytest.mark.asyncio async def test_aiohttp_error_code_429(aiohttp_server): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -226,6 +231,7 @@ async def handler(request): @pytest.mark.asyncio async def test_aiohttp_error_code_500(aiohttp_server): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -261,6 +267,7 @@ async def handler(request): @pytest.mark.parametrize("query_error", transport_query_error_responses) async def test_aiohttp_error_code(aiohttp_server, query_error): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -316,6 +323,7 @@ async def handler(request): @pytest.mark.parametrize("param", invalid_protocol_responses) async def test_aiohttp_invalid_protocol(aiohttp_server, param): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport response = param["response"] @@ -344,6 +352,7 @@ async def handler(request): @pytest.mark.asyncio async def test_aiohttp_subscribe_not_supported(aiohttp_server): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -369,6 +378,7 @@ async def handler(request): @pytest.mark.asyncio async def test_aiohttp_cannot_connect_twice(aiohttp_server): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -391,6 +401,7 @@ async def handler(request): @pytest.mark.asyncio async def test_aiohttp_cannot_execute_if_not_connected(aiohttp_server): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -413,6 +424,7 @@ async def handler(request): @pytest.mark.asyncio async def test_aiohttp_extra_args(aiohttp_server): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -460,6 +472,7 @@ async def handler(request): @pytest.mark.asyncio async def test_aiohttp_query_variable_values(aiohttp_server): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -496,6 +509,7 @@ async def test_aiohttp_query_variable_values_fix_issue_292(aiohttp_server): See https://github.com/graphql-python/gql/issues/292""" from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -526,6 +540,7 @@ async def handler(request): @pytest.mark.asyncio async def test_aiohttp_execute_running_in_thread(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -552,6 +567,7 @@ def test_code(): @pytest.mark.asyncio async def test_aiohttp_subscribe_running_in_thread(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -638,6 +654,7 @@ async def single_upload_handler(request): @pytest.mark.asyncio async def test_aiohttp_file_upload(aiohttp_server): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport app = web.Application() @@ -703,6 +720,7 @@ async def single_upload_handler_with_content_type(request): @pytest.mark.asyncio async def test_aiohttp_file_upload_with_content_type(aiohttp_server): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport app = web.Application() @@ -724,7 +742,7 @@ async def test_aiohttp_file_upload_with_content_type(aiohttp_server): with open(file_path, "rb") as f: # Setting the content_type - f.content_type = "application/pdf" + f.content_type = "application/pdf" # type: ignore params = {"file": f, "other_var": 42} @@ -741,6 +759,7 @@ async def test_aiohttp_file_upload_with_content_type(aiohttp_server): @pytest.mark.asyncio async def test_aiohttp_file_upload_without_session(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport app = web.Application() @@ -809,6 +828,7 @@ async def binary_upload_handler(request): @pytest.mark.asyncio async def test_aiohttp_binary_file_upload(aiohttp_server): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport app = web.Application() @@ -843,7 +863,8 @@ async def test_aiohttp_binary_file_upload(aiohttp_server): @pytest.mark.asyncio async def test_aiohttp_stream_reader_upload(aiohttp_server): - from aiohttp import web, ClientSession + from aiohttp import ClientSession, web + from gql.transport.aiohttp import AIOHTTPTransport async def binary_data_handler(request): @@ -882,6 +903,7 @@ async def binary_data_handler(request): async def test_aiohttp_async_generator_upload(aiohttp_server): import aiofiles from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport app = web.Application() @@ -944,6 +966,7 @@ async def file_sender(file_name): @pytest.mark.asyncio async def test_aiohttp_file_upload_two_files(aiohttp_server): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -1035,6 +1058,7 @@ async def handler(request): @pytest.mark.asyncio async def test_aiohttp_file_upload_list_of_two_files(aiohttp_server): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -1253,6 +1277,7 @@ async def handler(request): @pytest.mark.asyncio async def test_aiohttp_query_with_extensions(aiohttp_server): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -1282,6 +1307,7 @@ async def handler(request): @pytest.mark.parametrize("verify_https", ["disabled", "cert_provided"]) async def test_aiohttp_query_https(ssl_aiohttp_server, ssl_close_timeout, verify_https): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -1328,8 +1354,9 @@ async def handler(request): @pytest.mark.asyncio async def test_aiohttp_query_https_self_cert_fail(ssl_aiohttp_server): """By default, we should verify the ssl certificate""" - from aiohttp.client_exceptions import ClientConnectorCertificateError from aiohttp import web + from aiohttp.client_exceptions import ClientConnectorCertificateError + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -1361,6 +1388,7 @@ async def handler(request): @pytest.mark.asyncio async def test_aiohttp_query_https_self_cert_default(ssl_aiohttp_server): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -1382,6 +1410,7 @@ async def handler(request): @pytest.mark.asyncio async def test_aiohttp_error_fetching_schema(aiohttp_server): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport error_answer = """ @@ -1425,6 +1454,7 @@ async def handler(request): @pytest.mark.asyncio async def test_aiohttp_reconnecting_session(aiohttp_server): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -1463,6 +1493,7 @@ async def handler(request): @pytest.mark.parametrize("retries", [False, lambda e: e]) async def test_aiohttp_reconnecting_session_retries(aiohttp_server, retries): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -1496,6 +1527,7 @@ async def test_aiohttp_reconnecting_session_start_connecting_task_twice( aiohttp_server, caplog ): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -1529,6 +1561,7 @@ async def handler(request): @pytest.mark.asyncio async def test_aiohttp_json_serializer(aiohttp_server, caplog): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -1584,9 +1617,11 @@ async def handler(request): @pytest.mark.asyncio async def test_aiohttp_json_deserializer(aiohttp_server): - from aiohttp import web from decimal import Decimal from functools import partial + + from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -1623,7 +1658,8 @@ async def handler(request): @pytest.mark.asyncio async def test_aiohttp_connector_owner_false(aiohttp_server): - from aiohttp import web, TCPConnector + from aiohttp import TCPConnector, web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): diff --git a/tests/test_aiohttp_websocket_exceptions.py b/tests/test_aiohttp_websocket_exceptions.py index 801af6b9..86c502a9 100644 --- a/tests/test_aiohttp_websocket_exceptions.py +++ b/tests/test_aiohttp_websocket_exceptions.py @@ -179,7 +179,7 @@ async def monkey_patch_send_query( document, variable_values=None, operation_name=None, - ) -> int: + ): query_id = self.next_query_id self.next_query_id += 1 diff --git a/tests/test_aiohttp_websocket_graphqlws_subscription.py b/tests/test_aiohttp_websocket_graphqlws_subscription.py index 8863ead9..e8832217 100644 --- a/tests/test_aiohttp_websocket_graphqlws_subscription.py +++ b/tests/test_aiohttp_websocket_graphqlws_subscription.py @@ -8,6 +8,7 @@ from parse import search from gql import Client, gql +from gql.client import AsyncClientSession from gql.transport.exceptions import TransportConnectionFailed, TransportServerError from .conftest import MS, PyPy, WebSocketServerHelper @@ -763,6 +764,7 @@ def test_aiohttp_websocket_graphqlws_subscription_sync_graceful_shutdown( warnings.filterwarnings( "ignore", message="There is no current event loop" ) + assert isinstance(client.session, AsyncClientSession) asyncio.ensure_future( client.session._generator.athrow(KeyboardInterrupt) ) @@ -818,8 +820,8 @@ async def test_aiohttp_websocket_graphqlws_subscription_reconnecting_session( graphqlws_server, subscription_str, execute_instead_of_subscribe ): - from gql.transport.exceptions import TransportClosed from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + from gql.transport.exceptions import TransportClosed path = "/graphql" url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}{path}" diff --git a/tests/test_aiohttp_websocket_query.py b/tests/test_aiohttp_websocket_query.py index deb425f7..cf91d148 100644 --- a/tests/test_aiohttp_websocket_query.py +++ b/tests/test_aiohttp_websocket_query.py @@ -1,7 +1,7 @@ import asyncio import json import sys -from typing import Dict, Mapping +from typing import Any, Dict, Mapping import pytest @@ -66,7 +66,8 @@ async def test_aiohttp_websocket_starting_client_in_context_manager(aiohttp_ws_s ) assert transport.response_headers == {} - assert transport.headers["test"] == "1234" + assert isinstance(transport.headers, Mapping) + assert transport.headers["test"] == "1234" # type: ignore async with Client(transport=transport) as session: @@ -154,6 +155,7 @@ async def test_aiohttp_websocket_using_ssl_connection_self_cert_fail( ): from aiohttp.client_exceptions import ClientConnectorCertificateError + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport server = ws_ssl_server @@ -161,7 +163,7 @@ async def test_aiohttp_websocket_using_ssl_connection_self_cert_fail( url = f"wss://{server.hostname}:{server.port}/graphql" print(f"url = {url}") - extra_args = {} + extra_args: Dict[str, Any] = {} if verify_https == "explicitely_enabled": extra_args["ssl"] = True @@ -645,7 +647,6 @@ async def test_aiohttp_websocket_non_regression_bug_108( async def test_aiohttp_websocket_using_cli( aiohttp_ws_server, transport_arg, monkeypatch, capsys ): - """ Note: depending on the transport_arg parameter, if there is no transport argument, then we will use WebsocketsTransport if the websockets dependency is installed, diff --git a/tests/test_aiohttp_websocket_subscription.py b/tests/test_aiohttp_websocket_subscription.py index 5beb023e..83ae3589 100644 --- a/tests/test_aiohttp_websocket_subscription.py +++ b/tests/test_aiohttp_websocket_subscription.py @@ -9,6 +9,7 @@ from parse import search from gql import Client, gql +from gql.client import AsyncClientSession from gql.transport.exceptions import TransportConnectionFailed, TransportServerError from .conftest import MS, WebSocketServerHelper @@ -228,6 +229,7 @@ async def test_aiohttp_websocket_subscription_get_execution_result( async for result in session.subscribe(subscription, get_execution_result=True): assert isinstance(result, ExecutionResult) + assert result.data is not None number = result.data["number"] print(f"Number received: {number}") @@ -669,6 +671,7 @@ def test_aiohttp_websocket_subscription_sync_graceful_shutdown( warnings.filterwarnings( "ignore", message="There is no current event loop" ) + assert isinstance(client.session, AsyncClientSession) interrupt_task = asyncio.ensure_future( client.session._generator.athrow(KeyboardInterrupt) ) @@ -678,6 +681,7 @@ def test_aiohttp_websocket_subscription_sync_graceful_shutdown( assert count == 4 # Catch interrupt_task exception to remove warning + assert interrupt_task is not None interrupt_task.exception() # Check that the server received a connection_terminate message last diff --git a/tests/test_appsync_auth.py b/tests/test_appsync_auth.py index cb279ae5..8abb3410 100644 --- a/tests/test_appsync_auth.py +++ b/tests/test_appsync_auth.py @@ -23,6 +23,7 @@ def test_appsync_init_with_minimal_args(fake_session_factory): @pytest.mark.botocore def test_appsync_init_with_no_credentials(caplog, fake_session_factory): import botocore.exceptions + from gql.transport.appsync_websockets import AppSyncWebsocketsTransport with pytest.raises(botocore.exceptions.NoCredentialsError): @@ -72,6 +73,7 @@ def test_appsync_init_with_apikey_auth(): @pytest.mark.botocore def test_appsync_init_with_iam_auth_without_creds(fake_session_factory): import botocore.exceptions + from gql.transport.appsync_auth import AppSyncIAMAuthentication from gql.transport.appsync_websockets import AppSyncWebsocketsTransport @@ -108,10 +110,13 @@ def test_appsync_init_with_iam_auth_and_no_region( - you have the AWS_DEFAULT_REGION environment variable set """ - from gql.transport.appsync_websockets import AppSyncWebsocketsTransport - from botocore.exceptions import NoRegionError import logging + from botocore.exceptions import NoRegionError + + from gql.transport.appsync_auth import AppSyncIAMAuthentication + from gql.transport.appsync_websockets import AppSyncWebsocketsTransport + caplog.set_level(logging.WARNING) with pytest.raises(NoRegionError): @@ -120,6 +125,8 @@ def test_appsync_init_with_iam_auth_and_no_region( session._credentials.region = None transport = AppSyncWebsocketsTransport(url=mock_transport_url, session=session) + assert isinstance(transport.auth, AppSyncIAMAuthentication) + # prints the region name in case the test fails print(f"Region found: {transport.auth._region_name}") diff --git a/tests/test_appsync_http.py b/tests/test_appsync_http.py index 2a6c9ca7..536b2fe9 100644 --- a/tests/test_appsync_http.py +++ b/tests/test_appsync_http.py @@ -9,10 +9,12 @@ @pytest.mark.aiohttp @pytest.mark.botocore async def test_appsync_iam_mutation(aiohttp_server, fake_credentials_factory): + from urllib.parse import urlparse + from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport from gql.transport.appsync_auth import AppSyncIAMAuthentication - from urllib.parse import urlparse async def handler(request): data = { diff --git a/tests/test_appsync_websockets.py b/tests/test_appsync_websockets.py index 7aa96292..37cbe460 100644 --- a/tests/test_appsync_websockets.py +++ b/tests/test_appsync_websockets.py @@ -426,9 +426,10 @@ async def test_appsync_subscription_api_key(server): @pytest.mark.parametrize("server", [realtime_appsync_server], indirect=True) async def test_appsync_subscription_iam_with_token(server): + from botocore.credentials import Credentials + from gql.transport.appsync_auth import AppSyncIAMAuthentication from gql.transport.appsync_websockets import AppSyncWebsocketsTransport - from botocore.credentials import Credentials path = "/graphql" url = f"ws://{server.hostname}:{server.port}{path}" @@ -453,9 +454,10 @@ async def test_appsync_subscription_iam_with_token(server): @pytest.mark.parametrize("server", [realtime_appsync_server], indirect=True) async def test_appsync_subscription_iam_without_token(server): + from botocore.credentials import Credentials + from gql.transport.appsync_auth import AppSyncIAMAuthentication from gql.transport.appsync_websockets import AppSyncWebsocketsTransport - from botocore.credentials import Credentials path = "/graphql" url = f"ws://{server.hostname}:{server.port}{path}" @@ -479,9 +481,10 @@ async def test_appsync_subscription_iam_without_token(server): @pytest.mark.parametrize("server", [realtime_appsync_server], indirect=True) async def test_appsync_execute_method_not_allowed(server): + from botocore.credentials import Credentials + from gql.transport.appsync_auth import AppSyncIAMAuthentication from gql.transport.appsync_websockets import AppSyncWebsocketsTransport - from botocore.credentials import Credentials path = "/graphql" url = f"ws://{server.hostname}:{server.port}{path}" @@ -526,9 +529,10 @@ async def test_appsync_execute_method_not_allowed(server): @pytest.mark.botocore async def test_appsync_fetch_schema_from_transport_not_allowed(): + from botocore.credentials import Credentials + from gql.transport.appsync_auth import AppSyncIAMAuthentication from gql.transport.appsync_websockets import AppSyncWebsocketsTransport - from botocore.credentials import Credentials dummy_credentials = Credentials( access_key=DUMMY_ACCESS_KEY_ID, @@ -579,10 +583,11 @@ async def test_appsync_subscription_api_key_unauthorized(server): @pytest.mark.parametrize("server", [realtime_appsync_server], indirect=True) async def test_appsync_subscription_iam_not_allowed(server): + from botocore.credentials import Credentials + from gql.transport.appsync_auth import AppSyncIAMAuthentication from gql.transport.appsync_websockets import AppSyncWebsocketsTransport from gql.transport.exceptions import TransportQueryError - from botocore.credentials import Credentials path = "/graphql" url = f"ws://{server.hostname}:{server.port}{path}" diff --git a/tests/test_cli.py b/tests/test_cli.py index dccfcb5a..4c6b7d15 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -286,8 +286,8 @@ async def test_cli_main_appsync_websockets_iam(parser, url): ) def test_cli_get_transport_appsync_websockets_api_key(parser, url): - from gql.transport.appsync_websockets import AppSyncWebsocketsTransport from gql.transport.appsync_auth import AppSyncApiKeyAuthentication + from gql.transport.appsync_websockets import AppSyncWebsocketsTransport args = parser.parse_args( [url, "--transport", "appsync_websockets", "--api-key", "test-api-key"] @@ -307,8 +307,8 @@ def test_cli_get_transport_appsync_websockets_api_key(parser, url): ) def test_cli_get_transport_appsync_websockets_jwt(parser, url): - from gql.transport.appsync_websockets import AppSyncWebsocketsTransport from gql.transport.appsync_auth import AppSyncJWTAuthentication + from gql.transport.appsync_websockets import AppSyncWebsocketsTransport args = parser.parse_args( [url, "--transport", "appsync_websockets", "--jwt", "test-jwt"] diff --git a/tests/test_client.py b/tests/test_client.py index e5edec8b..8669b4a3 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,9 +1,10 @@ import os from contextlib import suppress +from typing import Any from unittest import mock import pytest -from graphql import build_ast_schema, parse +from graphql import DocumentNode, ExecutionResult, build_ast_schema, parse from gql import Client, GraphQLRequest, gql from gql.transport import Transport @@ -29,19 +30,27 @@ def http_transport_query(): def test_request_transport_not_implemented(http_transport_query): class RandomTransport(Transport): - def execute(self): - super().execute(http_transport_query) + pass - with pytest.raises(NotImplementedError) as exc_info: - RandomTransport().execute() + with pytest.raises(TypeError) as exc_info: + RandomTransport() # type: ignore - assert "Any Transport subclass must implement execute method" == str(exc_info.value) + assert "Can't instantiate abstract class RandomTransport" in str(exc_info.value) - with pytest.raises(NotImplementedError) as exc_info: - RandomTransport().execute_batch([]) + class RandomTransport2(Transport): + def execute( + self, + document: DocumentNode, + *args: Any, + **kwargs: Any, + ) -> ExecutionResult: + return ExecutionResult() + + with pytest.raises(NotImplementedError) as exc_info2: + RandomTransport2().execute_batch([]) assert "This Transport has not implemented the execute_batch method" == str( - exc_info.value + exc_info2.value ) @@ -70,7 +79,7 @@ def test_retries_on_transport(execute_mock): expected_retries = 3 execute_mock.side_effect = NewConnectionError( - "Should be HTTPConnection", "Fake connection error" + "Should be HTTPConnection", "Fake connection error" # type: ignore ) transport = RequestsHTTPTransport( url="http://127.0.0.1:8000/graphql", @@ -109,11 +118,10 @@ def test_retries_on_transport(execute_mock): assert execute_mock.call_count == expected_retries + 1 -def test_no_schema_exception(): +def test_no_schema_no_transport_exception(): with pytest.raises(AssertionError) as exc_info: - client = Client() - client.validate("") - assert "Cannot validate the document locally, you need to pass a schema." in str( + Client() + assert "You need to provide either a transport or a schema to the Client." in str( exc_info.value ) @@ -255,6 +263,7 @@ def test_sync_transport_close_on_schema_retrieval_failure(): # transport is closed afterwards pass + assert isinstance(client.transport, RequestsHTTPTransport) assert client.transport.session is None @@ -279,6 +288,7 @@ async def test_async_transport_close_on_schema_retrieval_failure(): # transport is closed afterwards pass + assert isinstance(client.transport, AIOHTTPTransport) assert client.transport.session is None import asyncio diff --git a/tests/test_graphqlws_subscription.py b/tests/test_graphqlws_subscription.py index 2735fbb0..94028d26 100644 --- a/tests/test_graphqlws_subscription.py +++ b/tests/test_graphqlws_subscription.py @@ -8,6 +8,7 @@ from parse import search from gql import Client, gql +from gql.client import AsyncClientSession from gql.transport.exceptions import TransportConnectionFailed, TransportServerError from .conftest import MS, PyPy, WebSocketServerHelper @@ -757,6 +758,7 @@ def test_graphqlws_subscription_sync_graceful_shutdown( warnings.filterwarnings( "ignore", message="There is no current event loop" ) + assert isinstance(client.session, AsyncClientSession) asyncio.ensure_future( client.session._generator.athrow(KeyboardInterrupt) ) @@ -812,8 +814,8 @@ async def test_graphqlws_subscription_reconnecting_session( graphqlws_server, subscription_str, execute_instead_of_subscribe ): - from gql.transport.websockets import WebsocketsTransport from gql.transport.exceptions import TransportClosed + from gql.transport.websockets import WebsocketsTransport path = "/graphql" url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}{path}" diff --git a/tests/test_httpx.py b/tests/test_httpx.py index c15872d7..43d74ec6 100644 --- a/tests/test_httpx.py +++ b/tests/test_httpx.py @@ -1,4 +1,4 @@ -from typing import Mapping +from typing import Any, Dict, Mapping import pytest @@ -38,6 +38,7 @@ @pytest.mark.asyncio async def test_httpx_query(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.httpx import HTTPXTransport async def handler(request): @@ -82,6 +83,7 @@ def test_code(): @pytest.mark.parametrize("verify_https", ["disabled", "cert_provided"]) async def test_httpx_query_https(ssl_aiohttp_server, run_sync_test, verify_https): from aiohttp import web + from gql.transport.httpx import HTTPXTransport async def handler(request): @@ -144,6 +146,7 @@ async def test_httpx_query_https_self_cert_fail( """By default, we should verify the ssl certificate""" from aiohttp import web from httpx import ConnectError + from gql.transport.httpx import HTTPXTransport async def handler(request): @@ -162,7 +165,7 @@ async def handler(request): assert str(url).startswith("https://") def test_code(): - extra_args = {} + extra_args: Dict[str, Any] = {} if verify_https == "explicitely_enabled": extra_args["verify"] = True @@ -191,6 +194,7 @@ def test_code(): @pytest.mark.asyncio async def test_httpx_cookies(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.httpx import HTTPXTransport async def handler(request): @@ -228,6 +232,7 @@ def test_code(): @pytest.mark.asyncio async def test_httpx_error_code_401(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.httpx import HTTPXTransport async def handler(request): @@ -263,6 +268,7 @@ def test_code(): @pytest.mark.asyncio async def test_httpx_error_code_429(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.httpx import HTTPXTransport async def handler(request): @@ -312,6 +318,7 @@ def test_code(): @pytest.mark.asyncio async def test_httpx_error_code_500(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.httpx import HTTPXTransport async def handler(request): @@ -344,6 +351,7 @@ def test_code(): @pytest.mark.asyncio async def test_httpx_error_code(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.httpx import HTTPXTransport async def handler(request): @@ -382,6 +390,7 @@ def test_code(): @pytest.mark.parametrize("response", invalid_protocol_responses) async def test_httpx_invalid_protocol(aiohttp_server, response, run_sync_test): from aiohttp import web + from gql.transport.httpx import HTTPXTransport async def handler(request): @@ -410,6 +419,7 @@ def test_code(): @pytest.mark.asyncio async def test_httpx_cannot_connect_twice(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.httpx import HTTPXTransport async def handler(request): @@ -436,6 +446,7 @@ def test_code(): @pytest.mark.asyncio async def test_httpx_cannot_execute_if_not_connected(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.httpx import HTTPXTransport async def handler(request): @@ -473,6 +484,7 @@ def test_code(): @pytest.mark.asyncio async def test_httpx_query_with_extensions(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.httpx import HTTPXTransport async def handler(request): @@ -528,6 +540,7 @@ def test_code(): @pytest.mark.asyncio async def test_httpx_file_upload(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.httpx import HTTPXTransport async def single_upload_handler(request): @@ -588,6 +601,7 @@ def test_code(): @pytest.mark.asyncio async def test_httpx_file_upload_with_content_type(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.httpx import HTTPXTransport async def single_upload_handler(request): @@ -638,7 +652,7 @@ def test_code(): with open(file_path, "rb") as f: # Setting the content_type - f.content_type = "application/pdf" + f.content_type = "application/pdf" # type: ignore params = {"file": f, "other_var": 42} execution_result = session._execute( @@ -654,6 +668,7 @@ def test_code(): @pytest.mark.asyncio async def test_httpx_file_upload_additional_headers(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.httpx import HTTPXTransport async def single_upload_handler(request): @@ -716,6 +731,7 @@ def test_code(): @pytest.mark.asyncio async def test_httpx_binary_file_upload(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.httpx import HTTPXTransport # This is a sample binary file content containing all possible byte values @@ -789,6 +805,7 @@ def test_code(): @pytest.mark.asyncio async def test_httpx_file_upload_two_files(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.httpx import HTTPXTransport file_upload_mutation_2 = """ @@ -887,6 +904,7 @@ def test_code(): @pytest.mark.asyncio async def test_httpx_file_upload_list_of_two_files(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.httpx import HTTPXTransport file_upload_mutation_3 = """ @@ -976,6 +994,7 @@ def test_code(): @pytest.mark.asyncio async def test_httpx_error_fetching_schema(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.httpx import HTTPXTransport error_answer = """ diff --git a/tests/test_httpx_async.py b/tests/test_httpx_async.py index 44764ea4..49ea6a24 100644 --- a/tests/test_httpx_async.py +++ b/tests/test_httpx_async.py @@ -1,6 +1,6 @@ import io import json -from typing import Mapping +from typing import Any, Dict, Mapping import pytest @@ -48,6 +48,7 @@ @pytest.mark.asyncio async def test_httpx_query(aiohttp_server): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -88,6 +89,7 @@ async def handler(request): @pytest.mark.asyncio async def test_httpx_ignore_backend_content_type(aiohttp_server): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -118,6 +120,7 @@ async def handler(request): @pytest.mark.asyncio async def test_httpx_cookies(aiohttp_server): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -152,6 +155,7 @@ async def handler(request): @pytest.mark.asyncio async def test_httpx_error_code_401(aiohttp_server): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -184,6 +188,7 @@ async def handler(request): @pytest.mark.asyncio async def test_httpx_error_code_429(aiohttp_server): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -232,6 +237,7 @@ async def handler(request): @pytest.mark.asyncio async def test_httpx_error_code_500(aiohttp_server): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -268,6 +274,7 @@ async def handler(request): @pytest.mark.parametrize("query_error", transport_query_error_responses) async def test_httpx_error_code(aiohttp_server, query_error): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -324,6 +331,7 @@ async def handler(request): @pytest.mark.parametrize("param", invalid_protocol_responses) async def test_httpx_invalid_protocol(aiohttp_server, param): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport response = param["response"] @@ -353,6 +361,7 @@ async def handler(request): @pytest.mark.asyncio async def test_httpx_subscribe_not_supported(aiohttp_server): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -379,6 +388,7 @@ async def handler(request): @pytest.mark.asyncio async def test_httpx_cannot_connect_twice(aiohttp_server): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -402,6 +412,7 @@ async def handler(request): @pytest.mark.asyncio async def test_httpx_cannot_execute_if_not_connected(aiohttp_server): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -424,9 +435,10 @@ async def handler(request): @pytest.mark.aiohttp @pytest.mark.asyncio async def test_httpx_extra_args(aiohttp_server): + import httpx from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport - import httpx async def handler(request): return web.Response(text=query1_server_answer, content_type="application/json") @@ -438,8 +450,8 @@ async def handler(request): url = str(server.make_url("/")) # passing extra arguments to httpx.AsyncClient - transport = httpx.AsyncHTTPTransport(retries=2) - transport = HTTPXAsyncTransport(url=url, max_redirects=2, transport=transport) + inner_transport = httpx.AsyncHTTPTransport(retries=2) + transport = HTTPXAsyncTransport(url=url, max_redirects=2, transport=inner_transport) async with Client(transport=transport) as session: @@ -470,6 +482,7 @@ async def handler(request): @pytest.mark.asyncio async def test_httpx_query_variable_values(aiohttp_server): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -507,6 +520,7 @@ async def test_httpx_query_variable_values_fix_issue_292(aiohttp_server): See https://github.com/graphql-python/gql/issues/292""" from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -538,6 +552,7 @@ async def handler(request): @pytest.mark.asyncio async def test_httpx_execute_running_in_thread(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -565,6 +580,7 @@ def test_code(): @pytest.mark.asyncio async def test_httpx_subscribe_running_in_thread(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -652,6 +668,7 @@ async def single_upload_handler(request): @pytest.mark.asyncio async def test_httpx_file_upload(aiohttp_server): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport app = web.Application() @@ -688,6 +705,7 @@ async def test_httpx_file_upload(aiohttp_server): @pytest.mark.asyncio async def test_httpx_file_upload_without_session(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport app = web.Application() @@ -757,6 +775,7 @@ async def binary_upload_handler(request): @pytest.mark.asyncio async def test_httpx_binary_file_upload(aiohttp_server): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport app = web.Application() @@ -815,6 +834,7 @@ async def test_httpx_binary_file_upload(aiohttp_server): @pytest.mark.asyncio async def test_httpx_file_upload_two_files(aiohttp_server): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -907,6 +927,7 @@ async def handler(request): @pytest.mark.asyncio async def test_httpx_file_upload_list_of_two_files(aiohttp_server): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -1130,6 +1151,7 @@ async def handler(request): @pytest.mark.asyncio async def test_httpx_query_with_extensions(aiohttp_server): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -1159,6 +1181,7 @@ async def handler(request): @pytest.mark.parametrize("verify_https", ["disabled", "cert_provided"]) async def test_httpx_query_https(ssl_aiohttp_server, verify_https): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -1202,9 +1225,10 @@ async def handler(request): @pytest.mark.parametrize("verify_https", ["explicitely_enabled", "default"]) async def test_httpx_query_https_self_cert_fail(ssl_aiohttp_server, verify_https): from aiohttp import web - from gql.transport.httpx import HTTPXAsyncTransport from httpx import ConnectError + from gql.transport.httpx import HTTPXAsyncTransport + async def handler(request): return web.Response(text=query1_server_answer, content_type="application/json") @@ -1216,7 +1240,7 @@ async def handler(request): assert url.startswith("https://") - extra_args = {} + extra_args: Dict[str, Any] = {} if verify_https == "explicitely_enabled": extra_args["verify"] = True @@ -1240,6 +1264,7 @@ async def handler(request): @pytest.mark.asyncio async def test_httpx_error_fetching_schema(aiohttp_server): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport error_answer = """ @@ -1284,6 +1309,7 @@ async def handler(request): @pytest.mark.asyncio async def test_httpx_reconnecting_session(aiohttp_server): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -1323,6 +1349,7 @@ async def handler(request): @pytest.mark.parametrize("retries", [False, lambda e: e]) async def test_httpx_reconnecting_session_retries(aiohttp_server, retries): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -1357,6 +1384,7 @@ async def test_httpx_reconnecting_session_start_connecting_task_twice( aiohttp_server, caplog ): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -1391,6 +1419,7 @@ async def handler(request): @pytest.mark.asyncio async def test_httpx_json_serializer(aiohttp_server, caplog): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -1447,9 +1476,11 @@ async def handler(request): @pytest.mark.aiohttp @pytest.mark.asyncio async def test_httpx_json_deserializer(aiohttp_server): - from aiohttp import web from decimal import Decimal from functools import partial + + from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): diff --git a/tests/test_phoenix_channel_exceptions.py b/tests/test_phoenix_channel_exceptions.py index 2a312d71..09c129b3 100644 --- a/tests/test_phoenix_channel_exceptions.py +++ b/tests/test_phoenix_channel_exceptions.py @@ -19,9 +19,7 @@ def ensure_list(s): return ( s if s is None or isinstance(s, list) - else list(s) - if isinstance(s, tuple) - else [s] + else list(s) if isinstance(s, tuple) else [s] ) @@ -360,9 +358,10 @@ def subscription_server( data_answers=default_subscription_data_answer, unsubscribe_answers=default_subscription_unsubscribe_answer, ): - from .conftest import PhoenixChannelServerHelper import json + from .conftest import PhoenixChannelServerHelper + async def phoenix_server(ws): await PhoenixChannelServerHelper.send_connection_ack(ws) await ws.recv() diff --git a/tests/test_phoenix_channel_query.py b/tests/test_phoenix_channel_query.py index 621f648e..7dff7062 100644 --- a/tests/test_phoenix_channel_query.py +++ b/tests/test_phoenix_channel_query.py @@ -110,10 +110,11 @@ async def test_phoenix_channel_query_ssl(ws_ssl_server, query_str): async def test_phoenix_channel_query_ssl_self_cert_fail( ws_ssl_server, query_str, verify_https ): + from ssl import SSLCertVerificationError + from gql.transport.phoenix_channel_websockets import ( PhoenixChannelWebsocketsTransport, ) - from ssl import SSLCertVerificationError path = "/graphql" server = ws_ssl_server diff --git a/tests/test_requests.py b/tests/test_requests.py index 8f3b0b7a..9c0334bd 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -1,4 +1,4 @@ -from typing import Mapping +from typing import Any, Dict, Mapping import pytest @@ -42,6 +42,7 @@ @pytest.mark.asyncio async def test_requests_query(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -85,9 +86,11 @@ def test_code(): @pytest.mark.asyncio @pytest.mark.parametrize("verify_https", ["disabled", "cert_provided"]) async def test_requests_query_https(ssl_aiohttp_server, run_sync_test, verify_https): + import warnings + from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport - import warnings async def handler(request): return web.Response( @@ -151,9 +154,10 @@ async def test_requests_query_https_self_cert_fail( ): """By default, we should verify the ssl certificate""" from aiohttp import web - from gql.transport.requests import RequestsHTTPTransport from requests.exceptions import SSLError + from gql.transport.requests import RequestsHTTPTransport + async def handler(request): return web.Response( text=query1_server_answer, @@ -168,7 +172,7 @@ async def handler(request): url = server.make_url("/") def test_code(): - extra_args = {} + extra_args: Dict[str, Any] = {} if verify_https == "explicitely_enabled": extra_args["verify"] = True @@ -197,6 +201,7 @@ def test_code(): @pytest.mark.asyncio async def test_requests_cookies(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -234,6 +239,7 @@ def test_code(): @pytest.mark.asyncio async def test_requests_error_code_401(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -269,6 +275,7 @@ def test_code(): @pytest.mark.asyncio async def test_requests_error_code_429(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -318,6 +325,7 @@ def test_code(): @pytest.mark.asyncio async def test_requests_error_code_500(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -350,6 +358,7 @@ def test_code(): @pytest.mark.asyncio async def test_requests_error_code(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -388,6 +397,7 @@ def test_code(): @pytest.mark.parametrize("response", invalid_protocol_responses) async def test_requests_invalid_protocol(aiohttp_server, response, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -416,6 +426,7 @@ def test_code(): @pytest.mark.asyncio async def test_requests_cannot_connect_twice(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -442,6 +453,7 @@ def test_code(): @pytest.mark.asyncio async def test_requests_cannot_execute_if_not_connected(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -479,6 +491,7 @@ def test_code(): @pytest.mark.asyncio async def test_requests_query_with_extensions(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -534,6 +547,7 @@ def test_code(): @pytest.mark.asyncio async def test_requests_file_upload(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def single_upload_handler(request): @@ -594,6 +608,7 @@ def test_code(): @pytest.mark.asyncio async def test_requests_file_upload_with_content_type(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def single_upload_handler(request): @@ -644,7 +659,7 @@ def test_code(): with open(file_path, "rb") as f: # Setting the content_type - f.content_type = "application/pdf" + f.content_type = "application/pdf" # type: ignore params = {"file": f, "other_var": 42} execution_result = session._execute( @@ -660,6 +675,7 @@ def test_code(): @pytest.mark.asyncio async def test_requests_file_upload_additional_headers(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def single_upload_handler(request): @@ -722,6 +738,7 @@ def test_code(): @pytest.mark.asyncio async def test_requests_binary_file_upload(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport # This is a sample binary file content containing all possible byte values @@ -795,6 +812,7 @@ def test_code(): @pytest.mark.asyncio async def test_requests_file_upload_two_files(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport file_upload_mutation_2 = """ @@ -893,6 +911,7 @@ def test_code(): @pytest.mark.asyncio async def test_requests_file_upload_list_of_two_files(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport file_upload_mutation_3 = """ @@ -982,6 +1001,7 @@ def test_code(): @pytest.mark.asyncio async def test_requests_error_fetching_schema(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport error_answer = """ @@ -1029,7 +1049,9 @@ def test_code(): @pytest.mark.asyncio async def test_requests_json_serializer(aiohttp_server, run_sync_test, caplog): import json + from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -1089,9 +1111,11 @@ def test_code(): @pytest.mark.asyncio async def test_requests_json_deserializer(aiohttp_server, run_sync_test): import json - from aiohttp import web from decimal import Decimal from functools import partial + + from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): diff --git a/tests/test_requests_batch.py b/tests/test_requests_batch.py index dbd3dfa5..4b9e09b8 100644 --- a/tests/test_requests_batch.py +++ b/tests/test_requests_batch.py @@ -50,6 +50,7 @@ @pytest.mark.asyncio async def test_requests_query(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -93,6 +94,7 @@ def test_code(): @pytest.mark.asyncio async def test_requests_query_auto_batch_enabled(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -140,9 +142,11 @@ def test_code(): async def test_requests_query_auto_batch_enabled_two_requests( aiohttp_server, run_sync_test ): + from threading import Thread + from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport - from threading import Thread async def handler(request): return web.Response( @@ -199,6 +203,7 @@ def test_thread(): @pytest.mark.asyncio async def test_requests_cookies(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -238,6 +243,7 @@ def test_code(): @pytest.mark.asyncio async def test_requests_error_code_401(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -275,6 +281,7 @@ async def test_requests_error_code_401_auto_batch_enabled( aiohttp_server, run_sync_test ): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -313,6 +320,7 @@ def test_code(): @pytest.mark.asyncio async def test_requests_error_code_429(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -362,6 +370,7 @@ def test_code(): @pytest.mark.asyncio async def test_requests_error_code_500(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -394,6 +403,7 @@ def test_code(): @pytest.mark.asyncio async def test_requests_error_code(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -437,6 +447,7 @@ def test_code(): @pytest.mark.parametrize("response", invalid_protocol_responses) async def test_requests_invalid_protocol(aiohttp_server, response, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -465,6 +476,7 @@ def test_code(): @pytest.mark.asyncio async def test_requests_cannot_execute_if_not_connected(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -504,6 +516,7 @@ def test_code(): @pytest.mark.asyncio async def test_requests_query_with_extensions(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -543,6 +556,7 @@ def test_code(): def test_requests_sync_batch_auto(): from threading import Thread + from gql.transport.requests import RequestsHTTPTransport client = Client( diff --git a/tests/test_transport.py b/tests/test_transport.py index d9a3eced..e554955a 100644 --- a/tests/test_transport.py +++ b/tests/test_transport.py @@ -28,6 +28,7 @@ def use_cassette(name): @pytest.fixture def client(): import requests + from gql.transport.requests import RequestsHTTPTransport with use_cassette("client"): diff --git a/tests/test_transport_batch.py b/tests/test_transport_batch.py index a9b21e6a..7c108ec3 100644 --- a/tests/test_transport_batch.py +++ b/tests/test_transport_batch.py @@ -28,6 +28,7 @@ def use_cassette(name): @pytest.fixture def client(): import requests + from gql.transport.requests import RequestsHTTPTransport with use_cassette("client"): diff --git a/tests/test_websocket_exceptions.py b/tests/test_websocket_exceptions.py index 9c43965f..08058aea 100644 --- a/tests/test_websocket_exceptions.py +++ b/tests/test_websocket_exceptions.py @@ -175,7 +175,7 @@ async def monkey_patch_send_query( document, variable_values=None, operation_name=None, - ) -> int: + ): query_id = self.next_query_id self.next_query_id += 1 @@ -366,9 +366,10 @@ async def test_websocket_using_cli_invalid_query(server, monkeypatch, capsys): url = f"ws://{server.hostname}:{server.port}/graphql" print(f"url = {url}") - from gql.cli import main, get_parser import io + from gql.cli import get_parser, main + parser = get_parser(with_examples=True) args = parser.parse_args([url]) diff --git a/tests/test_websocket_online.py b/tests/test_websocket_online.py index fa288b6d..c53be5f4 100644 --- a/tests/test_websocket_online.py +++ b/tests/test_websocket_online.py @@ -27,12 +27,10 @@ async def test_websocket_simple_query(): from gql.transport.websockets import WebsocketsTransport # Get Websockets transport - sample_transport = WebsocketsTransport( - url="wss://countries.trevorblades.com/graphql" - ) + transport = WebsocketsTransport(url="wss://countries.trevorblades.com/graphql") # Instanciate client - async with Client(transport=sample_transport) as session: + async with Client(transport=transport) as session: query = gql( """ @@ -68,12 +66,12 @@ async def test_websocket_invalid_query(): from gql.transport.websockets import WebsocketsTransport # Get Websockets transport - sample_transport = WebsocketsTransport( + transport = WebsocketsTransport( url="wss://countries.trevorblades.com/graphql", ssl=True ) # Instanciate client - async with Client(transport=sample_transport) as session: + async with Client(transport=transport) as session: query = gql( """ @@ -98,12 +96,12 @@ async def test_websocket_sending_invalid_data(): from gql.transport.websockets import WebsocketsTransport # Get Websockets transport - sample_transport = WebsocketsTransport( + transport = WebsocketsTransport( url="wss://countries.trevorblades.com/graphql", ssl=True ) # Instanciate client - async with Client(transport=sample_transport) as session: + async with Client(transport=transport) as session: query = gql( """ @@ -122,7 +120,8 @@ async def test_websocket_sending_invalid_data(): invalid_data = "QSDF" print(f">>> {invalid_data}") - await sample_transport.websocket.send(invalid_data) + assert transport.adapter.websocket is not None + await transport.adapter.websocket.send(invalid_data) await asyncio.sleep(2) @@ -134,17 +133,18 @@ async def test_websocket_sending_invalid_payload(): from gql.transport.websockets import WebsocketsTransport # Get Websockets transport - sample_transport = WebsocketsTransport( + transport = WebsocketsTransport( url="wss://countries.trevorblades.com/graphql", ssl=True ) # Instanciate client - async with Client(transport=sample_transport): + async with Client(transport=transport): invalid_payload = '{"id": "1", "type": "start", "payload": "BLAHBLAH"}' print(f">>> {invalid_payload}") - await sample_transport.websocket.send(invalid_payload) + assert transport.adapter.websocket is not None + await transport.adapter.websocket.send(invalid_payload) await asyncio.sleep(2) @@ -156,12 +156,12 @@ async def test_websocket_sending_invalid_data_while_other_query_is_running(): from gql.transport.websockets import WebsocketsTransport # Get Websockets transport - sample_transport = WebsocketsTransport( + transport = WebsocketsTransport( url="wss://countries.trevorblades.com/graphql", ssl=True ) # Instanciate client - async with Client(transport=sample_transport) as session: + async with Client(transport=transport) as session: query = gql( """ @@ -190,7 +190,8 @@ async def query_task2(): invalid_data = "QSDF" print(f">>> {invalid_data}") - await sample_transport.websocket.send(invalid_data) + assert transport.adapter.websocket is not None + await transport.adapter.websocket.send(invalid_data) task1 = asyncio.create_task(query_task1()) task2 = asyncio.create_task(query_task2()) @@ -207,12 +208,12 @@ async def test_websocket_two_queries_in_parallel_using_two_tasks(): from gql.transport.websockets import WebsocketsTransport # Get Websockets transport - sample_transport = WebsocketsTransport( + transport = WebsocketsTransport( url="wss://countries.trevorblades.com/graphql", ssl=True ) # Instanciate client - async with Client(transport=sample_transport) as session: + async with Client(transport=transport) as session: query1 = gql( """ diff --git a/tests/test_websocket_query.py b/tests/test_websocket_query.py index 919f6bdb..99ff7334 100644 --- a/tests/test_websocket_query.py +++ b/tests/test_websocket_query.py @@ -1,7 +1,7 @@ import asyncio import json import sys -from typing import Dict, Mapping +from typing import Any, Dict, Mapping import pytest @@ -60,6 +60,7 @@ async def test_websocket_starting_client_in_context_manager(server): transport = WebsocketsTransport(url=url, headers={"test": "1234"}) assert transport.response_headers == {} + assert isinstance(transport.headers, Mapping) assert transport.headers["test"] == "1234" async with Client(transport=transport) as session: @@ -93,6 +94,7 @@ async def test_websocket_starting_client_in_context_manager(server): @pytest.mark.parametrize("ws_ssl_server", [server1_answers], indirect=True) async def test_websocket_using_ssl_connection(ws_ssl_server): import websockets + from gql.transport.websockets import WebsocketsTransport server = ws_ssl_server @@ -138,15 +140,16 @@ async def test_websocket_using_ssl_connection(ws_ssl_server): async def test_websocket_using_ssl_connection_self_cert_fail( ws_ssl_server, verify_https ): - from gql.transport.websockets import WebsocketsTransport from ssl import SSLCertVerificationError + from gql.transport.websockets import WebsocketsTransport + server = ws_ssl_server url = f"wss://{server.hostname}:{server.port}/graphql" print(f"url = {url}") - extra_args = {} + extra_args: Dict[str, Any] = {} if verify_https == "explicitely_enabled": extra_args["ssl"] = True @@ -585,10 +588,11 @@ async def test_websocket_using_cli(server, monkeypatch, capsys): url = f"ws://{server.hostname}:{server.port}/graphql" print(f"url = {url}") - from gql.cli import main, get_parser import io import json + from gql.cli import get_parser, main + parser = get_parser(with_examples=True) args = parser.parse_args([url]) diff --git a/tests/test_websocket_subscription.py b/tests/test_websocket_subscription.py index a020e1f5..89acd635 100644 --- a/tests/test_websocket_subscription.py +++ b/tests/test_websocket_subscription.py @@ -9,6 +9,7 @@ from parse import search from gql import Client, gql +from gql.client import AsyncClientSession from gql.transport.exceptions import TransportConnectionFailed, TransportServerError from .conftest import MS, PyPy, WebSocketServerHelper @@ -160,6 +161,7 @@ async def test_websocket_subscription_get_execution_result( assert isinstance(result, ExecutionResult) + assert result.data is not None number = result.data["number"] print(f"Number received: {number}") @@ -600,6 +602,7 @@ def test_websocket_subscription_sync_graceful_shutdown(server, subscription_str) warnings.filterwarnings( "ignore", message="There is no current event loop" ) + assert isinstance(client.session, AsyncClientSession) interrupt_task = asyncio.ensure_future( client.session._generator.athrow(KeyboardInterrupt) ) @@ -609,6 +612,7 @@ def test_websocket_subscription_sync_graceful_shutdown(server, subscription_str) assert count == 4 # Catch interrupt_task exception to remove warning + assert interrupt_task is not None interrupt_task.exception() # Check that the server received a connection_terminate message last diff --git a/tests/test_websockets_adapter.py b/tests/test_websockets_adapter.py index f070f497..f0448c79 100644 --- a/tests/test_websockets_adapter.py +++ b/tests/test_websockets_adapter.py @@ -1,4 +1,5 @@ import json +from typing import Mapping import pytest from graphql import print_ast @@ -73,11 +74,12 @@ async def test_websockets_adapter_edge_cases(server): query = print_ast(gql(query1_str)) print("query=", query) - adapter = WebSocketsAdapter(url, headers={"a": 1}, ssl=False, connect_args={}) + adapter = WebSocketsAdapter(url, headers={"a": "r1"}, ssl=False, connect_args={}) await adapter.connect() - assert adapter.headers["a"] == 1 + assert isinstance(adapter.headers, Mapping) + assert adapter.headers["a"] == "r1" assert adapter.ssl is False assert adapter.connect_args == {} assert adapter.response_headers["dummy"] == "test1234" diff --git a/tox.ini b/tox.ini index 8796357b..f6d4b48e 100644 --- a/tox.ini +++ b/tox.ini @@ -47,7 +47,7 @@ commands = basepython = python deps = -e.[dev] commands = - isort --recursive --check-only --diff gql tests + isort --check-only --diff gql tests [testenv:mypy] basepython = python