From 19fd8107576ad026903d0256377708c7a6272f41 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Thu, 6 Mar 2025 18:34:28 +0100 Subject: [PATCH 1/9] Put ListenerQueue in separate file --- gql/transport/aiohttp_websockets.py | 61 ++----------------- gql/transport/websockets_base.py | 55 +---------------- gql/transport/websockets_common/__init__.py | 3 + .../websockets_common/listener_queue.py | 58 ++++++++++++++++++ 4 files changed, 66 insertions(+), 111 deletions(-) create mode 100644 gql/transport/websockets_common/__init__.py create mode 100644 gql/transport/websockets_common/listener_queue.py diff --git a/gql/transport/aiohttp_websockets.py b/gql/transport/aiohttp_websockets.py index 18699b5e..9b84bd9b 100644 --- a/gql/transport/aiohttp_websockets.py +++ b/gql/transport/aiohttp_websockets.py @@ -22,72 +22,19 @@ from graphql import DocumentNode, ExecutionResult, print_ast from multidict import CIMultiDictProxy -from gql.transport.aiohttp import AIOHTTPTransport -from gql.transport.async_transport import AsyncTransport -from gql.transport.exceptions import ( +from .aiohttp import AIOHTTPTransport +from .async_transport import AsyncTransport +from .exceptions import ( TransportAlreadyConnected, TransportClosed, TransportProtocolError, TransportQueryError, TransportServerError, ) +from .websockets_common import ListenerQueue log = logging.getLogger("gql.transport.aiohttp_websockets") -ParsedAnswer = Tuple[str, Optional[ExecutionResult]] - - -class ListenerQueue: - """Special queue used for each query waiting for server answers - - If the server is stopped while the listener is still waiting, - Then we send an exception to the queue and this exception will be raised - to the consumer once all the previous messages have been consumed from the queue - """ - - def __init__(self, query_id: int, send_stop: bool) -> None: - self.query_id: int = query_id - self.send_stop: bool = send_stop - self._queue: asyncio.Queue = asyncio.Queue() - self._closed: bool = False - - async def get(self) -> ParsedAnswer: - - try: - item = self._queue.get_nowait() - except asyncio.QueueEmpty: - item = await self._queue.get() - - self._queue.task_done() - - # If we receive an exception when reading the queue, we raise it - if isinstance(item, Exception): - self._closed = True - raise item - - # Don't need to save new answers or - # send the stop message if we already received the complete message - answer_type, execution_result = item - if answer_type == "complete": - self.send_stop = False - self._closed = True - - return item - - async def put(self, item: ParsedAnswer) -> None: - - if not self._closed: - await self._queue.put(item) - - async def set_exception(self, exception: Exception) -> None: - - # Put the exception in the queue - await self._queue.put(exception) - - # Don't need to send stop messages in case of error - self.send_stop = False - self._closed = True - class AIOHTTPWebsocketsTransport(AsyncTransport): diff --git a/gql/transport/websockets_base.py b/gql/transport/websockets_base.py index accca275..f8694c16 100644 --- a/gql/transport/websockets_base.py +++ b/gql/transport/websockets_base.py @@ -21,63 +21,10 @@ TransportQueryError, TransportServerError, ) +from .websockets_common import ListenerQueue log = logging.getLogger("gql.transport.websockets") -ParsedAnswer = Tuple[str, Optional[ExecutionResult]] - - -class ListenerQueue: - """Special queue used for each query waiting for server answers - - If the server is stopped while the listener is still waiting, - Then we send an exception to the queue and this exception will be raised - to the consumer once all the previous messages have been consumed from the queue - """ - - def __init__(self, query_id: int, send_stop: bool) -> None: - self.query_id: int = query_id - self.send_stop: bool = send_stop - self._queue: asyncio.Queue = asyncio.Queue() - self._closed: bool = False - - async def get(self) -> ParsedAnswer: - - try: - item = self._queue.get_nowait() - except asyncio.QueueEmpty: - item = await self._queue.get() - - self._queue.task_done() - - # If we receive an exception when reading the queue, we raise it - if isinstance(item, Exception): - self._closed = True - raise item - - # Don't need to save new answers or - # send the stop message if we already received the complete message - answer_type, execution_result = item - if answer_type == "complete": - self.send_stop = False - self._closed = True - - return item - - async def put(self, item: ParsedAnswer) -> None: - - if not self._closed: - await self._queue.put(item) - - async def set_exception(self, exception: Exception) -> None: - - # Put the exception in the queue - await self._queue.put(exception) - - # Don't need to send stop messages in case of error - self.send_stop = False - self._closed = True - class WebsocketsTransportBase(AsyncTransport): """abstract :ref:`Async Transport ` used to implement diff --git a/gql/transport/websockets_common/__init__.py b/gql/transport/websockets_common/__init__.py new file mode 100644 index 00000000..7661cf87 --- /dev/null +++ b/gql/transport/websockets_common/__init__.py @@ -0,0 +1,3 @@ +from .listener_queue import ListenerQueue, ParsedAnswer + +__all__ = ["ListenerQueue", "ParsedAnswer"] diff --git a/gql/transport/websockets_common/listener_queue.py b/gql/transport/websockets_common/listener_queue.py new file mode 100644 index 00000000..54aa650f --- /dev/null +++ b/gql/transport/websockets_common/listener_queue.py @@ -0,0 +1,58 @@ +import asyncio +from typing import Optional, Tuple + +from graphql import ExecutionResult + +ParsedAnswer = Tuple[str, Optional[ExecutionResult]] + + +class ListenerQueue: + """Special queue used for each query waiting for server answers + + If the server is stopped while the listener is still waiting, + Then we send an exception to the queue and this exception will be raised + to the consumer once all the previous messages have been consumed from the queue + """ + + def __init__(self, query_id: int, send_stop: bool) -> None: + self.query_id: int = query_id + self.send_stop: bool = send_stop + self._queue: asyncio.Queue = asyncio.Queue() + self._closed: bool = False + + async def get(self) -> ParsedAnswer: + + try: + item = self._queue.get_nowait() + except asyncio.QueueEmpty: + item = await self._queue.get() + + self._queue.task_done() + + # If we receive an exception when reading the queue, we raise it + if isinstance(item, Exception): + self._closed = True + raise item + + # Don't need to save new answers or + # send the stop message if we already received the complete message + answer_type, execution_result = item + if answer_type == "complete": + self.send_stop = False + self._closed = True + + return item + + async def put(self, item: ParsedAnswer) -> None: + + if not self._closed: + await self._queue.put(item) + + async def set_exception(self, exception: Exception) -> None: + + # Put the exception in the queue + await self._queue.put(exception) + + # Don't need to send stop messages in case of error + self.send_stop = False + self._closed = True From 5cb5b9a8878b0bd09e9a55c6f836bf904077f986 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Thu, 6 Mar 2025 18:41:04 +0100 Subject: [PATCH 2/9] Moving websockets_base.py into websockets_common folder --- gql/transport/phoenix_channel_websockets.py | 2 +- gql/transport/websockets.py | 2 +- .../{websockets_base.py => websockets_common/base.py} | 6 +++--- 3 files changed, 5 insertions(+), 5 deletions(-) rename gql/transport/{websockets_base.py => websockets_common/base.py} (99%) diff --git a/gql/transport/phoenix_channel_websockets.py b/gql/transport/phoenix_channel_websockets.py index 08cde8cc..a7b256eb 100644 --- a/gql/transport/phoenix_channel_websockets.py +++ b/gql/transport/phoenix_channel_websockets.py @@ -11,7 +11,7 @@ TransportQueryError, TransportServerError, ) -from .websockets_base import WebsocketsTransportBase +from .websockets_common.base import WebsocketsTransportBase log = logging.getLogger(__name__) diff --git a/gql/transport/websockets.py b/gql/transport/websockets.py index 02abb61f..adebf249 100644 --- a/gql/transport/websockets.py +++ b/gql/transport/websockets.py @@ -15,7 +15,7 @@ TransportQueryError, TransportServerError, ) -from .websockets_base import WebsocketsTransportBase +from .websockets_common.base import WebsocketsTransportBase log = logging.getLogger(__name__) diff --git a/gql/transport/websockets_base.py b/gql/transport/websockets_common/base.py similarity index 99% rename from gql/transport/websockets_base.py rename to gql/transport/websockets_common/base.py index f8694c16..4a07a10d 100644 --- a/gql/transport/websockets_base.py +++ b/gql/transport/websockets_common/base.py @@ -13,15 +13,15 @@ from websockets.exceptions import ConnectionClosed from websockets.typing import Data, Subprotocol -from .async_transport import AsyncTransport -from .exceptions import ( +from ..async_transport import AsyncTransport +from ..exceptions import ( TransportAlreadyConnected, TransportClosed, TransportProtocolError, TransportQueryError, TransportServerError, ) -from .websockets_common import ListenerQueue +from .listener_queue import ListenerQueue log = logging.getLogger("gql.transport.websockets") From c369d2a1b67bb485b38202f18926a25c6b54bc0b Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Fri, 7 Mar 2025 23:37:43 +0100 Subject: [PATCH 3/9] Refactor WebSockets Transport with Dependency Injection Architecture This major architectural improvement implements dependency injection patterns across the WebSockets transport layer, creating a more modular, testable, and extensible system: - Created abstract AdapterConnection interface in common/adapters/connection.py - Implemented concrete WebSocketsAdapter to wrap the websockets library - Moved websockets_base.py to common/base.py maintaining better structure which is independant of the websockets library used - Added new TransportConnectionClosed exception for clearer error handling - Reorganized code with proper separation of concerns: * Moved common functionality into dedicated adapters folder * Isolated connection handling from transport business logic * Separated ListenerQueue into its own file for better modularity Potential Breaking changes: * New TransportConnectionClosed Exception replacing ConnectionClosed Exception * websocket attribute removed from transport, now using _connected to check if the transport is connected --- gql/transport/aiohttp_websockets.py | 2 +- gql/transport/appsync_websockets.py | 2 +- gql/transport/common/__init__.py | 10 ++ gql/transport/common/adapters/__init__.py | 3 + gql/transport/common/adapters/connection.py | 54 +++++++ gql/transport/common/adapters/websockets.py | 142 +++++++++++++++++ .../{websockets_common => common}/base.py | 148 ++++++------------ .../listener_queue.py | 0 gql/transport/exceptions.py | 7 + gql/transport/phoenix_channel_websockets.py | 4 +- gql/transport/websockets.py | 38 ++--- gql/transport/websockets_base.py | 93 +++++++++++ gql/transport/websockets_common/__init__.py | 3 - tests/conftest.py | 3 +- tests/test_graphqlws_exceptions.py | 8 +- tests/test_graphqlws_subscription.py | 9 +- tests/test_phoenix_channel_query.py | 4 + tests/test_websocket_exceptions.py | 10 +- tests/test_websocket_query.py | 73 +++++++-- tests/test_websocket_subscription.py | 6 +- tests/test_websockets_adapter.py | 98 ++++++++++++ 21 files changed, 556 insertions(+), 161 deletions(-) create mode 100644 gql/transport/common/__init__.py create mode 100644 gql/transport/common/adapters/__init__.py create mode 100644 gql/transport/common/adapters/connection.py create mode 100644 gql/transport/common/adapters/websockets.py rename gql/transport/{websockets_common => common}/base.py (78%) rename gql/transport/{websockets_common => common}/listener_queue.py (100%) create mode 100644 gql/transport/websockets_base.py delete mode 100644 gql/transport/websockets_common/__init__.py create mode 100644 tests/test_websockets_adapter.py diff --git a/gql/transport/aiohttp_websockets.py b/gql/transport/aiohttp_websockets.py index 9b84bd9b..f97fbba8 100644 --- a/gql/transport/aiohttp_websockets.py +++ b/gql/transport/aiohttp_websockets.py @@ -24,6 +24,7 @@ from .aiohttp import AIOHTTPTransport from .async_transport import AsyncTransport +from .common import ListenerQueue from .exceptions import ( TransportAlreadyConnected, TransportClosed, @@ -31,7 +32,6 @@ TransportQueryError, TransportServerError, ) -from .websockets_common import ListenerQueue log = logging.getLogger("gql.transport.aiohttp_websockets") diff --git a/gql/transport/appsync_websockets.py b/gql/transport/appsync_websockets.py index 66091747..0d5139c3 100644 --- a/gql/transport/appsync_websockets.py +++ b/gql/transport/appsync_websockets.py @@ -181,7 +181,7 @@ async def _send_query( return query_id - subscribe = WebsocketsTransportBase.subscribe + subscribe = WebsocketsTransportBase.subscribe # type: ignore[assignment] """Send a subscription query and receive the results using a python async generator. diff --git a/gql/transport/common/__init__.py b/gql/transport/common/__init__.py new file mode 100644 index 00000000..a60ce0b0 --- /dev/null +++ b/gql/transport/common/__init__.py @@ -0,0 +1,10 @@ +from .adapters import AdapterConnection +from .base import SubscriptionTransportBase +from .listener_queue import ListenerQueue, ParsedAnswer + +__all__ = [ + "AdapterConnection", + "ListenerQueue", + "ParsedAnswer", + "SubscriptionTransportBase", +] diff --git a/gql/transport/common/adapters/__init__.py b/gql/transport/common/adapters/__init__.py new file mode 100644 index 00000000..593c46b6 --- /dev/null +++ b/gql/transport/common/adapters/__init__.py @@ -0,0 +1,3 @@ +from .connection import AdapterConnection + +__all__ = ["AdapterConnection"] diff --git a/gql/transport/common/adapters/connection.py b/gql/transport/common/adapters/connection.py new file mode 100644 index 00000000..fbe38e3b --- /dev/null +++ b/gql/transport/common/adapters/connection.py @@ -0,0 +1,54 @@ +import abc +from typing import Dict + + +class AdapterConnection(abc.ABC): + """Abstract interface for subscription connections. + + This allows different WebSocket implementations to be used interchangeably. + """ + + @abc.abstractmethod + async def connect(self) -> None: + """Connect to the server.""" + pass # pragma: no cover + + @abc.abstractmethod + async def send(self, message: str) -> None: + """Send message to the server. + + Args: + message: String message to send + + Raises: + TransportConnectionClosed: If connection closed + """ + pass # pragma: no cover + + @abc.abstractmethod + async def receive(self) -> str: + """Receive message from the server. + + Returns: + String message received + + Raises: + TransportConnectionClosed: If connection closed + TransportProtocolError: If protocol error or binary data received + """ + pass # pragma: no cover + + @abc.abstractmethod + async def close(self) -> None: + """Close the connection.""" + pass # pragma: no cover + + @property + @abc.abstractmethod + def response_headers(self) -> Dict[str, str]: + """Get the response headers from the connection. + + Returns: + Dictionary of response headers + """ + pass # pragma: no cover diff --git a/gql/transport/common/adapters/websockets.py b/gql/transport/common/adapters/websockets.py new file mode 100644 index 00000000..95fbaf39 --- /dev/null +++ b/gql/transport/common/adapters/websockets.py @@ -0,0 +1,142 @@ +from ssl import SSLContext +from typing import Any, Dict, Optional, Union + +import websockets +from websockets.client import WebSocketClientProtocol +from websockets.datastructures import Headers, HeadersLike +from websockets.exceptions import WebSocketException + +from ...exceptions import TransportConnectionClosed, TransportProtocolError +from .connection import AdapterConnection + + +class WebSocketsAdapter(AdapterConnection): + """AdapterConnection implementation using the websockets library.""" + + def __init__( + self, + url: str, + *, + headers: Optional[HeadersLike] = None, + ssl: Union[SSLContext, bool] = False, + connect_args: Dict[str, Any] = {}, + ) -> None: + """Initialize the transport with the given parameters. + + :param url: The GraphQL server URL. Example: 'wss://server.com:PORT/graphql'. + :param headers: Dict of HTTP Headers. + :param ssl: ssl_context of the connection. Use ssl=False to disable encryption + :param connect_args: Other parameters forwarded to websockets.connect + """ + self.url: str = url + self._headers: Optional[HeadersLike] = headers + self.ssl: Union[SSLContext, bool] = ssl + self.connect_args = connect_args + + self.websocket: Optional[WebSocketClientProtocol] = None + self._response_headers: Optional[Headers] = None + + async def connect(self) -> None: + """Connect to the WebSocket server.""" + + assert self.websocket is None + + ssl: Optional[Union[SSLContext, bool]] + if self.ssl: + ssl = self.ssl + else: + ssl = True if self.url.startswith("wss") else None + + # Set default arguments used in the websockets.connect call + connect_args: Dict[str, Any] = { + "ssl": ssl, + "extra_headers": self.headers, + } + + # Adding custom parameters passed from init + connect_args.update(self.connect_args) + + # Connection to the specified url + try: + self.websocket = await websockets.client.connect(self.url, **connect_args) + except WebSocketException as e: + raise TransportConnectionClosed("Connection was closed") from e + + self._response_headers = self.websocket.response_headers + + async def send(self, message: str) -> None: + """Send message to the WebSocket server. + + Args: + message: String message to send + + Raises: + TransportConnectionClosed: If connection closed + """ + if self.websocket is None: + raise TransportConnectionClosed("Connection is already closed") + + try: + await self.websocket.send(message) + except WebSocketException as e: + raise TransportConnectionClosed("Connection was closed") from e + + async def receive(self) -> str: + """Receive message from the WebSocket server. + + Returns: + String message received + + Raises: + TransportConnectionClosed: If connection closed + TransportProtocolError: If protocol error or binary data received + """ + # It is possible that the websocket has been already closed in another task + if self.websocket is None: + raise TransportConnectionClosed("Connection is already closed") + + # Wait for the next websocket frame. Can raise ConnectionClosed + try: + data = await self.websocket.recv() + except WebSocketException as e: + # When the connection is closed, make sure to clean up resources + self.websocket = None + raise TransportConnectionClosed("Connection was closed") from e + + # websocket.recv() can return either str or bytes + # In our case, we should receive only str here + if not isinstance(data, str): + raise TransportProtocolError("Binary data received in the websocket") + + answer: str = data + + return answer + + async def close(self) -> None: + """Close the WebSocket connection.""" + if self.websocket: + websocket = self.websocket + self.websocket = None + await websocket.close() + + @property + def headers(self) -> Dict[str, str]: + """Get the response headers from the WebSocket connection. + + Returns: + Dictionary of response headers + """ + if self._headers: + return dict(self._headers) + return {} + + @property + def response_headers(self) -> Dict[str, str]: + """Get the response headers from the WebSocket connection. + + Returns: + Dictionary of response headers + """ + if self._response_headers: + return dict(self._response_headers.raw_items()) + return {} diff --git a/gql/transport/websockets_common/base.py b/gql/transport/common/base.py similarity index 78% rename from gql/transport/websockets_common/base.py rename to gql/transport/common/base.py index 4a07a10d..9ee07dd8 100644 --- a/gql/transport/websockets_common/base.py +++ b/gql/transport/common/base.py @@ -3,79 +3,54 @@ import warnings from abc import abstractmethod from contextlib import suppress -from ssl import SSLContext -from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union, cast +from typing import Any, AsyncGenerator, Dict, Optional, Tuple, Union -import websockets from graphql import DocumentNode, ExecutionResult -from websockets.client import WebSocketClientProtocol -from websockets.datastructures import Headers, HeadersLike -from websockets.exceptions import ConnectionClosed -from websockets.typing import Data, Subprotocol from ..async_transport import AsyncTransport from ..exceptions import ( TransportAlreadyConnected, TransportClosed, + TransportConnectionClosed, TransportProtocolError, TransportQueryError, TransportServerError, ) +from .adapters import AdapterConnection from .listener_queue import ListenerQueue -log = logging.getLogger("gql.transport.websockets") +log = logging.getLogger("gql.transport.common.base") -class WebsocketsTransportBase(AsyncTransport): +class SubscriptionTransportBase(AsyncTransport): """abstract :ref:`Async Transport ` used to implement - different websockets protocols. - - This transport uses asyncio and the websockets library in order to send requests - on a websocket connection. + different subscription protocols (mainly websockets). """ def __init__( self, - url: str, - headers: Optional[HeadersLike] = None, - ssl: Union[SSLContext, bool] = False, - init_payload: Dict[str, Any] = {}, + *, + adapter: AdapterConnection, connect_timeout: Optional[Union[int, float]] = 10, close_timeout: Optional[Union[int, float]] = 10, - ack_timeout: Optional[Union[int, float]] = 10, keep_alive_timeout: Optional[Union[int, float]] = None, - connect_args: Dict[str, Any] = {}, ) -> None: """Initialize the transport with the given parameters. - :param url: The GraphQL server URL. Example: 'wss://server.com:PORT/graphql'. - :param headers: Dict of HTTP Headers. - :param ssl: ssl_context of the connection. Use ssl=False to disable encryption - :param init_payload: Dict of the payload sent in the connection_init message. + :param adapter: The connection dependency adapter :param connect_timeout: Timeout in seconds for the establishment - of the websocket connection. If None is provided this will wait forever. + of the connection. If None is provided this will wait forever. :param close_timeout: Timeout in seconds for the close. If None is provided this will wait forever. - :param ack_timeout: Timeout in seconds to wait for the connection_ack message - from the server. If None is provided this will wait forever. :param keep_alive_timeout: Optional Timeout in seconds to receive a sign of liveness from the server. - :param connect_args: Other parameters forwarded to websockets.connect """ - self.url: str = url - self.headers: Optional[HeadersLike] = headers - self.ssl: Union[SSLContext, bool] = ssl - self.init_payload: Dict[str, Any] = init_payload - self.connect_timeout: Optional[Union[int, float]] = connect_timeout self.close_timeout: Optional[Union[int, float]] = close_timeout - self.ack_timeout: Optional[Union[int, float]] = ack_timeout self.keep_alive_timeout: Optional[Union[int, float]] = keep_alive_timeout + self.adapter: AdapterConnection = adapter - self.connect_args = connect_args - - self.websocket: Optional[WebSocketClientProtocol] = None self.next_query_id: int = 1 self.listeners: Dict[int, ListenerQueue] = {} @@ -105,18 +80,14 @@ def __init__( self._next_keep_alive_message: asyncio.Event = asyncio.Event() self._next_keep_alive_message.set() - self.payloads: Dict[str, Any] = {} - """payloads is a dict which will contain the payloads received - for example with the graphql-ws protocol: 'ping', 'pong', 'connection_ack'""" - self._connecting: bool = False + self._connected: bool = False self.close_exception: Optional[Exception] = None - # The list of supported subprotocols should be defined in the subclass - self.supported_subprotocols: List[Subprotocol] = [] - - self.response_headers: Optional[Headers] = None + @property + def response_headers(self) -> Dict[str, str]: + return self.adapter.response_headers async def _initialize(self): """Hook to send the initialization messages after the connection @@ -153,36 +124,30 @@ async def _connection_terminate(self): pass # pragma: no cover async def _send(self, message: str) -> None: - """Send the provided message to the websocket connection and log the message""" + """Send the provided message to the adapter connection and log the message""" - if not self.websocket: + if not self._connected: raise TransportClosed( "Transport is not connected" ) from self.close_exception try: - await self.websocket.send(message) + await self.adapter.send(message) log.info(">>> %s", message) - except ConnectionClosed as e: + except TransportConnectionClosed as e: await self._fail(e, clean_close=False) raise e async def _receive(self) -> str: - """Wait the next message from the websocket connection and log the answer""" + """Wait the next message from the connection and log the answer""" - # It is possible that the websocket has been already closed in another task - if self.websocket is None: + # It is possible that the connection has been already closed in another task + if not self._connected: raise TransportClosed("Transport is already closed") - # Wait for the next websocket frame. Can raise ConnectionClosed - data: Data = await self.websocket.recv() - - # websocket.recv() can return either str or bytes - # In our case, we should receive only str here - if not isinstance(data, str): - raise TransportProtocolError("Binary data received in the websocket") - - answer: str = data + # Wait for the next frame. + # Can raise TransportConnectionClosed or TransportProtocolError + answer: str = await self.adapter.receive() log.info("<<< %s", answer) @@ -243,10 +208,10 @@ async def _receive_data_loop(self) -> None: try: while True: - # Wait the next answer from the websocket server + # Wait the next answer from the server try: answer = await self._receive() - except (ConnectionClosed, TransportProtocolError) as e: + except (TransportConnectionClosed, TransportProtocolError) as e: await self._fail(e, clean_close=False) break except TransportClosed: @@ -331,7 +296,7 @@ async def subscribe( while True: # Wait for the answer from the queue of this query_id - # This can raise a TransportError or ConnectionClosed exception. + # This can raise TransportError or TransportConnectionClosed answer_type, execution_result = await listener.get() # If the received answer contains data, @@ -394,52 +359,30 @@ async def connect(self) -> None: - send the init message - wait for the connection acknowledge from the server - create an asyncio task which will be used to receive - and parse the websocket answers + and parse the answers Should be cleaned with a call to the close coroutine """ log.debug("connect: starting") - if self.websocket is None and not self._connecting: + if not self._connected and not self._connecting: # Set connecting to True to avoid a race condition if user is trying # to connect twice using the same client at the same time self._connecting = True - # If the ssl parameter is not provided, - # generate the ssl value depending on the url - ssl: Optional[Union[SSLContext, bool]] - if self.ssl: - ssl = self.ssl - else: - ssl = True if self.url.startswith("wss") else None - - # Set default arguments used in the websockets.connect call - connect_args: Dict[str, Any] = { - "ssl": ssl, - "extra_headers": self.headers, - "subprotocols": self.supported_subprotocols, - } - - # Adding custom parameters passed from init - connect_args.update(self.connect_args) - - # Connection to the specified url # Generate a TimeoutError if taking more than connect_timeout seconds # Set the _connecting flag to False after in all cases try: - self.websocket = await asyncio.wait_for( - websockets.client.connect(self.url, **connect_args), + await asyncio.wait_for( + self.adapter.connect(), self.connect_timeout, ) + self._connected = True finally: self._connecting = False - self.websocket = cast(WebSocketClientProtocol, self.websocket) - - self.response_headers = self.websocket.response_headers - # Run the after_connect hook of the subclass await self._after_connect() @@ -452,7 +395,7 @@ async def connect(self) -> None: # if no ACKs are received within the ack_timeout try: await self._initialize() - except ConnectionClosed as e: + except TransportConnectionClosed as e: raise e except ( TransportProtocolError, @@ -531,7 +474,7 @@ async def _close_coro(self, e: Exception, clean_close: bool = True) -> None: try: # We should always have an active websocket connection here - assert self.websocket is not None + assert self._connected # Properly shut down liveness checker if enabled if self.check_keep_alive_task is not None: @@ -560,11 +503,11 @@ async def _close_coro(self, e: Exception, clean_close: bool = True) -> None: for query_id, listener in self.listeners.items(): await listener.set_exception(e) - log.debug("_close_coro: close websocket connection") + log.debug("_close_coro: close connection") - await self.websocket.close() + await self.adapter.close() - log.debug("_close_coro: websocket connection closed") + log.debug("_close_coro: connection closed") except Exception as exc: # pragma: no cover log.warning("Exception catched in _close_coro: " + repr(exc)) @@ -573,7 +516,7 @@ async def _close_coro(self, e: Exception, clean_close: bool = True) -> None: log.debug("_close_coro: start cleanup") - self.websocket = None + self._connected = False self.close_task = None self.check_keep_alive_task = None self._wait_closed.set() @@ -585,12 +528,12 @@ async def _fail(self, e: Exception, clean_close: bool = True) -> None: if self.close_task is None: - if self.websocket is None: - log.debug("_fail started with self.websocket == None -> already closed") - else: + if self._connected: self.close_task = asyncio.shield( asyncio.ensure_future(self._close_coro(e, clean_close=clean_close)) ) + else: + log.debug("_fail started with self._connected:False -> already closed") else: log.debug( "close_task is not None in _fail. Previous exception is: " @@ -602,7 +545,7 @@ async def _fail(self, e: Exception, clean_close: bool = True) -> None: async def close(self) -> None: log.debug("close: starting") - await self._fail(TransportClosed("Websocket GraphQL transport closed by user")) + await self._fail(TransportClosed("Transport closed by user")) await self.wait_closed() log.debug("close: done") @@ -610,6 +553,9 @@ async def close(self) -> None: async def wait_closed(self) -> None: log.debug("wait_close: starting") - await self._wait_closed.wait() + try: + await asyncio.wait_for(self._wait_closed.wait(), self.close_timeout) + except asyncio.TimeoutError: + log.debug("Timer close_timeout fired in wait_closed") log.debug("wait_close: done") diff --git a/gql/transport/websockets_common/listener_queue.py b/gql/transport/common/listener_queue.py similarity index 100% rename from gql/transport/websockets_common/listener_queue.py rename to gql/transport/common/listener_queue.py diff --git a/gql/transport/exceptions.py b/gql/transport/exceptions.py index 7ec27a33..27cefe2f 100644 --- a/gql/transport/exceptions.py +++ b/gql/transport/exceptions.py @@ -61,6 +61,13 @@ class TransportClosed(TransportError): """ +class TransportConnectionClosed(TransportError): + """Transport adapter connection closed. + + This exception is by the connection adapter code when a connection closed. + """ + + class TransportAlreadyConnected(TransportError): """Transport is already connected. diff --git a/gql/transport/phoenix_channel_websockets.py b/gql/transport/phoenix_channel_websockets.py index a7b256eb..382e9014 100644 --- a/gql/transport/phoenix_channel_websockets.py +++ b/gql/transport/phoenix_channel_websockets.py @@ -11,7 +11,7 @@ TransportQueryError, TransportServerError, ) -from .websockets_common.base import WebsocketsTransportBase +from .websockets_base import WebsocketsTransportBase log = logging.getLogger(__name__) @@ -370,7 +370,7 @@ async def _handle_answer( execution_result: Optional[ExecutionResult], ) -> None: if answer_type == "close": - await self.close() + pass else: await super()._handle_answer(answer_type, answer_id, execution_result) diff --git a/gql/transport/websockets.py b/gql/transport/websockets.py index adebf249..929761e6 100644 --- a/gql/transport/websockets.py +++ b/gql/transport/websockets.py @@ -15,7 +15,7 @@ TransportQueryError, TransportServerError, ) -from .websockets_common.base import WebsocketsTransportBase +from .websockets_base import WebsocketsTransportBase log = logging.getLogger(__name__) @@ -36,6 +36,7 @@ class WebsocketsTransport(WebsocketsTransportBase): def __init__( self, url: str, + *, headers: Optional[HeadersLike] = None, ssl: Union[SSLContext, bool] = False, init_payload: Dict[str, Any] = {}, @@ -83,16 +84,24 @@ def __init__( By default: both apollo and graphql-ws subprotocols. """ + if subprotocols is None: + subprotocols = [ + self.APOLLO_SUBPROTOCOL, + self.GRAPHQLWS_SUBPROTOCOL, + ] + + # Initiliaze WebsocketsTransportBase parent class super().__init__( url, - headers, - ssl, - init_payload, - connect_timeout, - close_timeout, - ack_timeout, - keep_alive_timeout, - connect_args, + headers=headers, + ssl=ssl, + init_payload=init_payload, + connect_timeout=connect_timeout, + close_timeout=close_timeout, + ack_timeout=ack_timeout, + keep_alive_timeout=keep_alive_timeout, + connect_args=connect_args, + subprotocols=subprotocols, ) self.ping_interval: Optional[Union[int, float]] = ping_interval @@ -115,14 +124,6 @@ def __init__( """pong_received is an asyncio Event which will fire each time a pong is received with the graphql-ws protocol""" - if subprotocols is None: - self.supported_subprotocols = [ - self.APOLLO_SUBPROTOCOL, - self.GRAPHQLWS_SUBPROTOCOL, - ] - else: - self.supported_subprotocols = subprotocols - async def _wait_ack(self) -> None: """Wait for the connection_ack message. Keep alive messages are ignored""" @@ -485,9 +486,8 @@ async def _handle_answer( async def _after_connect(self): # Find the backend subprotocol returned in the response headers - response_headers = self.websocket.response_headers try: - self.subprotocol = response_headers["Sec-WebSocket-Protocol"] + self.subprotocol = self.response_headers["Sec-WebSocket-Protocol"] except KeyError: # If the server does not send the subprotocol header, using # the apollo subprotocol by default diff --git a/gql/transport/websockets_base.py b/gql/transport/websockets_base.py new file mode 100644 index 00000000..95e54b3f --- /dev/null +++ b/gql/transport/websockets_base.py @@ -0,0 +1,93 @@ +from ssl import SSLContext +from typing import Any, Dict, List, Optional, Union + +from websockets.datastructures import HeadersLike +from websockets.typing import Subprotocol + +from .common.adapters.websockets import WebSocketsAdapter +from .common.base import SubscriptionTransportBase + + +class WebsocketsTransportBase(SubscriptionTransportBase): + """abstract :ref:`Async Transport ` used to implement + different websockets protocols. + + This transport uses asyncio and the websockets library in order to send requests + on a websocket connection. + """ + + def __init__( + self, + url: str, + *, + headers: Optional[HeadersLike] = None, + ssl: Union[SSLContext, bool] = False, + init_payload: Dict[str, Any] = {}, + connect_timeout: Optional[Union[int, float]] = 10, + close_timeout: Optional[Union[int, float]] = 10, + ack_timeout: Optional[Union[int, float]] = 10, + keep_alive_timeout: Optional[Union[int, float]] = None, + connect_args: Dict[str, Any] = {}, + subprotocols: Optional[List[Subprotocol]] = None, + ) -> None: + """Initialize the transport with the given parameters. + + :param url: The GraphQL server URL. Example: 'wss://server.com:PORT/graphql'. + :param headers: Dict of HTTP Headers. + :param ssl: ssl_context of the connection. Use ssl=False to disable encryption + :param init_payload: Dict of the payload sent in the connection_init message. + :param connect_timeout: Timeout in seconds for the establishment + of the websocket connection. If None is provided this will wait forever. + :param close_timeout: Timeout in seconds for the close. If None is provided + this will wait forever. + :param ack_timeout: Timeout in seconds to wait for the connection_ack message + from the server. If None is provided this will wait forever. + :param keep_alive_timeout: Optional Timeout in seconds to receive + a sign of liveness from the server. + :param connect_args: Other parameters forwarded to websockets.connect + :param subprotocols: list of subprotocols sent to the + backend in the 'subprotocols' http header. + """ + + if subprotocols is not None: + connect_args.update({"subprotocols": subprotocols}) + + # Instanciate a WebSocketAdapter to indicate the use + # of the websockets dependency for this transport + self.adapter: WebSocketsAdapter = WebSocketsAdapter( + url, + headers=headers, + ssl=ssl, + connect_args=connect_args, + ) + + # Initialize the generic SubscriptionTransportBase parent class + super().__init__( + adapter=self.adapter, + connect_timeout=connect_timeout, + close_timeout=close_timeout, + keep_alive_timeout=keep_alive_timeout, + ) + + self.init_payload: Dict[str, Any] = init_payload + self.ack_timeout: Optional[Union[int, float]] = ack_timeout + + self.payloads: Dict[str, Any] = {} + """payloads is a dict which will contain the payloads received + for example with the graphql-ws protocol: 'ping', 'pong', 'connection_ack'""" + + @property + def url(self) -> str: + return self.adapter.url + + @property + def headers(self) -> Dict[str, str]: + return self.adapter.headers + + @property + def ssl(self) -> Union[SSLContext, bool]: + return self.adapter.ssl + + @property + def connect_args(self) -> Dict[str, Any]: + return self.adapter.connect_args diff --git a/gql/transport/websockets_common/__init__.py b/gql/transport/websockets_common/__init__.py deleted file mode 100644 index 7661cf87..00000000 --- a/gql/transport/websockets_common/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .listener_queue import ListenerQueue, ParsedAnswer - -__all__ = ["ListenerQueue", "ParsedAnswer"] diff --git a/tests/conftest.py b/tests/conftest.py index b0103a99..664fe8c9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -121,9 +121,10 @@ async def ssl_aiohttp_server(): "gql.transport.aiohttp", "gql.transport.aiohttp_websockets", "gql.transport.appsync", + "gql.transport.common.base", + "gql.transport.httpx", "gql.transport.phoenix_channel_websockets", "gql.transport.requests", - "gql.transport.httpx", "gql.transport.websockets", "gql.dsl", "gql.utilities.parse_result", diff --git a/tests/test_graphqlws_exceptions.py b/tests/test_graphqlws_exceptions.py index befeeb4e..cce31d59 100644 --- a/tests/test_graphqlws_exceptions.py +++ b/tests/test_graphqlws_exceptions.py @@ -6,6 +6,7 @@ from gql import Client, gql from gql.transport.exceptions import ( TransportClosed, + TransportConnectionClosed, TransportProtocolError, TransportQueryError, ) @@ -233,7 +234,6 @@ async def server_closing_directly(ws): @pytest.mark.asyncio @pytest.mark.parametrize("graphqlws_server", [server_closing_directly], indirect=True) async def test_graphqlws_server_closing_directly(event_loop, graphqlws_server): - import websockets from gql.transport.websockets import WebsocketsTransport url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}/graphql" @@ -241,7 +241,7 @@ async def test_graphqlws_server_closing_directly(event_loop, graphqlws_server): sample_transport = WebsocketsTransport(url=url) - with pytest.raises(websockets.exceptions.ConnectionClosed): + with pytest.raises(TransportConnectionClosed): async with Client(transport=sample_transport): pass @@ -257,13 +257,11 @@ async def test_graphqlws_server_closing_after_ack( event_loop, client_and_graphqlws_server ): - import websockets - session, server = client_and_graphqlws_server query = gql("query { hello }") - with pytest.raises(websockets.exceptions.ConnectionClosed): + with pytest.raises(TransportConnectionClosed): await session.execute(query) await session.transport.wait_closed() diff --git a/tests/test_graphqlws_subscription.py b/tests/test_graphqlws_subscription.py index 683da43a..1b8f7ccb 100644 --- a/tests/test_graphqlws_subscription.py +++ b/tests/test_graphqlws_subscription.py @@ -8,7 +8,7 @@ from parse import search from gql import Client, gql -from gql.transport.exceptions import TransportServerError +from gql.transport.exceptions import TransportConnectionClosed, TransportServerError from .conftest import MS, WebSocketServerHelper @@ -385,14 +385,12 @@ async def server_countdown_close_connection_in_middle(ws): async def test_graphqlws_subscription_server_connection_closed( event_loop, client_and_graphqlws_server, subscription_str ): - import websockets - session, server = client_and_graphqlws_server count = 10 subscription = gql(subscription_str.format(count=count)) - with pytest.raises(websockets.exceptions.ConnectionClosedOK): + with pytest.raises(TransportConnectionClosed): async for result in session.subscribe(subscription): @@ -812,7 +810,6 @@ async def test_graphqlws_subscription_reconnecting_session( event_loop, graphqlws_server, subscription_str, execute_instead_of_subscribe ): - import websockets from gql.transport.websockets import WebsocketsTransport from gql.transport.exceptions import TransportClosed @@ -838,7 +835,7 @@ async def test_graphqlws_subscription_reconnecting_session( print("\nSUBSCRIPTION_1_WITH_DISCONNECT\n") async for result in session.subscribe(subscription_with_disconnect): pass - except websockets.exceptions.ConnectionClosedOK: + except TransportConnectionClosed: pass await asyncio.sleep(50 * MS) diff --git a/tests/test_phoenix_channel_query.py b/tests/test_phoenix_channel_query.py index f39edacb..320d1da3 100644 --- a/tests/test_phoenix_channel_query.py +++ b/tests/test_phoenix_channel_query.py @@ -65,6 +65,10 @@ async def test_phoenix_channel_query(event_loop, server, query_str): result = await session.execute(query) print("Client received:", result) + continents = result["continents"] + print("Continents received:", continents) + africa = continents[0] + assert africa["code"] == "AF" @pytest.mark.skip(reason="ssl=False is not working for now") diff --git a/tests/test_websocket_exceptions.py b/tests/test_websocket_exceptions.py index cb9e7274..f9f1f8db 100644 --- a/tests/test_websocket_exceptions.py +++ b/tests/test_websocket_exceptions.py @@ -9,6 +9,7 @@ from gql.transport.exceptions import ( TransportAlreadyConnected, TransportClosed, + TransportConnectionClosed, TransportProtocolError, TransportQueryError, ) @@ -141,7 +142,7 @@ async def test_websocket_sending_invalid_data(event_loop, client_and_server, que invalid_data = "QSDF" print(f">>> {invalid_data}") - await session.transport.websocket.send(invalid_data) + await session.transport.adapter.websocket.send(invalid_data) await asyncio.sleep(2 * MS) @@ -272,7 +273,6 @@ async def server_closing_directly(ws): @pytest.mark.asyncio @pytest.mark.parametrize("server", [server_closing_directly], indirect=True) async def test_websocket_server_closing_directly(event_loop, server): - import websockets from gql.transport.websockets import WebsocketsTransport url = f"ws://{server.hostname}:{server.port}/graphql" @@ -280,7 +280,7 @@ async def test_websocket_server_closing_directly(event_loop, server): sample_transport = WebsocketsTransport(url=url) - with pytest.raises(websockets.exceptions.ConnectionClosed): + with pytest.raises(TransportConnectionClosed): async with Client(transport=sample_transport): pass @@ -294,13 +294,11 @@ async def server_closing_after_ack(ws): @pytest.mark.parametrize("server", [server_closing_after_ack], indirect=True) async def test_websocket_server_closing_after_ack(event_loop, client_and_server): - import websockets - session, server = client_and_server query = gql("query { hello }") - with pytest.raises(websockets.exceptions.ConnectionClosed): + with pytest.raises(TransportConnectionClosed): await session.execute(query) await session.transport.wait_closed() diff --git a/tests/test_websocket_query.py b/tests/test_websocket_query.py index 2c723b3f..f509f676 100644 --- a/tests/test_websocket_query.py +++ b/tests/test_websocket_query.py @@ -51,19 +51,19 @@ @pytest.mark.asyncio @pytest.mark.parametrize("server", [server1_answers], indirect=True) async def test_websocket_starting_client_in_context_manager(event_loop, server): - import websockets from gql.transport.websockets import WebsocketsTransport url = f"ws://{server.hostname}:{server.port}/graphql" print(f"url = {url}") - transport = WebsocketsTransport(url=url) + transport = WebsocketsTransport(url=url, headers={"test": "1234"}) + + assert transport.response_headers == {} + assert transport.headers["test"] == "1234" async with Client(transport=transport) as session: - assert isinstance( - transport.websocket, websockets.client.WebSocketClientProtocol - ) + assert transport._connected is True query1 = gql(query1_str) @@ -85,7 +85,7 @@ async def test_websocket_starting_client_in_context_manager(event_loop, server): assert transport.response_headers["dummy"] == "test1234" # Check client is disconnect here - assert transport.websocket is None + assert transport._connected is False @pytest.mark.skip(reason="ssl=False is not working for now") @@ -133,7 +133,7 @@ async def test_websocket_using_ssl_connection(event_loop, ws_ssl_server, verify_ assert africa["code"] == "AF" # Check client is disconnect here - assert transport.websocket is None + assert transport._connected is False @pytest.mark.asyncio @@ -169,7 +169,7 @@ async def test_websocket_using_ssl_connection_self_cert_fail( assert expected_error in str(exc_info.value) # Check client is disconnect here - assert transport.websocket is None + assert transport._connected is False @pytest.mark.asyncio @@ -355,13 +355,13 @@ async def test_websocket_multiple_connections_in_series(event_loop, server): await assert_client_is_working(session) # Check client is disconnect here - assert transport.websocket is None + assert transport._connected is False async with Client(transport=transport) as session: await assert_client_is_working(session) # Check client is disconnect here - assert transport.websocket is None + assert transport._connected is False @pytest.mark.asyncio @@ -484,7 +484,7 @@ async def test_websocket_connect_failed_with_authentication_in_connection_init( await session.execute(query1) - assert transport.websocket is None + assert transport._connected is False @pytest.mark.parametrize("server", [server1_answers], indirect=True) @@ -526,7 +526,7 @@ def test_websocket_execute_sync(server): assert africa["code"] == "AF" # Check client is disconnect here - assert transport.websocket is None + assert transport._connected is False @pytest.mark.asyncio @@ -649,3 +649,52 @@ async def test_websocket_simple_query_with_extensions( execution_result = await session.execute(query, get_execution_result=True) assert execution_result.extensions["key1"] == "val1" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server1_answers], indirect=True) +async def test_websocket_adapter_connection_closed(event_loop, server): + from gql.transport.websockets import WebsocketsTransport + + url = f"ws://{server.hostname}:{server.port}/graphql" + print(f"url = {url}") + + transport = WebsocketsTransport(url=url, headers={"test": "1234"}) + + async with Client(transport=transport) as session: + + query1 = gql(query1_str) + + # Close adapter connection manually (should not be done) + await transport.adapter.close() + + with pytest.raises(TransportClosed): + await session.execute(query1) + + # Check client is disconnect here + assert transport._connected is False + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server1_answers], indirect=True) +async def test_websocket_transport_closed_in_receive(event_loop, server): + from gql.transport.websockets import WebsocketsTransport + + url = f"ws://{server.hostname}:{server.port}/graphql" + print(f"url = {url}") + + transport = WebsocketsTransport( + url=url, + close_timeout=0.1, + ) + + async with Client(transport=transport) as session: + + query1 = gql(query1_str) + + # Close adapter connection manually (should not be done) + # await transport.adapter.close() + transport._connected = False + + with pytest.raises(TransportClosed): + await session.execute(query1) diff --git a/tests/test_websocket_subscription.py b/tests/test_websocket_subscription.py index 5af44d59..3efe63a6 100644 --- a/tests/test_websocket_subscription.py +++ b/tests/test_websocket_subscription.py @@ -9,7 +9,7 @@ from parse import search from gql import Client, gql -from gql.transport.exceptions import TransportServerError +from gql.transport.exceptions import TransportConnectionClosed, TransportServerError from .conftest import MS, WebSocketServerHelper @@ -306,14 +306,12 @@ async def server_countdown_close_connection_in_middle(ws): async def test_websocket_subscription_server_connection_closed( event_loop, client_and_server, subscription_str ): - import websockets - session, server = client_and_server count = 10 subscription = gql(subscription_str.format(count=count)) - with pytest.raises(websockets.exceptions.ConnectionClosedOK): + with pytest.raises(TransportConnectionClosed): async for result in session.subscribe(subscription): diff --git a/tests/test_websockets_adapter.py b/tests/test_websockets_adapter.py new file mode 100644 index 00000000..f266ce29 --- /dev/null +++ b/tests/test_websockets_adapter.py @@ -0,0 +1,98 @@ +import json + +import pytest +from graphql import print_ast + +from gql import gql +from gql.transport.exceptions import TransportConnectionClosed + +# Marking all tests in this file with the websockets marker +pytestmark = pytest.mark.websockets + +query1_str = """ + query getContinents { + continents { + code + name + } + } +""" + +query1_server_answer = ( + '{{"type":"data","id":"{query_id}","payload":{{"data":{{"continents":[' + '{{"code":"AF","name":"Africa"}},{{"code":"AN","name":"Antarctica"}},' + '{{"code":"AS","name":"Asia"}},{{"code":"EU","name":"Europe"}},' + '{{"code":"NA","name":"North America"}},{{"code":"OC","name":"Oceania"}},' + '{{"code":"SA","name":"South America"}}]}}}}}}' +) + +server1_answers = [ + query1_server_answer, +] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server1_answers], indirect=True) +async def test_websockets_adapter_simple_query(event_loop, server): + from gql.transport.common.adapters.websockets import WebSocketsAdapter + + url = f"ws://{server.hostname}:{server.port}/graphql" + + query = print_ast(gql(query1_str)) + print("query=", query) + + adapter = WebSocketsAdapter(url) + + await adapter.connect() + + init_message = json.dumps({"type": "connection_init", "payload": {}}) + + await adapter.send(init_message) + + result = await adapter.receive() + print(f"result={result}") + + payload = json.dumps({"query": query}) + query_message = json.dumps({"id": 1, "type": "start", "payload": payload}) + + await adapter.send(query_message) + + result = await adapter.receive() + print(f"result={result}") + + await adapter.close() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server1_answers], indirect=True) +async def test_websockets_adapter_edge_cases(event_loop, server): + from gql.transport.common.adapters.websockets import WebSocketsAdapter + + url = f"ws://{server.hostname}:{server.port}/graphql" + + query = print_ast(gql(query1_str)) + print("query=", query) + + adapter = WebSocketsAdapter(url, headers={"a": 1}, ssl=False, connect_args={}) + + await adapter.connect() + + assert adapter.headers["a"] == 1 + assert adapter.ssl is False + assert adapter.connect_args == {} + assert adapter.response_headers["dummy"] == "test1234" + + # Connect twice causes AssertionError + with pytest.raises(AssertionError): + await adapter.connect() + + await adapter.close() + + # Second close call is ignored + await adapter.close() + + with pytest.raises(TransportConnectionClosed): + await adapter.send("Blah") + + with pytest.raises(TransportConnectionClosed): + await adapter.receive() From 4a8493b22b26110a9c130fb1e732713f94db79d2 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Sat, 8 Mar 2025 23:03:20 +0100 Subject: [PATCH 4/9] Using SubscriptionTransportBase instead of WebsocketsTransportBase for Phoenix transport --- gql/transport/phoenix_channel_websockets.py | 40 +++++++++++++++++---- 1 file changed, 33 insertions(+), 7 deletions(-) diff --git a/gql/transport/phoenix_channel_websockets.py b/gql/transport/phoenix_channel_websockets.py index 382e9014..0c1bd62b 100644 --- a/gql/transport/phoenix_channel_websockets.py +++ b/gql/transport/phoenix_channel_websockets.py @@ -1,17 +1,18 @@ import asyncio import json import logging -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Optional, Tuple, Union from graphql import DocumentNode, ExecutionResult, print_ast -from websockets.exceptions import ConnectionClosed +from .common.adapters.websockets import WebSocketsAdapter +from .common.base import SubscriptionTransportBase from .exceptions import ( + TransportConnectionClosed, TransportProtocolError, TransportQueryError, TransportServerError, ) -from .websockets_base import WebsocketsTransportBase log = logging.getLogger(__name__) @@ -24,7 +25,7 @@ def __init__(self, query_id: int) -> None: self.unsubscribe_id: Optional[int] = None -class PhoenixChannelWebsocketsTransport(WebsocketsTransportBase): +class PhoenixChannelWebsocketsTransport(SubscriptionTransportBase): """The PhoenixChannelWebsocketsTransport is an async transport which allows you to execute queries and subscriptions against an `Absinthe`_ backend using the `Phoenix`_ framework `channels`_. @@ -36,23 +37,48 @@ class PhoenixChannelWebsocketsTransport(WebsocketsTransportBase): def __init__( self, + url: str, + *, channel_name: str = "__absinthe__:control", heartbeat_interval: float = 30, - *args, + ack_timeout: Optional[Union[int, float]] = 10, **kwargs, ) -> None: """Initialize the transport with the given parameters. + :param url: The server URL.'. :param channel_name: Channel on the server this transport will join. The default for Absinthe servers is "__absinthe__:control" :param heartbeat_interval: Interval in second between each heartbeat messages sent by the client + :param ack_timeout: Timeout in seconds to wait for the reply message + from the server. """ self.channel_name: str = channel_name self.heartbeat_interval: float = heartbeat_interval self.heartbeat_task: Optional[asyncio.Future] = None self.subscriptions: Dict[str, Subscription] = {} - super().__init__(*args, **kwargs) + self.ack_timeout: Optional[Union[int, float]] = ack_timeout + + # Instanciate a WebSocketAdapter to indicate the use + # of the websockets dependency for this transport + ws_adapter_args = {} + for ws_arg in ["headers", "ssl", "connect_args"]: + try: + ws_adapter_args[ws_arg] = kwargs.pop(ws_arg) + except KeyError: + pass + + self.adapter: WebSocketsAdapter = WebSocketsAdapter( + url=url, + **ws_adapter_args, + ) + + # Initialize the generic SubscriptionTransportBase parent class + super().__init__( + adapter=self.adapter, + **kwargs, + ) async def _initialize(self) -> None: """Join the specified channel and wait for the connection ACK. @@ -101,7 +127,7 @@ async def heartbeat_coro(): } ) ) - except ConnectionClosed: # pragma: no cover + except TransportConnectionClosed: # pragma: no cover return self.heartbeat_task = asyncio.ensure_future(heartbeat_coro()) From fe6712b383f256eafdb976e1a1d985c375b6236e Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Sat, 8 Mar 2025 23:21:24 +0100 Subject: [PATCH 5/9] Using SubscriptionTransportBase instead of WebsocketsTransportBase for AppSync transport --- gql/transport/appsync_websockets.py | 41 ++++++++++++++++++++++++----- 1 file changed, 34 insertions(+), 7 deletions(-) diff --git a/gql/transport/appsync_websockets.py b/gql/transport/appsync_websockets.py index 0d5139c3..c339e0b8 100644 --- a/gql/transport/appsync_websockets.py +++ b/gql/transport/appsync_websockets.py @@ -7,8 +7,10 @@ from graphql import DocumentNode, ExecutionResult, print_ast from .appsync_auth import AppSyncAuthentication, AppSyncIAMAuthentication +from .common.adapters.websockets import WebSocketsAdapter +from .common.base import SubscriptionTransportBase from .exceptions import TransportProtocolError, TransportServerError -from .websockets import WebsocketsTransport, WebsocketsTransportBase +from .websockets import WebsocketsTransport log = logging.getLogger("gql.transport.appsync") @@ -19,7 +21,7 @@ pass -class AppSyncWebsocketsTransport(WebsocketsTransportBase): +class AppSyncWebsocketsTransport(SubscriptionTransportBase): """:ref:`Async Transport ` used to execute GraphQL subscription on AWS appsync realtime endpoint. @@ -32,6 +34,7 @@ class AppSyncWebsocketsTransport(WebsocketsTransportBase): def __init__( self, url: str, + *, auth: Optional[AppSyncAuthentication] = None, session: Optional["botocore.session.Session"] = None, ssl: Union[SSLContext, bool] = False, @@ -70,17 +73,25 @@ def __init__( auth = AppSyncIAMAuthentication(host=host, session=session) self.auth = auth + self.ack_timeout: Optional[Union[int, float]] = ack_timeout + self.init_payload: Dict[str, Any] = {} url = self.auth.get_auth_url(url) - super().__init__( - url, + # Instanciate a WebSocketAdapter to indicate the use + # of the websockets dependency for this transport + self.adapter: WebSocketsAdapter = WebSocketsAdapter( + url=url, ssl=ssl, + connect_args=connect_args, + ) + + # Initialize the generic SubscriptionTransportBase parent class + super().__init__( + adapter=self.adapter, connect_timeout=connect_timeout, close_timeout=close_timeout, - ack_timeout=ack_timeout, keep_alive_timeout=keep_alive_timeout, - connect_args=connect_args, ) # Using the same 'graphql-ws' protocol as the apollo protocol @@ -181,7 +192,7 @@ async def _send_query( return query_id - subscribe = WebsocketsTransportBase.subscribe # type: ignore[assignment] + subscribe = SubscriptionTransportBase.subscribe # type: ignore[assignment] """Send a subscription query and receive the results using a python async generator. @@ -212,3 +223,19 @@ async def execute( WebsocketsTransport._send_init_message_and_wait_ack ) _wait_ack = WebsocketsTransport._wait_ack + + @property + def url(self) -> str: + return self.adapter.url + + @property + def headers(self) -> Dict[str, str]: + return self.adapter.headers + + @property + def ssl(self) -> Union[SSLContext, bool]: + return self.adapter.ssl + + @property + def connect_args(self) -> Dict[str, Any]: + return self.adapter.connect_args From 352af37ec8c9fa2eb89540155243bedf6d16d887 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Sun, 9 Mar 2025 17:02:13 +0100 Subject: [PATCH 6/9] Put dependency-free websockets protocol in websockets_protocol.py --- gql/transport/appsync_websockets.py | 12 - gql/transport/common/adapters/connection.py | 5 +- gql/transport/common/adapters/websockets.py | 6 +- gql/transport/common/base.py | 8 + gql/transport/websockets.py | 474 +----------------- gql/transport/websockets_base.py | 93 ---- gql/transport/websockets_protocol.py | 516 ++++++++++++++++++++ tests/test_phoenix_channel_subscription.py | 4 +- tests/test_websocket_query.py | 3 + 9 files changed, 564 insertions(+), 557 deletions(-) delete mode 100644 gql/transport/websockets_base.py create mode 100644 gql/transport/websockets_protocol.py diff --git a/gql/transport/appsync_websockets.py b/gql/transport/appsync_websockets.py index c339e0b8..e0f5c031 100644 --- a/gql/transport/appsync_websockets.py +++ b/gql/transport/appsync_websockets.py @@ -224,18 +224,6 @@ async def execute( ) _wait_ack = WebsocketsTransport._wait_ack - @property - def url(self) -> str: - return self.adapter.url - - @property - def headers(self) -> Dict[str, str]: - return self.adapter.headers - @property def ssl(self) -> Union[SSLContext, bool]: return self.adapter.ssl - - @property - def connect_args(self) -> Dict[str, Any]: - return self.adapter.connect_args diff --git a/gql/transport/common/adapters/connection.py b/gql/transport/common/adapters/connection.py index fbe38e3b..cf361b8d 100644 --- a/gql/transport/common/adapters/connection.py +++ b/gql/transport/common/adapters/connection.py @@ -1,5 +1,5 @@ import abc -from typing import Dict +from typing import Any, Dict class AdapterConnection(abc.ABC): @@ -8,6 +8,9 @@ class AdapterConnection(abc.ABC): This allows different WebSocket implementations to be used interchangeably. """ + url: str + connect_args: Dict[str, Any] + @abc.abstractmethod async def connect(self) -> None: """Connect to the server.""" diff --git a/gql/transport/common/adapters/websockets.py b/gql/transport/common/adapters/websockets.py index 95fbaf39..4494e256 100644 --- a/gql/transport/common/adapters/websockets.py +++ b/gql/transport/common/adapters/websockets.py @@ -19,7 +19,7 @@ def __init__( *, headers: Optional[HeadersLike] = None, ssl: Union[SSLContext, bool] = False, - connect_args: Dict[str, Any] = {}, + connect_args: Optional[Dict[str, Any]] = None, ) -> None: """Initialize the transport with the given parameters. @@ -31,6 +31,10 @@ def __init__( self.url: str = url self._headers: Optional[HeadersLike] = headers self.ssl: Union[SSLContext, bool] = ssl + + if connect_args is None: + connect_args = {} + self.connect_args = connect_args self.websocket: Optional[WebSocketClientProtocol] = None diff --git a/gql/transport/common/base.py b/gql/transport/common/base.py index 9ee07dd8..40d0b4cb 100644 --- a/gql/transport/common/base.py +++ b/gql/transport/common/base.py @@ -559,3 +559,11 @@ async def wait_closed(self) -> None: log.debug("Timer close_timeout fired in wait_closed") log.debug("wait_close: done") + + @property + def url(self) -> str: + return self.adapter.url + + @property + def connect_args(self) -> Dict[str, Any]: + return self.adapter.connect_args diff --git a/gql/transport/websockets.py b/gql/transport/websockets.py index 929761e6..7a0ce10a 100644 --- a/gql/transport/websockets.py +++ b/gql/transport/websockets.py @@ -1,26 +1,13 @@ -import asyncio -import json -import logging -from contextlib import suppress from ssl import SSLContext -from typing import Any, Dict, List, Optional, Tuple, Union, cast +from typing import Any, Dict, List, Optional, Union -from graphql import DocumentNode, ExecutionResult, print_ast from websockets.datastructures import HeadersLike -from websockets.exceptions import ConnectionClosed -from websockets.typing import Subprotocol -from .exceptions import ( - TransportProtocolError, - TransportQueryError, - TransportServerError, -) -from .websockets_base import WebsocketsTransportBase +from .common.adapters.websockets import WebSocketsAdapter +from .websockets_protocol import WebsocketsProtocolTransportBase -log = logging.getLogger(__name__) - -class WebsocketsTransport(WebsocketsTransportBase): +class WebsocketsTransport(WebsocketsProtocolTransportBase): """:ref:`Async Transport ` used to execute GraphQL queries on remote servers with websocket connection. @@ -28,18 +15,13 @@ class WebsocketsTransport(WebsocketsTransportBase): on a websocket connection. """ - # This transport supports two subprotocols and will autodetect the - # subprotocol supported on the server - APOLLO_SUBPROTOCOL = cast(Subprotocol, "graphql-ws") - GRAPHQLWS_SUBPROTOCOL = cast(Subprotocol, "graphql-transport-ws") - def __init__( self, url: str, *, headers: Optional[HeadersLike] = None, ssl: Union[SSLContext, bool] = False, - init_payload: Dict[str, Any] = {}, + init_payload: Optional[Dict[str, Any]] = None, connect_timeout: Optional[Union[int, float]] = 10, close_timeout: Optional[Union[int, float]] = 10, ack_timeout: Optional[Union[int, float]] = 10, @@ -47,8 +29,8 @@ def __init__( ping_interval: Optional[Union[int, float]] = None, pong_timeout: Optional[Union[int, float]] = None, answer_pings: bool = True, - connect_args: Dict[str, Any] = {}, - subprotocols: Optional[List[Subprotocol]] = None, + connect_args: Optional[Dict[str, Any]] = None, + subprotocols: Optional[List[str]] = None, ) -> None: """Initialize the transport with the given parameters. @@ -84,437 +66,33 @@ def __init__( By default: both apollo and graphql-ws subprotocols. """ - if subprotocols is None: - subprotocols = [ - self.APOLLO_SUBPROTOCOL, - self.GRAPHQLWS_SUBPROTOCOL, - ] - - # Initiliaze WebsocketsTransportBase parent class - super().__init__( - url, + # Instanciate a WebSocketAdapter to indicate the use + # of the websockets dependency for this transport + self.adapter: WebSocketsAdapter = WebSocketsAdapter( + url=url, headers=headers, ssl=ssl, + connect_args=connect_args, + ) + + # Initialize the WebsocketsProtocolTransportBase parent class + super().__init__( + adapter=self.adapter, init_payload=init_payload, connect_timeout=connect_timeout, close_timeout=close_timeout, ack_timeout=ack_timeout, keep_alive_timeout=keep_alive_timeout, - connect_args=connect_args, + ping_interval=ping_interval, + pong_timeout=pong_timeout, + answer_pings=answer_pings, subprotocols=subprotocols, ) - self.ping_interval: Optional[Union[int, float]] = ping_interval - self.pong_timeout: Optional[Union[int, float]] - self.answer_pings: bool = answer_pings - - if ping_interval is not None: - if pong_timeout is None: - self.pong_timeout = ping_interval / 2 - else: - self.pong_timeout = pong_timeout - - self.send_ping_task: Optional[asyncio.Future] = None - - self.ping_received: asyncio.Event = asyncio.Event() - """ping_received is an asyncio Event which will fire each time - a ping is received with the graphql-ws protocol""" - - self.pong_received: asyncio.Event = asyncio.Event() - """pong_received is an asyncio Event which will fire each time - a pong is received with the graphql-ws protocol""" - - async def _wait_ack(self) -> None: - """Wait for the connection_ack message. Keep alive messages are ignored""" - - while True: - init_answer = await self._receive() - - answer_type, answer_id, execution_result = self._parse_answer(init_answer) - - if answer_type == "connection_ack": - return - - if answer_type != "ka": - raise TransportProtocolError( - "Websocket server did not return a connection ack" - ) - - async def _send_init_message_and_wait_ack(self) -> None: - """Send init message to the provided websocket and wait for the connection ACK. - - If the answer is not a connection_ack message, we will return an Exception. - """ - - init_message = json.dumps( - {"type": "connection_init", "payload": self.init_payload} - ) - - await self._send(init_message) - - # Wait for the connection_ack message or raise a TimeoutError - await asyncio.wait_for(self._wait_ack(), self.ack_timeout) - - async def _initialize(self): - await self._send_init_message_and_wait_ack() - - async def send_ping(self, payload: Optional[Any] = None) -> None: - """Send a ping message for the graphql-ws protocol""" - - ping_message = {"type": "ping"} - - if payload is not None: - ping_message["payload"] = payload - - await self._send(json.dumps(ping_message)) - - async def send_pong(self, payload: Optional[Any] = None) -> None: - """Send a pong message for the graphql-ws protocol""" - - pong_message = {"type": "pong"} - - if payload is not None: - pong_message["payload"] = payload - - await self._send(json.dumps(pong_message)) - - async def _send_stop_message(self, query_id: int) -> None: - """Send stop message to the provided websocket connection and query_id. - - The server should afterwards return a 'complete' message. - """ - - stop_message = json.dumps({"id": str(query_id), "type": "stop"}) - - await self._send(stop_message) - - async def _send_complete_message(self, query_id: int) -> None: - """Send a complete message for the provided query_id. - - This is only for the graphql-ws protocol. - """ - - complete_message = json.dumps({"id": str(query_id), "type": "complete"}) - - await self._send(complete_message) - - async def _stop_listener(self, query_id: int): - """Stop the listener corresponding to the query_id depending on the - detected backend protocol. - - For apollo: send a "stop" message - (a "complete" message will be sent from the backend) - - For graphql-ws: send a "complete" message and simulate the reception - of a "complete" message from the backend - """ - log.debug(f"stop listener {query_id}") - - if self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL: - await self._send_complete_message(query_id) - await self.listeners[query_id].put(("complete", None)) - else: - await self._send_stop_message(query_id) - - async def _send_connection_terminate_message(self) -> None: - """Send a connection_terminate message to the provided websocket connection. - - This message indicates that the connection will disconnect. - """ - - connection_terminate_message = json.dumps({"type": "connection_terminate"}) - - await self._send(connection_terminate_message) - - async def _send_query( - self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, - ) -> int: - """Send a query to the provided websocket connection. - - We use an incremented id to reference the query. - - Returns the used id for this query. - """ - - query_id = self.next_query_id - self.next_query_id += 1 - - payload: Dict[str, Any] = {"query": print_ast(document)} - if variable_values: - payload["variables"] = variable_values - if operation_name: - payload["operationName"] = operation_name - - query_type = "start" - - if self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL: - query_type = "subscribe" - - query_str = json.dumps( - {"id": str(query_id), "type": query_type, "payload": payload} - ) - - await self._send(query_str) - - return query_id - - async def _connection_terminate(self): - if self.subprotocol == self.APOLLO_SUBPROTOCOL: - await self._send_connection_terminate_message() - - def _parse_answer_graphqlws( - self, json_answer: Dict[str, Any] - ) -> Tuple[str, Optional[int], Optional[ExecutionResult]]: - """Parse the answer received from the server if the server supports the - graphql-ws protocol. - - Returns a list consisting of: - - the answer_type (between: - 'connection_ack', 'ping', 'pong', 'data', 'error', 'complete') - - the answer id (Integer) if received or None - - an execution Result if the answer_type is 'data' or None - - Differences with the apollo websockets protocol (superclass): - - the "data" message is now called "next" - - the "stop" message is now called "complete" - - there is no connection_terminate or connection_error messages - - instead of a unidirectional keep-alive (ka) message from server to client, - there is now the possibility to send bidirectional ping/pong messages - - connection_ack has an optional payload - - the 'error' answer type returns a list of errors instead of a single error - """ - - answer_type: str = "" - answer_id: Optional[int] = None - execution_result: Optional[ExecutionResult] = None - - try: - answer_type = str(json_answer.get("type")) - - if answer_type in ["next", "error", "complete"]: - answer_id = int(str(json_answer.get("id"))) - - if answer_type == "next" or answer_type == "error": - - payload = json_answer.get("payload") - - if answer_type == "next": - - if not isinstance(payload, dict): - raise ValueError("payload is not a dict") - - if "errors" not in payload and "data" not in payload: - raise ValueError( - "payload does not contain 'data' or 'errors' fields" - ) - - execution_result = ExecutionResult( - errors=payload.get("errors"), - data=payload.get("data"), - extensions=payload.get("extensions"), - ) - - # Saving answer_type as 'data' to be understood with superclass - answer_type = "data" - - elif answer_type == "error": - - if not isinstance(payload, list): - raise ValueError("payload is not a list") - - raise TransportQueryError( - str(payload[0]), query_id=answer_id, errors=payload - ) - - elif answer_type in ["ping", "pong", "connection_ack"]: - self.payloads[answer_type] = json_answer.get("payload", None) - - else: - raise ValueError - - if self.check_keep_alive_task is not None: - self._next_keep_alive_message.set() - - except ValueError as e: - raise TransportProtocolError( - f"Server did not return a GraphQL result: {json_answer}" - ) from e - - return answer_type, answer_id, execution_result - - def _parse_answer_apollo( - self, json_answer: Dict[str, Any] - ) -> Tuple[str, Optional[int], Optional[ExecutionResult]]: - """Parse the answer received from the server if the server supports the - apollo websockets protocol. - - Returns a list consisting of: - - the answer_type (between: - 'connection_ack', 'ka', 'connection_error', 'data', 'error', 'complete') - - the answer id (Integer) if received or None - - an execution Result if the answer_type is 'data' or None - """ - - answer_type: str = "" - answer_id: Optional[int] = None - execution_result: Optional[ExecutionResult] = None - - try: - answer_type = str(json_answer.get("type")) - - if answer_type in ["data", "error", "complete"]: - answer_id = int(str(json_answer.get("id"))) - - if answer_type == "data" or answer_type == "error": - - payload = json_answer.get("payload") - - if not isinstance(payload, dict): - raise ValueError("payload is not a dict") - - if answer_type == "data": - - if "errors" not in payload and "data" not in payload: - raise ValueError( - "payload does not contain 'data' or 'errors' fields" - ) - - execution_result = ExecutionResult( - errors=payload.get("errors"), - data=payload.get("data"), - extensions=payload.get("extensions"), - ) - - elif answer_type == "error": - - raise TransportQueryError( - str(payload), query_id=answer_id, errors=[payload] - ) - - elif answer_type == "ka": - # Keep-alive message - if self.check_keep_alive_task is not None: - self._next_keep_alive_message.set() - elif answer_type == "connection_ack": - pass - elif answer_type == "connection_error": - error_payload = json_answer.get("payload") - raise TransportServerError(f"Server error: '{repr(error_payload)}'") - else: - raise ValueError - - except ValueError as e: - raise TransportProtocolError( - f"Server did not return a GraphQL result: {json_answer}" - ) from e - - return answer_type, answer_id, execution_result - - def _parse_answer( - self, answer: str - ) -> Tuple[str, Optional[int], Optional[ExecutionResult]]: - """Parse the answer received from the server depending on - the detected subprotocol. - """ - try: - json_answer = json.loads(answer) - except ValueError: - raise TransportProtocolError( - f"Server did not return a GraphQL result: {answer}" - ) - - if self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL: - return self._parse_answer_graphqlws(json_answer) - - return self._parse_answer_apollo(json_answer) - - async def _send_ping_coro(self) -> None: - """Coroutine to periodically send a ping from the client to the backend. - - Only used for the graphql-ws protocol. - - Send a ping every ping_interval seconds. - Close the connection if a pong is not received within pong_timeout seconds. - """ - - assert self.ping_interval is not None - - try: - while True: - await asyncio.sleep(self.ping_interval) - - await self.send_ping() - - await asyncio.wait_for(self.pong_received.wait(), self.pong_timeout) - - # Reset for the next iteration - self.pong_received.clear() - - except asyncio.TimeoutError: - # No pong received in the appriopriate time, close with error - # If the timeout happens during a close already in progress, do nothing - if self.close_task is None: - await self._fail( - TransportServerError( - f"No pong received after {self.pong_timeout!r} seconds" - ), - clean_close=False, - ) - - async def _handle_answer( - self, - answer_type: str, - answer_id: Optional[int], - execution_result: Optional[ExecutionResult], - ) -> None: - - # Put the answer in the queue - await super()._handle_answer(answer_type, answer_id, execution_result) - - # Answer pong to ping for graphql-ws protocol - if answer_type == "ping": - self.ping_received.set() - if self.answer_pings: - await self.send_pong() - - elif answer_type == "pong": - self.pong_received.set() - - async def _after_connect(self): - - # Find the backend subprotocol returned in the response headers - try: - self.subprotocol = self.response_headers["Sec-WebSocket-Protocol"] - except KeyError: - # If the server does not send the subprotocol header, using - # the apollo subprotocol by default - self.subprotocol = self.APOLLO_SUBPROTOCOL - - log.debug(f"backend subprotocol returned: {self.subprotocol!r}") - - async def _after_initialize(self): - - # If requested, create a task to send periodic pings to the backend - if ( - self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL - and self.ping_interval is not None - ): - - self.send_ping_task = asyncio.ensure_future(self._send_ping_coro()) - - async def _close_hook(self): - log.debug("_close_hook: start") - - # Properly shut down the send ping task if enabled - if self.send_ping_task is not None: - log.debug("_close_hook: cancelling send_ping_task") - self.send_ping_task.cancel() - with suppress(asyncio.CancelledError, ConnectionClosed): - log.debug("_close_hook: awaiting send_ping_task") - await self.send_ping_task - self.send_ping_task = None + @property + def headers(self) -> Optional[HeadersLike]: + return self.adapter.headers - log.debug("_close_hook: end") + @property + def ssl(self) -> Union[SSLContext, bool]: + return self.adapter.ssl diff --git a/gql/transport/websockets_base.py b/gql/transport/websockets_base.py deleted file mode 100644 index 95e54b3f..00000000 --- a/gql/transport/websockets_base.py +++ /dev/null @@ -1,93 +0,0 @@ -from ssl import SSLContext -from typing import Any, Dict, List, Optional, Union - -from websockets.datastructures import HeadersLike -from websockets.typing import Subprotocol - -from .common.adapters.websockets import WebSocketsAdapter -from .common.base import SubscriptionTransportBase - - -class WebsocketsTransportBase(SubscriptionTransportBase): - """abstract :ref:`Async Transport ` used to implement - different websockets protocols. - - This transport uses asyncio and the websockets library in order to send requests - on a websocket connection. - """ - - def __init__( - self, - url: str, - *, - headers: Optional[HeadersLike] = None, - ssl: Union[SSLContext, bool] = False, - init_payload: Dict[str, Any] = {}, - connect_timeout: Optional[Union[int, float]] = 10, - close_timeout: Optional[Union[int, float]] = 10, - ack_timeout: Optional[Union[int, float]] = 10, - keep_alive_timeout: Optional[Union[int, float]] = None, - connect_args: Dict[str, Any] = {}, - subprotocols: Optional[List[Subprotocol]] = None, - ) -> None: - """Initialize the transport with the given parameters. - - :param url: The GraphQL server URL. Example: 'wss://server.com:PORT/graphql'. - :param headers: Dict of HTTP Headers. - :param ssl: ssl_context of the connection. Use ssl=False to disable encryption - :param init_payload: Dict of the payload sent in the connection_init message. - :param connect_timeout: Timeout in seconds for the establishment - of the websocket connection. If None is provided this will wait forever. - :param close_timeout: Timeout in seconds for the close. If None is provided - this will wait forever. - :param ack_timeout: Timeout in seconds to wait for the connection_ack message - from the server. If None is provided this will wait forever. - :param keep_alive_timeout: Optional Timeout in seconds to receive - a sign of liveness from the server. - :param connect_args: Other parameters forwarded to websockets.connect - :param subprotocols: list of subprotocols sent to the - backend in the 'subprotocols' http header. - """ - - if subprotocols is not None: - connect_args.update({"subprotocols": subprotocols}) - - # Instanciate a WebSocketAdapter to indicate the use - # of the websockets dependency for this transport - self.adapter: WebSocketsAdapter = WebSocketsAdapter( - url, - headers=headers, - ssl=ssl, - connect_args=connect_args, - ) - - # Initialize the generic SubscriptionTransportBase parent class - super().__init__( - adapter=self.adapter, - connect_timeout=connect_timeout, - close_timeout=close_timeout, - keep_alive_timeout=keep_alive_timeout, - ) - - self.init_payload: Dict[str, Any] = init_payload - self.ack_timeout: Optional[Union[int, float]] = ack_timeout - - self.payloads: Dict[str, Any] = {} - """payloads is a dict which will contain the payloads received - for example with the graphql-ws protocol: 'ping', 'pong', 'connection_ack'""" - - @property - def url(self) -> str: - return self.adapter.url - - @property - def headers(self) -> Dict[str, str]: - return self.adapter.headers - - @property - def ssl(self) -> Union[SSLContext, bool]: - return self.adapter.ssl - - @property - def connect_args(self) -> Dict[str, Any]: - return self.adapter.connect_args diff --git a/gql/transport/websockets_protocol.py b/gql/transport/websockets_protocol.py new file mode 100644 index 00000000..84ba7656 --- /dev/null +++ b/gql/transport/websockets_protocol.py @@ -0,0 +1,516 @@ +import asyncio +import json +import logging +from contextlib import suppress +from typing import Any, Dict, List, Optional, Tuple, Union + +from graphql import DocumentNode, ExecutionResult, print_ast + +from .common.adapters.websockets import AdapterConnection +from .common.base import SubscriptionTransportBase +from .exceptions import ( + TransportConnectionClosed, + TransportProtocolError, + TransportQueryError, + TransportServerError, +) + +log = logging.getLogger("gql.transport.websockets") + + +class WebsocketsProtocolTransportBase(SubscriptionTransportBase): + """:ref:`Async Transport ` used to execute GraphQL queries on + remote servers with websocket connection. + + This transport uses asyncio and the provided websockets adapter library + in order to send requests on a websocket connection. + """ + + # This transport supports two subprotocols and will autodetect the + # subprotocol supported on the server + APOLLO_SUBPROTOCOL = "graphql-ws" + GRAPHQLWS_SUBPROTOCOL = "graphql-transport-ws" + + def __init__( + self, + *, + adapter: AdapterConnection, + init_payload: Optional[Dict[str, Any]] = None, + connect_timeout: Optional[Union[int, float]] = 10, + close_timeout: Optional[Union[int, float]] = 10, + ack_timeout: Optional[Union[int, float]] = 10, + keep_alive_timeout: Optional[Union[int, float]] = None, + ping_interval: Optional[Union[int, float]] = None, + pong_timeout: Optional[Union[int, float]] = None, + answer_pings: bool = True, + subprotocols: Optional[List[str]] = None, + ) -> None: + """Initialize the transport with the given parameters. + + :param adapter: The connection dependency adapter + :param init_payload: Dict of the payload sent in the connection_init message. + :param connect_timeout: Timeout in seconds for the establishment + of the websocket connection. If None is provided this will wait forever. + :param close_timeout: Timeout in seconds for the close. If None is provided + this will wait forever. + :param ack_timeout: Timeout in seconds to wait for the connection_ack message + from the server. If None is provided this will wait forever. + :param keep_alive_timeout: Optional Timeout in seconds to receive + a sign of liveness from the server. + :param ping_interval: Delay in seconds between pings sent by the client to + the backend for the graphql-ws protocol. None (by default) means that + we don't send pings. Note: there are also pings sent by the underlying + websockets protocol. See the + :ref:`keepalive documentation ` + for more information about this. + :param pong_timeout: Delay in seconds to receive a pong from the backend + after we sent a ping (only for the graphql-ws protocol). + By default equal to half of the ping_interval. + :param answer_pings: Whether the client answers the pings from the backend + (for the graphql-ws protocol). + By default: True + :param subprotocols: list of subprotocols sent to the + backend in the 'subprotocols' http header. + By default: both apollo and graphql-ws subprotocols. + """ + + if subprotocols is None: + subprotocols = [ + self.APOLLO_SUBPROTOCOL, + self.GRAPHQLWS_SUBPROTOCOL, + ] + + self.adapter.connect_args.update({"subprotocols": subprotocols}) + + # Initialize the generic SubscriptionTransportBase parent class + super().__init__( + adapter=self.adapter, + connect_timeout=connect_timeout, + close_timeout=close_timeout, + keep_alive_timeout=keep_alive_timeout, + ) + + if init_payload is None: + init_payload = {} + + self.init_payload: Dict[str, Any] = init_payload + self.ack_timeout: Optional[Union[int, float]] = ack_timeout + + self.payloads: Dict[str, Any] = {} + """payloads is a dict which will contain the payloads received + for example with the graphql-ws protocol: 'ping', 'pong', 'connection_ack'""" + + self.ping_interval: Optional[Union[int, float]] = ping_interval + self.pong_timeout: Optional[Union[int, float]] + self.answer_pings: bool = answer_pings + + if ping_interval is not None: + if pong_timeout is None: + self.pong_timeout = ping_interval / 2 + else: + self.pong_timeout = pong_timeout + + self.send_ping_task: Optional[asyncio.Future] = None + + self.ping_received: asyncio.Event = asyncio.Event() + """ping_received is an asyncio Event which will fire each time + a ping is received with the graphql-ws protocol""" + + self.pong_received: asyncio.Event = asyncio.Event() + """pong_received is an asyncio Event which will fire each time + a pong is received with the graphql-ws protocol""" + + async def _wait_ack(self) -> None: + """Wait for the connection_ack message. Keep alive messages are ignored""" + + while True: + init_answer = await self._receive() + + answer_type, answer_id, execution_result = self._parse_answer(init_answer) + + if answer_type == "connection_ack": + return + + if answer_type != "ka": + raise TransportProtocolError( + "Websocket server did not return a connection ack" + ) + + async def _send_init_message_and_wait_ack(self) -> None: + """Send init message to the provided websocket and wait for the connection ACK. + + If the answer is not a connection_ack message, we will return an Exception. + """ + + init_message = json.dumps( + {"type": "connection_init", "payload": self.init_payload} + ) + + await self._send(init_message) + + # Wait for the connection_ack message or raise a TimeoutError + await asyncio.wait_for(self._wait_ack(), self.ack_timeout) + + async def _initialize(self): + await self._send_init_message_and_wait_ack() + + async def send_ping(self, payload: Optional[Any] = None) -> None: + """Send a ping message for the graphql-ws protocol""" + + ping_message = {"type": "ping"} + + if payload is not None: + ping_message["payload"] = payload + + await self._send(json.dumps(ping_message)) + + async def send_pong(self, payload: Optional[Any] = None) -> None: + """Send a pong message for the graphql-ws protocol""" + + pong_message = {"type": "pong"} + + if payload is not None: + pong_message["payload"] = payload + + await self._send(json.dumps(pong_message)) + + async def _send_stop_message(self, query_id: int) -> None: + """Send stop message to the provided websocket connection and query_id. + + The server should afterwards return a 'complete' message. + """ + + stop_message = json.dumps({"id": str(query_id), "type": "stop"}) + + await self._send(stop_message) + + async def _send_complete_message(self, query_id: int) -> None: + """Send a complete message for the provided query_id. + + This is only for the graphql-ws protocol. + """ + + complete_message = json.dumps({"id": str(query_id), "type": "complete"}) + + await self._send(complete_message) + + async def _stop_listener(self, query_id: int): + """Stop the listener corresponding to the query_id depending on the + detected backend protocol. + + For apollo: send a "stop" message + (a "complete" message will be sent from the backend) + + For graphql-ws: send a "complete" message and simulate the reception + of a "complete" message from the backend + """ + log.debug(f"stop listener {query_id}") + + if self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL: + await self._send_complete_message(query_id) + await self.listeners[query_id].put(("complete", None)) + else: + await self._send_stop_message(query_id) + + async def _send_connection_terminate_message(self) -> None: + """Send a connection_terminate message to the provided websocket connection. + + This message indicates that the connection will disconnect. + """ + + connection_terminate_message = json.dumps({"type": "connection_terminate"}) + + await self._send(connection_terminate_message) + + async def _send_query( + self, + document: DocumentNode, + variable_values: Optional[Dict[str, Any]] = None, + operation_name: Optional[str] = None, + ) -> int: + """Send a query to the provided websocket connection. + + We use an incremented id to reference the query. + + Returns the used id for this query. + """ + + query_id = self.next_query_id + self.next_query_id += 1 + + payload: Dict[str, Any] = {"query": print_ast(document)} + if variable_values: + payload["variables"] = variable_values + if operation_name: + payload["operationName"] = operation_name + + query_type = "start" + + if self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL: + query_type = "subscribe" + + query_str = json.dumps( + {"id": str(query_id), "type": query_type, "payload": payload} + ) + + await self._send(query_str) + + return query_id + + async def _connection_terminate(self): + if self.subprotocol == self.APOLLO_SUBPROTOCOL: + await self._send_connection_terminate_message() + + def _parse_answer_graphqlws( + self, json_answer: Dict[str, Any] + ) -> Tuple[str, Optional[int], Optional[ExecutionResult]]: + """Parse the answer received from the server if the server supports the + graphql-ws protocol. + + Returns a list consisting of: + - the answer_type (between: + 'connection_ack', 'ping', 'pong', 'data', 'error', 'complete') + - the answer id (Integer) if received or None + - an execution Result if the answer_type is 'data' or None + + Differences with the apollo websockets protocol (superclass): + - the "data" message is now called "next" + - the "stop" message is now called "complete" + - there is no connection_terminate or connection_error messages + - instead of a unidirectional keep-alive (ka) message from server to client, + there is now the possibility to send bidirectional ping/pong messages + - connection_ack has an optional payload + - the 'error' answer type returns a list of errors instead of a single error + """ + + answer_type: str = "" + answer_id: Optional[int] = None + execution_result: Optional[ExecutionResult] = None + + try: + answer_type = str(json_answer.get("type")) + + if answer_type in ["next", "error", "complete"]: + answer_id = int(str(json_answer.get("id"))) + + if answer_type == "next" or answer_type == "error": + + payload = json_answer.get("payload") + + if answer_type == "next": + + if not isinstance(payload, dict): + raise ValueError("payload is not a dict") + + if "errors" not in payload and "data" not in payload: + raise ValueError( + "payload does not contain 'data' or 'errors' fields" + ) + + execution_result = ExecutionResult( + errors=payload.get("errors"), + data=payload.get("data"), + extensions=payload.get("extensions"), + ) + + # Saving answer_type as 'data' to be understood with superclass + answer_type = "data" + + elif answer_type == "error": + + if not isinstance(payload, list): + raise ValueError("payload is not a list") + + raise TransportQueryError( + str(payload[0]), query_id=answer_id, errors=payload + ) + + elif answer_type in ["ping", "pong", "connection_ack"]: + self.payloads[answer_type] = json_answer.get("payload", None) + + else: + raise ValueError + + if self.check_keep_alive_task is not None: + self._next_keep_alive_message.set() + + except ValueError as e: + raise TransportProtocolError( + f"Server did not return a GraphQL result: {json_answer}" + ) from e + + return answer_type, answer_id, execution_result + + def _parse_answer_apollo( + self, json_answer: Dict[str, Any] + ) -> Tuple[str, Optional[int], Optional[ExecutionResult]]: + """Parse the answer received from the server if the server supports the + apollo websockets protocol. + + Returns a list consisting of: + - the answer_type (between: + 'connection_ack', 'ka', 'connection_error', 'data', 'error', 'complete') + - the answer id (Integer) if received or None + - an execution Result if the answer_type is 'data' or None + """ + + answer_type: str = "" + answer_id: Optional[int] = None + execution_result: Optional[ExecutionResult] = None + + try: + answer_type = str(json_answer.get("type")) + + if answer_type in ["data", "error", "complete"]: + answer_id = int(str(json_answer.get("id"))) + + if answer_type == "data" or answer_type == "error": + + payload = json_answer.get("payload") + + if not isinstance(payload, dict): + raise ValueError("payload is not a dict") + + if answer_type == "data": + + if "errors" not in payload and "data" not in payload: + raise ValueError( + "payload does not contain 'data' or 'errors' fields" + ) + + execution_result = ExecutionResult( + errors=payload.get("errors"), + data=payload.get("data"), + extensions=payload.get("extensions"), + ) + + elif answer_type == "error": + + raise TransportQueryError( + str(payload), query_id=answer_id, errors=[payload] + ) + + elif answer_type == "ka": + # Keep-alive message + if self.check_keep_alive_task is not None: + self._next_keep_alive_message.set() + elif answer_type == "connection_ack": + pass + elif answer_type == "connection_error": + error_payload = json_answer.get("payload") + raise TransportServerError(f"Server error: '{repr(error_payload)}'") + else: + raise ValueError + + except ValueError as e: + raise TransportProtocolError( + f"Server did not return a GraphQL result: {json_answer}" + ) from e + + return answer_type, answer_id, execution_result + + def _parse_answer( + self, answer: str + ) -> Tuple[str, Optional[int], Optional[ExecutionResult]]: + """Parse the answer received from the server depending on + the detected subprotocol. + """ + try: + json_answer = json.loads(answer) + except ValueError: + raise TransportProtocolError( + f"Server did not return a GraphQL result: {answer}" + ) + + if self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL: + return self._parse_answer_graphqlws(json_answer) + + return self._parse_answer_apollo(json_answer) + + async def _send_ping_coro(self) -> None: + """Coroutine to periodically send a ping from the client to the backend. + + Only used for the graphql-ws protocol. + + Send a ping every ping_interval seconds. + Close the connection if a pong is not received within pong_timeout seconds. + """ + + assert self.ping_interval is not None + + try: + while True: + await asyncio.sleep(self.ping_interval) + + await self.send_ping() + + await asyncio.wait_for(self.pong_received.wait(), self.pong_timeout) + + # Reset for the next iteration + self.pong_received.clear() + + except asyncio.TimeoutError: + # No pong received in the appriopriate time, close with error + # If the timeout happens during a close already in progress, do nothing + if self.close_task is None: + await self._fail( + TransportServerError( + f"No pong received after {self.pong_timeout!r} seconds" + ), + clean_close=False, + ) + + async def _handle_answer( + self, + answer_type: str, + answer_id: Optional[int], + execution_result: Optional[ExecutionResult], + ) -> None: + + # Put the answer in the queue + await super()._handle_answer(answer_type, answer_id, execution_result) + + # Answer pong to ping for graphql-ws protocol + if answer_type == "ping": + self.ping_received.set() + if self.answer_pings: + await self.send_pong() + + elif answer_type == "pong": + self.pong_received.set() + + async def _after_connect(self): + + # Find the backend subprotocol returned in the response headers + try: + self.subprotocol = self.response_headers["Sec-WebSocket-Protocol"] + except KeyError: + # If the server does not send the subprotocol header, using + # the apollo subprotocol by default + self.subprotocol = self.APOLLO_SUBPROTOCOL + + log.debug(f"backend subprotocol returned: {self.subprotocol!r}") + + async def _after_initialize(self): + + # If requested, create a task to send periodic pings to the backend + if ( + self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL + and self.ping_interval is not None + ): + + self.send_ping_task = asyncio.ensure_future(self._send_ping_coro()) + + async def _close_hook(self): + log.debug("_close_hook: start") + + # Properly shut down the send ping task if enabled + if self.send_ping_task is not None: + log.debug("_close_hook: cancelling send_ping_task") + self.send_ping_task.cancel() + with suppress(asyncio.CancelledError, TransportConnectionClosed): + log.debug("_close_hook: awaiting send_ping_task") + await self.send_ping_task + self.send_ping_task = None + + log.debug("_close_hook: end") diff --git a/tests/test_phoenix_channel_subscription.py b/tests/test_phoenix_channel_subscription.py index 6193c658..3be4b07d 100644 --- a/tests/test_phoenix_channel_subscription.py +++ b/tests/test_phoenix_channel_subscription.py @@ -186,7 +186,7 @@ async def test_phoenix_channel_subscription( PhoenixChannelWebsocketsTransport, ) from gql.transport.phoenix_channel_websockets import log as phoenix_logger - from gql.transport.websockets import log as websockets_logger + from gql.transport.websockets_protocol import log as websockets_logger websockets_logger.setLevel(logging.DEBUG) phoenix_logger.setLevel(logging.DEBUG) @@ -227,7 +227,7 @@ async def test_phoenix_channel_subscription_no_break( PhoenixChannelWebsocketsTransport, ) from gql.transport.phoenix_channel_websockets import log as phoenix_logger - from gql.transport.websockets import log as websockets_logger + from gql.transport.websockets_protocol import log as websockets_logger from .conftest import MS diff --git a/tests/test_websocket_query.py b/tests/test_websocket_query.py index f509f676..7aa853bf 100644 --- a/tests/test_websocket_query.py +++ b/tests/test_websocket_query.py @@ -157,6 +157,9 @@ async def test_websocket_using_ssl_connection_self_cert_fail( transport = WebsocketsTransport(url=url, **extra_args) + if verify_https == "explicitely_enabled": + assert transport.ssl is True + with pytest.raises(SSLCertVerificationError) as exc_info: async with Client(transport=transport) as session: From 496add12c35fcc59bb580d15770b8ae8c633179d Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Sun, 9 Mar 2025 23:55:35 +0100 Subject: [PATCH 7/9] Use new connection adapter for aiohttp websockets --- gql/transport/aiohttp.py | 57 +- gql/transport/aiohttp_websockets.py | 1067 +---------------- gql/transport/appsync_websockets.py | 2 +- gql/transport/common/adapters/aiohttp.py | 269 +++++ gql/transport/common/adapters/connection.py | 13 +- gql/transport/common/adapters/websockets.py | 38 +- gql/transport/common/aiohttp_closed_event.py | 59 + gql/transport/websockets_protocol.py | 4 +- tests/test_aiohttp_websocket_exceptions.py | 8 +- ..._aiohttp_websocket_graphqlws_exceptions.py | 5 +- ...iohttp_websocket_graphqlws_subscription.py | 6 +- tests/test_aiohttp_websocket_query.py | 39 +- tests/test_aiohttp_websocket_subscription.py | 16 +- tests/test_phoenix_channel_query.py | 22 +- tests/test_websocket_query.py | 22 +- 15 files changed, 481 insertions(+), 1146 deletions(-) create mode 100644 gql/transport/common/adapters/aiohttp.py create mode 100644 gql/transport/common/aiohttp_closed_event.py diff --git a/gql/transport/aiohttp.py b/gql/transport/aiohttp.py index 0c332205..c1302794 100644 --- a/gql/transport/aiohttp.py +++ b/gql/transport/aiohttp.py @@ -1,5 +1,4 @@ import asyncio -import functools import io import json import logging @@ -28,6 +27,7 @@ from ..utils import extract_files from .appsync_auth import AppSyncAuthentication from .async_transport import AsyncTransport +from .common.aiohttp_closed_event import create_aiohttp_closed_event from .exceptions import ( TransportAlreadyConnected, TransportClosed, @@ -147,59 +147,6 @@ async def connect(self) -> None: else: raise TransportAlreadyConnected("Transport is already connected") - @staticmethod - def create_aiohttp_closed_event(session) -> asyncio.Event: - """Work around aiohttp issue that doesn't properly close transports on exit. - - See https://github.com/aio-libs/aiohttp/issues/1925#issuecomment-639080209 - - Returns: - An event that will be set once all transports have been properly closed. - """ - - ssl_transports = 0 - all_is_lost = asyncio.Event() - - def connection_lost(exc, orig_lost): - nonlocal ssl_transports - - try: - orig_lost(exc) - finally: - ssl_transports -= 1 - if ssl_transports == 0: - all_is_lost.set() - - def eof_received(orig_eof_received): - try: # pragma: no cover - orig_eof_received() - except AttributeError: # pragma: no cover - # It may happen that eof_received() is called after - # _app_protocol and _transport are set to None. - pass - - for conn in session.connector._conns.values(): - for handler, _ in conn: - proto = getattr(handler.transport, "_ssl_protocol", None) - if proto is None: - continue - - ssl_transports += 1 - orig_lost = proto.connection_lost - orig_eof_received = proto.eof_received - - proto.connection_lost = functools.partial( - connection_lost, orig_lost=orig_lost - ) - proto.eof_received = functools.partial( - eof_received, orig_eof_received=orig_eof_received - ) - - if ssl_transports == 0: - all_is_lost.set() - - return all_is_lost - async def close(self) -> None: """Coroutine which will close the aiohttp session. @@ -219,7 +166,7 @@ async def close(self) -> None: log.debug("connector_owner is False -> not closing connector") else: - closed_event = self.create_aiohttp_closed_event(self.session) + closed_event = create_aiohttp_closed_event(self.session) await self.session.close() try: await asyncio.wait_for(closed_event.wait(), self.ssl_close_timeout) diff --git a/gql/transport/aiohttp_websockets.py b/gql/transport/aiohttp_websockets.py index f97fbba8..59d870f6 100644 --- a/gql/transport/aiohttp_websockets.py +++ b/gql/transport/aiohttp_websockets.py @@ -1,53 +1,26 @@ -import asyncio -import json -import logging -import warnings -from contextlib import suppress from ssl import SSLContext -from typing import ( - Any, - AsyncGenerator, - Collection, - Dict, - Literal, - Mapping, - Optional, - Tuple, - Union, -) +from typing import Any, Dict, List, Literal, Mapping, Optional, Union -import aiohttp -from aiohttp import BasicAuth, Fingerprint, WSMsgType +from aiohttp import BasicAuth, ClientSession, Fingerprint from aiohttp.typedefs import LooseHeaders, StrOrURL -from graphql import DocumentNode, ExecutionResult, print_ast -from multidict import CIMultiDictProxy -from .aiohttp import AIOHTTPTransport -from .async_transport import AsyncTransport -from .common import ListenerQueue -from .exceptions import ( - TransportAlreadyConnected, - TransportClosed, - TransportProtocolError, - TransportQueryError, - TransportServerError, -) +from .common.adapters.aiohttp import AIOHTTPWebSocketsAdapter +from .websockets_protocol import WebsocketsProtocolTransportBase -log = logging.getLogger("gql.transport.aiohttp_websockets") +class AIOHTTPWebsocketsTransport(WebsocketsProtocolTransportBase): + """:ref:`Async Transport ` used to execute GraphQL queries on + remote servers with websocket connection. -class AIOHTTPWebsocketsTransport(AsyncTransport): - - # This transport supports two subprotocols and will autodetect the - # subprotocol supported on the server - APOLLO_SUBPROTOCOL: str = "graphql-ws" - GRAPHQLWS_SUBPROTOCOL: str = "graphql-transport-ws" + This transport uses asyncio and the provided aiohttp adapter library + in order to send requests on a websocket connection. + """ def __init__( self, url: StrOrURL, *, - subprotocols: Optional[Collection[str]] = None, + subprotocols: Optional[List[str]] = None, heartbeat: Optional[float] = None, auth: Optional[BasicAuth] = None, origin: Optional[str] = None, @@ -68,8 +41,9 @@ def __init__( ping_interval: Optional[Union[int, float]] = None, pong_timeout: Optional[Union[int, float]] = None, answer_pings: bool = True, + session: Optional[ClientSession] = None, client_session_args: Optional[Dict[str, Any]] = None, - connect_args: Dict[str, Any] = {}, + connect_args: Optional[Dict[str, Any]] = None, ) -> None: """Initialize the transport with the given parameters. @@ -140,6 +114,7 @@ def __init__( :param answer_pings: Whether the client answers the pings from the backend (for the graphql-ws protocol). By default: True + :param session: Optional aiohttp.ClientSession instance. :param client_session_args: Dict of extra args passed to `aiohttp.ClientSession`_ :param connect_args: Dict of extra args passed to @@ -150,986 +125,46 @@ def __init__( .. _aiohttp.ClientSession: https://docs.aiohttp.org/en/stable/client_reference.html#aiohttp.ClientSession """ - self.url: StrOrURL = url - self.heartbeat: Optional[float] = heartbeat - self.auth: Optional[BasicAuth] = auth - self.origin: Optional[str] = origin - self.params: Optional[Mapping[str, str]] = params - self.headers: Optional[LooseHeaders] = headers - - self.proxy: Optional[StrOrURL] = proxy - self.proxy_auth: Optional[BasicAuth] = proxy_auth - self.proxy_headers: Optional[LooseHeaders] = proxy_headers - - self.ssl: Optional[Union[SSLContext, Literal[False], Fingerprint]] = ssl - - self.websocket_close_timeout: float = websocket_close_timeout - self.receive_timeout: Optional[float] = receive_timeout - - self.ssl_close_timeout: Optional[Union[int, float]] = ssl_close_timeout - self.connect_timeout: Optional[Union[int, float]] = connect_timeout - self.close_timeout: Optional[Union[int, float]] = close_timeout - self.ack_timeout: Optional[Union[int, float]] = ack_timeout - self.keep_alive_timeout: Optional[Union[int, float]] = keep_alive_timeout - - self.init_payload: Dict[str, Any] = init_payload - - # We need to set an event loop here if there is none - # Or else we will not be able to create an asyncio.Event() - try: - with warnings.catch_warnings(): - warnings.filterwarnings( - "ignore", message="There is no current event loop" - ) - self._loop = asyncio.get_event_loop() - except RuntimeError: - self._loop = asyncio.new_event_loop() - asyncio.set_event_loop(self._loop) - - self._next_keep_alive_message: asyncio.Event = asyncio.Event() - self._next_keep_alive_message.set() - - self.session: Optional[aiohttp.ClientSession] = None - self.websocket: Optional[aiohttp.ClientWebSocketResponse] = None - self.next_query_id: int = 1 - self.listeners: Dict[int, ListenerQueue] = {} - self._connecting: bool = False - self.response_headers: Optional[CIMultiDictProxy[str]] = None - - self.receive_data_task: Optional[asyncio.Future] = None - self.check_keep_alive_task: Optional[asyncio.Future] = None - self.close_task: Optional[asyncio.Future] = None - - self._wait_closed: asyncio.Event = asyncio.Event() - self._wait_closed.set() - - self._no_more_listeners: asyncio.Event = asyncio.Event() - self._no_more_listeners.set() - - self.payloads: Dict[str, Any] = {} - - self.ping_interval: Optional[Union[int, float]] = ping_interval - self.pong_timeout: Optional[Union[int, float]] - self.answer_pings: bool = answer_pings - - if ping_interval is not None: - if pong_timeout is None: - self.pong_timeout = ping_interval / 2 - else: - self.pong_timeout = pong_timeout - - self.send_ping_task: Optional[asyncio.Future] = None - - self.ping_received: asyncio.Event = asyncio.Event() - """ping_received is an asyncio Event which will fire each time - a ping is received with the graphql-ws protocol""" - - self.pong_received: asyncio.Event = asyncio.Event() - """pong_received is an asyncio Event which will fire each time - a pong is received with the graphql-ws protocol""" - - self.supported_subprotocols: Collection[str] = subprotocols or ( - self.APOLLO_SUBPROTOCOL, - self.GRAPHQLWS_SUBPROTOCOL, - ) - - self.close_exception: Optional[Exception] = None - - self.client_session_args = client_session_args - self.connect_args = connect_args - - def _parse_answer_graphqlws( - self, answer: Dict[str, Any] - ) -> Tuple[str, Optional[int], Optional[ExecutionResult]]: - """Parse the answer received from the server if the server supports the - graphql-ws protocol. - - Returns a list consisting of: - - the answer_type (between: - 'connection_ack', 'ping', 'pong', 'data', 'error', 'complete') - - the answer id (Integer) if received or None - - an execution Result if the answer_type is 'data' or None - - Differences with the apollo websockets protocol (superclass): - - the "data" message is now called "next" - - the "stop" message is now called "complete" - - there is no connection_terminate or connection_error messages - - instead of a unidirectional keep-alive (ka) message from server to client, - there is now the possibility to send bidirectional ping/pong messages - - connection_ack has an optional payload - - the 'error' answer type returns a list of errors instead of a single error - """ - - answer_type: str = "" - answer_id: Optional[int] = None - execution_result: Optional[ExecutionResult] = None - - try: - answer_type = str(answer.get("type")) - - if answer_type in ["next", "error", "complete"]: - answer_id = int(str(answer.get("id"))) - - if answer_type == "next" or answer_type == "error": - - payload = answer.get("payload") - - if answer_type == "next": - - if not isinstance(payload, dict): - raise ValueError("payload is not a dict") - - if "errors" not in payload and "data" not in payload: - raise ValueError( - "payload does not contain 'data' or 'errors' fields" - ) - - execution_result = ExecutionResult( - errors=payload.get("errors"), - data=payload.get("data"), - extensions=payload.get("extensions"), - ) - - # Saving answer_type as 'data' to be understood with superclass - answer_type = "data" - - elif answer_type == "error": - - if not isinstance(payload, list): - raise ValueError("payload is not a list") - - raise TransportQueryError( - str(payload[0]), query_id=answer_id, errors=payload - ) - - elif answer_type in ["ping", "pong", "connection_ack"]: - self.payloads[answer_type] = answer.get("payload", None) - - else: - raise ValueError - - if self.check_keep_alive_task is not None: - self._next_keep_alive_message.set() - - except ValueError as e: - raise TransportProtocolError( - f"Server did not return a GraphQL result: {answer}" - ) from e - - return answer_type, answer_id, execution_result - - def _parse_answer_apollo( - self, answer: Dict[str, Any] - ) -> Tuple[str, Optional[int], Optional[ExecutionResult]]: - """Parse the answer received from the server if the server supports the - apollo websockets protocol. - - Returns a list consisting of: - - the answer_type (between: - 'connection_ack', 'ka', 'connection_error', 'data', 'error', 'complete') - - the answer id (Integer) if received or None - - an execution Result if the answer_type is 'data' or None - """ - - answer_type: str = "" - answer_id: Optional[int] = None - execution_result: Optional[ExecutionResult] = None - - try: - answer_type = str(answer.get("type")) - - if answer_type in ["data", "error", "complete"]: - answer_id = int(str(answer.get("id"))) - - if answer_type == "data" or answer_type == "error": - - payload = answer.get("payload") - - if not isinstance(payload, dict): - raise ValueError("payload is not a dict") - - if answer_type == "data": - - if "errors" not in payload and "data" not in payload: - raise ValueError( - "payload does not contain 'data' or 'errors' fields" - ) - - execution_result = ExecutionResult( - errors=payload.get("errors"), - data=payload.get("data"), - extensions=payload.get("extensions"), - ) - - elif answer_type == "error": - - raise TransportQueryError( - str(payload), query_id=answer_id, errors=[payload] - ) - - elif answer_type == "ka": - # Keep-alive message - if self.check_keep_alive_task is not None: - self._next_keep_alive_message.set() - elif answer_type == "connection_ack": - pass - elif answer_type == "connection_error": - error_payload = answer.get("payload") - raise TransportServerError(f"Server error: '{repr(error_payload)}'") - else: - raise ValueError - - except ValueError as e: - raise TransportProtocolError( - f"Server did not return a GraphQL result: {answer}" - ) from e - - return answer_type, answer_id, execution_result - - def _parse_answer( - self, answer: str - ) -> Tuple[str, Optional[int], Optional[ExecutionResult]]: - """Parse the answer received from the server depending on - the detected subprotocol. - """ - try: - json_answer = json.loads(answer) - except ValueError: - raise TransportProtocolError( - f"Server did not return a GraphQL result: {answer}" - ) - - if self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL: - return self._parse_answer_graphqlws(json_answer) - - return self._parse_answer_apollo(json_answer) - - async def _wait_ack(self) -> None: - """Wait for the connection_ack message. Keep alive messages are ignored""" - - while True: - init_answer = await self._receive() - - answer_type, _, _ = self._parse_answer(init_answer) - - if answer_type == "connection_ack": - return - - if answer_type != "ka": - raise TransportProtocolError( - "Websocket server did not return a connection ack" - ) - - async def _send_init_message_and_wait_ack(self) -> None: - """Send init message to the provided websocket and wait for the connection ACK. - - If the answer is not a connection_ack message, we will return an Exception. - """ - - init_message = {"type": "connection_init", "payload": self.init_payload} - - await self._send(init_message) - - # Wait for the connection_ack message or raise a TimeoutError - await asyncio.wait_for(self._wait_ack(), self.ack_timeout) - - async def _initialize(self): - """Hook to send the initialization messages after the connection - and potentially wait for the backend ack. - """ - await self._send_init_message_and_wait_ack() - - async def _stop_listener(self, query_id: int): - """Hook to stop to listen to a specific query. - Will send a stop message in some subclasses. - """ - log.debug(f"stop listener {query_id}") - - if self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL: - await self._send_complete_message(query_id) - await self.listeners[query_id].put(("complete", None)) - else: - await self._send_stop_message(query_id) - - async def _after_connect(self): - """Hook to add custom code for subclasses after the connection - has been established. - """ - # Find the backend subprotocol returned in the response headers - response_headers = self.websocket._response.headers - log.debug(f"Response headers: {response_headers!r}") - try: - self.subprotocol = response_headers["Sec-WebSocket-Protocol"] - except KeyError: - self.subprotocol = self.APOLLO_SUBPROTOCOL - - log.debug(f"backend subprotocol returned: {self.subprotocol!r}") - - async def send_ping(self, payload: Optional[Any] = None) -> None: - """Send a ping message for the graphql-ws protocol""" - - ping_message = {"type": "ping"} - - if payload is not None: - ping_message["payload"] = payload - - await self._send(ping_message) - - async def send_pong(self, payload: Optional[Any] = None) -> None: - """Send a pong message for the graphql-ws protocol""" - - pong_message = {"type": "pong"} - - if payload is not None: - pong_message["payload"] = payload - - await self._send(pong_message) - - async def _send_stop_message(self, query_id: int) -> None: - """Send stop message to the provided websocket connection and query_id. - - The server should afterwards return a 'complete' message. - """ - - stop_message = {"id": str(query_id), "type": "stop"} - - await self._send(stop_message) - - async def _send_complete_message(self, query_id: int) -> None: - """Send a complete message for the provided query_id. - - This is only for the graphql-ws protocol. - """ - - complete_message = {"id": str(query_id), "type": "complete"} - - await self._send(complete_message) - - async def _send_ping_coro(self) -> None: - """Coroutine to periodically send a ping from the client to the backend. - - Only used for the graphql-ws protocol. - - Send a ping every ping_interval seconds. - Close the connection if a pong is not received within pong_timeout seconds. - """ - - assert self.ping_interval is not None - - try: - while True: - await asyncio.sleep(self.ping_interval) - - await self.send_ping() - - await asyncio.wait_for(self.pong_received.wait(), self.pong_timeout) - - # Reset for the next iteration - self.pong_received.clear() - - except asyncio.TimeoutError: - # No pong received in the appriopriate time, close with error - # If the timeout happens during a close already in progress, do nothing - if self.close_task is None: - await self._fail( - TransportServerError( - f"No pong received after {self.pong_timeout!r} seconds" - ), - clean_close=False, - ) - - async def _after_initialize(self): - """Hook to add custom code for subclasses after the initialization - has been done. - """ - - # If requested, create a task to send periodic pings to the backend - if ( - self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL - and self.ping_interval is not None - ): - - self.send_ping_task = asyncio.ensure_future(self._send_ping_coro()) - - async def _close_hook(self): - """Hook to add custom code for subclasses for the connection close""" - # Properly shut down the send ping task if enabled - if self.send_ping_task is not None: - self.send_ping_task.cancel() - with suppress(asyncio.CancelledError): - await self.send_ping_task - self.send_ping_task = None - - async def _connection_terminate(self): - """Hook to add custom code for subclasses after the initialization - has been done. - """ - if self.subprotocol == self.APOLLO_SUBPROTOCOL: - await self._send_connection_terminate_message() - - async def _send_connection_terminate_message(self) -> None: - """Send a connection_terminate message to the provided websocket connection. - - This message indicates that the connection will disconnect. - """ - - connection_terminate_message = {"type": "connection_terminate"} - - await self._send(connection_terminate_message) - - async def _send_query( - self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, - ) -> int: - """Send a query to the provided websocket connection. - - We use an incremented id to reference the query. - - Returns the used id for this query. - """ - - query_id = self.next_query_id - self.next_query_id += 1 - - payload: Dict[str, Any] = {"query": print_ast(document)} - if variable_values: - payload["variables"] = variable_values - if operation_name: - payload["operationName"] = operation_name - - query_type = "start" - - if self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL: - query_type = "subscribe" - - query = {"id": str(query_id), "type": query_type, "payload": payload} - - await self._send(query) - - return query_id - - async def _send(self, message: Dict[str, Any]) -> None: - """Send the provided message to the websocket connection and log the message""" - - if self.websocket is None: - raise TransportClosed("WebSocket connection is closed") - - try: - await self.websocket.send_json(message) - log.info(">>> %s", message) - except ConnectionResetError as e: - await self._fail(e, clean_close=False) - raise e - - async def _receive(self) -> str: - """Wait the next message from the websocket connection and log the answer""" - - # It is possible that the websocket has been already closed in another task - if self.websocket is None: - raise TransportClosed("Transport is already closed") - - while True: - ws_message = await self.websocket.receive() - - # Ignore low-level ping and pong received - if ws_message.type not in (WSMsgType.PING, WSMsgType.PONG): - break - - if ws_message.type in ( - WSMsgType.CLOSE, - WSMsgType.CLOSED, - WSMsgType.CLOSING, - WSMsgType.ERROR, - ): - raise ConnectionResetError - elif ws_message.type is WSMsgType.BINARY: - raise TransportProtocolError("Binary data received in the websocket") - - assert ws_message.type is WSMsgType.TEXT - - answer: str = ws_message.data - - log.info("<<< %s", answer) - - return answer - - def _remove_listener(self, query_id) -> None: - """After exiting from a subscription, remove the listener and - signal an event if this was the last listener for the client. - """ - if query_id in self.listeners: - del self.listeners[query_id] - - remaining = len(self.listeners) - log.debug(f"listener {query_id} deleted, {remaining} remaining") - - if remaining == 0: - self._no_more_listeners.set() - - async def _check_ws_liveness(self) -> None: - """Coroutine which will periodically check the liveness of the connection - through keep-alive messages - """ - - try: - while True: - await asyncio.wait_for( - self._next_keep_alive_message.wait(), self.keep_alive_timeout - ) - - # Reset for the next iteration - self._next_keep_alive_message.clear() - - except asyncio.TimeoutError: - # No keep-alive message in the appriopriate interval, close with error - # while trying to notify the server of a proper close (in case - # the keep-alive interval of the client or server was not aligned - # the connection still remains) - - # If the timeout happens during a close already in progress, do nothing - if self.close_task is None: - await self._fail( - TransportServerError( - "No keep-alive message has been received within " - "the expected interval ('keep_alive_timeout' parameter)" - ), - clean_close=False, - ) - - except asyncio.CancelledError: - # The client is probably closing, handle it properly - pass - - async def _handle_answer( - self, - answer_type: str, - answer_id: Optional[int], - execution_result: Optional[ExecutionResult], - ) -> None: - - try: - # Put the answer in the queue - if answer_id is not None: - await self.listeners[answer_id].put((answer_type, execution_result)) - except KeyError: - # Do nothing if no one is listening to this query_id. - pass - - # Answer pong to ping for graphql-ws protocol - if answer_type == "ping": - self.ping_received.set() - if self.answer_pings: - await self.send_pong() - - elif answer_type == "pong": - self.pong_received.set() - - async def _receive_data_loop(self) -> None: - """Main asyncio task which will listen to the incoming messages and will - call the parse_answer and handle_answer methods of the subclass.""" - log.debug("Entering _receive_data_loop()") - - try: - while True: - - # Wait the next answer from the websocket server - try: - answer = await self._receive() - except (ConnectionResetError, TransportProtocolError) as e: - await self._fail(e, clean_close=False) - break - except TransportClosed as e: - await self._fail(e, clean_close=False) - raise e - - # Parse the answer - try: - answer_type, answer_id, execution_result = self._parse_answer( - answer - ) - except TransportQueryError as e: - # Received an exception for a specific query - # ==> Add an exception to this query queue - # The exception is raised for this specific query, - # but the transport is not closed. - assert isinstance( - e.query_id, int - ), "TransportQueryError should have a query_id defined here" - try: - await self.listeners[e.query_id].set_exception(e) - except KeyError: - # Do nothing if no one is listening to this query_id - pass - - continue - - except (TransportServerError, TransportProtocolError) as e: - # Received a global exception for this transport - # ==> close the transport - # The exception will be raised for all current queries. - await self._fail(e, clean_close=False) - break - - await self._handle_answer(answer_type, answer_id, execution_result) - - finally: - log.debug("Exiting _receive_data_loop()") - - async def connect(self) -> None: - log.debug("connect: starting") - - if self.session is None: - client_session_args: Dict[str, Any] = {} - - # Adding custom parameters passed from init - if self.client_session_args: - client_session_args.update(self.client_session_args) # type: ignore - - self.session = aiohttp.ClientSession(**client_session_args) - - if self.websocket is None and not self._connecting: - self._connecting = True - - connect_args: Dict[str, Any] = { - "url": self.url, - "headers": self.headers, - "auth": self.auth, - "heartbeat": self.heartbeat, - "origin": self.origin, - "params": self.params, - "protocols": self.supported_subprotocols, - "proxy": self.proxy, - "proxy_auth": self.proxy_auth, - "proxy_headers": self.proxy_headers, - "timeout": self.websocket_close_timeout, - "receive_timeout": self.receive_timeout, - } - - if self.ssl is not None: - connect_args.update( - { - "ssl": self.ssl, - } - ) - - # Adding custom parameters passed from init - if self.connect_args: - connect_args.update(self.connect_args) - - try: - # Connection to the specified url - # Generate a TimeoutError if taking more than connect_timeout seconds - # Set the _connecting flag to False after in all cases - self.websocket = await asyncio.wait_for( - self.session.ws_connect( - **connect_args, - ), - self.connect_timeout, - ) - finally: - self._connecting = False - - self.response_headers = self.websocket._response.headers - - await self._after_connect() - - self.next_query_id = 1 - self.close_exception = None - self._wait_closed.clear() - - # Send the init message and wait for the ack from the server - # Note: This should generate a TimeoutError - # if no ACKs are received within the ack_timeout - try: - await self._initialize() - except ConnectionResetError as e: - raise e - except ( - TransportProtocolError, - TransportServerError, - asyncio.TimeoutError, - ) as e: - await self._fail(e, clean_close=False) - raise e - - # Run the after_init hook of the subclass - await self._after_initialize() - - # If specified, create a task to check liveness of the connection - # through keep-alive messages - if self.keep_alive_timeout is not None: - self.check_keep_alive_task = asyncio.ensure_future( - self._check_ws_liveness() - ) - - # Create a task to listen to the incoming websocket messages - self.receive_data_task = asyncio.ensure_future(self._receive_data_loop()) - - else: - raise TransportAlreadyConnected("Transport is already connected") - - log.debug("connect: done") - async def _clean_close(self) -> None: - """Coroutine which will: - - - send stop messages for each active subscription to the server - - send the connection terminate message - """ - log.debug(f"Listeners: {self.listeners}") - - # Send 'stop' message for all current queries - for query_id, listener in self.listeners.items(): - print(f"Listener {query_id} send_stop: {listener.send_stop}") - - if listener.send_stop: - await self._stop_listener(query_id) - listener.send_stop = False - - # Wait that there is no more listeners (we received 'complete' for all queries) - try: - await asyncio.wait_for(self._no_more_listeners.wait(), self.close_timeout) - except asyncio.TimeoutError: # pragma: no cover - log.debug("Timer close_timeout fired") - - # Calling the subclass hook - await self._connection_terminate() - - async def _close_coro(self, e: Exception, clean_close: bool = True) -> None: - """Coroutine which will: - - - do a clean_close if possible: - - send stop messages for each active query to the server - - send the connection terminate message - - close the websocket connection - - send the exception to all the remaining listeners - """ - - log.debug("_close_coro: starting") - - try: - - try: - # Properly shut down liveness checker if enabled - if self.check_keep_alive_task is not None: - # More info: https://stackoverflow.com/a/43810272/1113207 - self.check_keep_alive_task.cancel() - with suppress(asyncio.CancelledError): - await self.check_keep_alive_task - except Exception as exc: # pragma: no cover - log.warning( - "_close_coro cancel keep alive task exception: " + repr(exc) - ) - - try: - # Calling the subclass close hook - await self._close_hook() - except Exception as exc: # pragma: no cover - log.warning("_close_coro close_hook exception: " + repr(exc)) - - # Saving exception to raise it later if trying to use the transport - # after it has already closed. - self.close_exception = e - - if clean_close: - log.debug("_close_coro: starting clean_close") - try: - await self._clean_close() - except Exception as exc: # pragma: no cover - log.warning("Ignoring exception in _clean_close: " + repr(exc)) - - log.debug("_close_coro: sending exception to listeners") - - # Send an exception to all remaining listeners - for query_id, listener in self.listeners.items(): - await listener.set_exception(e) - - log.debug("_close_coro: close websocket connection") - - try: - assert self.websocket is not None - - await self.websocket.close() - self.websocket = None - except Exception as exc: - log.warning("_close_coro websocket close exception: " + repr(exc)) - - log.debug("_close_coro: close aiohttp session") - - if ( - self.client_session_args - and self.client_session_args.get("connector_owner") is False - ): - - log.debug("connector_owner is False -> not closing connector") - - else: - try: - assert self.session is not None - - closed_event = AIOHTTPTransport.create_aiohttp_closed_event( - self.session - ) - await self.session.close() - try: - await asyncio.wait_for( - closed_event.wait(), self.ssl_close_timeout - ) - except asyncio.TimeoutError: - pass - except Exception as exc: # pragma: no cover - log.warning("_close_coro session close exception: " + repr(exc)) - - self.session = None - - log.debug("_close_coro: aiohttp session closed") - - try: - assert self.receive_data_task is not None - - self.receive_data_task.cancel() - with suppress(asyncio.CancelledError): - await self.receive_data_task - except Exception as exc: # pragma: no cover - log.warning( - "_close_coro cancel receive data task exception: " + repr(exc) - ) - - except Exception as exc: # pragma: no cover - log.warning("Exception catched in _close_coro: " + repr(exc)) - - finally: - - log.debug("_close_coro: final cleanup") - - self.websocket = None - self.close_task = None - self.check_keep_alive_task = None - self.receive_data_task = None - self._wait_closed.set() - - log.debug("_close_coro: exiting") - - async def _fail(self, e: Exception, clean_close: bool = True) -> None: - log.debug("_fail: starting with exception: " + repr(e)) - - if self.close_task is None: - - if self._wait_closed.is_set(): - log.debug("_fail started but transport is already closed") - else: - self.close_task = asyncio.shield( - asyncio.ensure_future(self._close_coro(e, clean_close=clean_close)) - ) - else: - log.debug( - "close_task is not None in _fail. Previous exception is: " - + repr(self.close_exception) - + " New exception is: " - + repr(e) - ) - - async def close(self) -> None: - log.debug("close: starting") - - await self._fail(TransportClosed("Websocket GraphQL transport closed by user")) - await self.wait_closed() - - log.debug("close: done") - - async def wait_closed(self) -> None: - log.debug("wait_close: starting") - - if not self._wait_closed.is_set(): - await self._wait_closed.wait() - - log.debug("wait_close: done") - - async def execute( - self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, - ) -> ExecutionResult: - """Execute the provided document AST against the configured remote server - using the current session. - - Send a query but close the async generator as soon as we have the first answer. - - The result is sent as an ExecutionResult object. - """ - first_result = None - - generator = self.subscribe( - document, variable_values, operation_name, send_stop=False + # Instanciate a AIOHTTPWebSocketAdapter to indicate the use + # of the aiohttp dependency for this transport + self.adapter: AIOHTTPWebSocketsAdapter = AIOHTTPWebSocketsAdapter( + url=url, + headers=headers, + ssl=ssl, + session=session, + client_session_args=client_session_args, + connect_args=connect_args, + heartbeat=heartbeat, + auth=auth, + origin=origin, + params=params, + proxy=proxy, + proxy_auth=proxy_auth, + proxy_headers=proxy_headers, + websocket_close_timeout=websocket_close_timeout, + receive_timeout=receive_timeout, + ssl_close_timeout=ssl_close_timeout, ) - async for result in generator: - first_result = result - break - - if first_result is None: - raise TransportQueryError( - "Query completed without any answer received from the server" - ) - - return first_result - - async def subscribe( - self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, - send_stop: Optional[bool] = True, - ) -> AsyncGenerator[ExecutionResult, None]: - """Send a query and receive the results using a python async generator. - - The query can be a graphql query, mutation or subscription. - - The results are sent as an ExecutionResult object. - """ - - # Send the query and receive the id - query_id: int = await self._send_query( - document, variable_values, operation_name + # Initialize the WebsocketsProtocolTransportBase parent class + super().__init__( + adapter=self.adapter, + init_payload=init_payload, + connect_timeout=connect_timeout, + close_timeout=close_timeout, + ack_timeout=ack_timeout, + keep_alive_timeout=keep_alive_timeout, + ping_interval=ping_interval, + pong_timeout=pong_timeout, + answer_pings=answer_pings, + subprotocols=subprotocols, ) - # Create a queue to receive the answers for this query_id - listener = ListenerQueue(query_id, send_stop=(send_stop is True)) - self.listeners[query_id] = listener - - # We will need to wait at close for this query to clean properly - self._no_more_listeners.clear() - - try: - # Loop over the received answers - while True: - - # Wait for the answer from the queue of this query_id - # This can raise a TransportError or ConnectionClosed exception. - answer_type, execution_result = await listener.get() - - # If the received answer contains data, - # Then we will yield the results back as an ExecutionResult object - if execution_result is not None: - yield execution_result - - # If we receive a 'complete' answer from the server, - # Then we will end this async generator output without errors - elif answer_type == "complete": - log.debug( - f"Complete received for query {query_id} --> exit without error" - ) - break - - except (asyncio.CancelledError, GeneratorExit) as e: - log.debug(f"Exception in subscribe: {e!r}") - if listener.send_stop: - await self._stop_listener(query_id) - listener.send_stop = False + @property + def headers(self) -> Optional[LooseHeaders]: + return self.adapter.headers - finally: - log.debug(f"In subscribe finally for query_id {query_id}") - self._remove_listener(query_id) + @property + def ssl(self) -> Optional[Union[SSLContext, Literal[False], Fingerprint]]: + return self.adapter.ssl diff --git a/gql/transport/appsync_websockets.py b/gql/transport/appsync_websockets.py index e0f5c031..f35cefe5 100644 --- a/gql/transport/appsync_websockets.py +++ b/gql/transport/appsync_websockets.py @@ -95,7 +95,7 @@ def __init__( ) # Using the same 'graphql-ws' protocol as the apollo protocol - self.supported_subprotocols = [ + self.adapter.subprotocols = [ WebsocketsTransport.APOLLO_SUBPROTOCOL, ] self.subprotocol = WebsocketsTransport.APOLLO_SUBPROTOCOL diff --git a/gql/transport/common/adapters/aiohttp.py b/gql/transport/common/adapters/aiohttp.py new file mode 100644 index 00000000..d9af7c50 --- /dev/null +++ b/gql/transport/common/adapters/aiohttp.py @@ -0,0 +1,269 @@ +import asyncio +import logging +from ssl import SSLContext +from typing import Any, Dict, Literal, Mapping, Optional, Union + +import aiohttp +from aiohttp import BasicAuth, Fingerprint, WSMsgType +from aiohttp.typedefs import LooseHeaders, StrOrURL +from multidict import CIMultiDictProxy + +from ...exceptions import TransportConnectionClosed, TransportProtocolError +from ..aiohttp_closed_event import create_aiohttp_closed_event +from .connection import AdapterConnection + +log = logging.getLogger("gql.transport.common.adapters.aiohttp") + + +class AIOHTTPWebSocketsAdapter(AdapterConnection): + """AdapterConnection implementation using the aiohttp library.""" + + def __init__( + self, + url: StrOrURL, + *, + headers: Optional[LooseHeaders] = None, + ssl: Optional[Union[SSLContext, Literal[False], Fingerprint]] = None, + session: Optional[aiohttp.ClientSession] = None, + client_session_args: Optional[Dict[str, Any]] = None, + connect_args: Optional[Dict[str, Any]] = None, + heartbeat: Optional[float] = None, + auth: Optional[BasicAuth] = None, + origin: Optional[str] = None, + params: Optional[Mapping[str, str]] = None, + proxy: Optional[StrOrURL] = None, + proxy_auth: Optional[BasicAuth] = None, + proxy_headers: Optional[LooseHeaders] = None, + websocket_close_timeout: float = 10.0, + receive_timeout: Optional[float] = None, + ssl_close_timeout: Optional[Union[int, float]] = 10, + ) -> None: + """Initialize the transport with the given parameters. + + :param url: The GraphQL server URL. Example: 'wss://server.com:PORT/graphql'. + :param headers: Dict of HTTP Headers. + :param ssl: SSL validation mode. ``True`` for default SSL check + (:func:`ssl.create_default_context` is used), + ``False`` for skip SSL certificate validation, + :class:`aiohttp.Fingerprint` for fingerprint + validation, :class:`ssl.SSLContext` for custom SSL + certificate validation. + :param session: Optional aiohttp opened session. + :param client_session_args: Dict of extra args passed to + `aiohttp.ClientSession`_ + :param connect_args: Dict of extra args passed to + `aiohttp.ClientSession.ws_connect`_ + + :param float heartbeat: Send low level `ping` message every `heartbeat` + seconds and wait `pong` response, close + connection if `pong` response is not + received. The timer is reset on any data reception. + :param auth: An object that represents HTTP Basic Authorization. + :class:`~aiohttp.BasicAuth` (optional) + :param str origin: Origin header to send to server(optional) + :param params: Mapping, iterable of tuple of *key*/*value* pairs or + string to be sent as parameters in the query + string of the new request. Ignored for subsequent + redirected requests (optional) + + Allowed values are: + + - :class:`collections.abc.Mapping` e.g. :class:`dict`, + :class:`multidict.MultiDict` or + :class:`multidict.MultiDictProxy` + - :class:`collections.abc.Iterable` e.g. :class:`tuple` or + :class:`list` + - :class:`str` with preferably url-encoded content + (**Warning:** content will not be encoded by *aiohttp*) + :param proxy: Proxy URL, :class:`str` or :class:`~yarl.URL` (optional) + :param aiohttp.BasicAuth proxy_auth: an object that represents proxy HTTP + Basic Authorization (optional) + :param float websocket_close_timeout: Timeout for websocket to close. + ``10`` seconds by default + :param float receive_timeout: Timeout for websocket to receive + complete message. ``None`` (unlimited) + seconds by default + :param ssl_close_timeout: Timeout in seconds to wait for the ssl connection + to close properly + """ + super().__init__( + url=str(url), + connect_args=connect_args, + ) + + self._headers: Optional[LooseHeaders] = headers + self.ssl: Optional[Union[SSLContext, Literal[False], Fingerprint]] = ssl + + self.session: Optional[aiohttp.ClientSession] = session + self._using_external_session = True if self.session else False + + if client_session_args is None: + client_session_args = {} + self.client_session_args = client_session_args + + self.heartbeat: Optional[float] = heartbeat + self.auth: Optional[BasicAuth] = auth + self.origin: Optional[str] = origin + self.params: Optional[Mapping[str, str]] = params + + self.proxy: Optional[StrOrURL] = proxy + self.proxy_auth: Optional[BasicAuth] = proxy_auth + self.proxy_headers: Optional[LooseHeaders] = proxy_headers + + self.websocket_close_timeout: float = websocket_close_timeout + self.receive_timeout: Optional[float] = receive_timeout + + self.ssl_close_timeout: Optional[Union[int, float]] = ssl_close_timeout + + self.websocket: Optional[aiohttp.ClientWebSocketResponse] = None + self._response_headers: Optional[CIMultiDictProxy[str]] = None + + async def connect(self) -> None: + """Connect to the WebSocket server.""" + + assert self.websocket is None + + # Create a session if necessary + if self.session is None: + client_session_args: Dict[str, Any] = {} + + # Adding custom parameters passed from init + client_session_args.update(self.client_session_args) # type: ignore + + self.session = aiohttp.ClientSession(**client_session_args) + + connect_args: Dict[str, Any] = { + "url": self.url, + "headers": self.headers, + "auth": self.auth, + "heartbeat": self.heartbeat, + "origin": self.origin, + "params": self.params, + "proxy": self.proxy, + "proxy_auth": self.proxy_auth, + "proxy_headers": self.proxy_headers, + "timeout": self.websocket_close_timeout, + "receive_timeout": self.receive_timeout, + } + + if self.subprotocols: + connect_args["protocols"] = self.subprotocols + + if self.ssl is not None: + connect_args["ssl"] = self.ssl + + # Adding custom parameters passed from init + connect_args.update(self.connect_args) + + try: + self.websocket = await self.session.ws_connect( + **connect_args, + ) + except Exception as e: + raise TransportConnectionClosed("Connect failed") from e + + self._response_headers = self.websocket._response.headers + + async def send(self, message: str) -> None: + """Send message to the WebSocket server. + + Args: + message: String message to send + + Raises: + TransportConnectionClosed: If connection closed + """ + if self.websocket is None: + raise TransportConnectionClosed("Connection is already closed") + + try: + await self.websocket.send_str(message) + except ConnectionResetError as e: + raise TransportConnectionClosed("Connection was closed") from e + + async def receive(self) -> str: + """Receive message from the WebSocket server. + + Returns: + String message received + + Raises: + TransportConnectionClosed: If connection closed + TransportProtocolError: If protocol error or binary data received + """ + # It is possible that the websocket has been already closed in another task + if self.websocket is None: + raise TransportConnectionClosed("Connection is already closed") + + while True: + ws_message = await self.websocket.receive() + + # Ignore low-level ping and pong received + if ws_message.type not in (WSMsgType.PING, WSMsgType.PONG): + break + + if ws_message.type in ( + WSMsgType.CLOSE, + WSMsgType.CLOSED, + WSMsgType.CLOSING, + WSMsgType.ERROR, + ): + raise TransportConnectionClosed("Connection was closed") + elif ws_message.type is WSMsgType.BINARY: + raise TransportProtocolError("Binary data received in the websocket") + + assert ws_message.type is WSMsgType.TEXT + + answer: str = ws_message.data + + return answer + + async def _close_session(self) -> None: + """Close the aiohttp session.""" + + assert self.session is not None + + closed_event = create_aiohttp_closed_event(self.session) + await self.session.close() + try: + await asyncio.wait_for(closed_event.wait(), self.ssl_close_timeout) + except asyncio.TimeoutError: + pass + finally: + self.session = None + + async def close(self) -> None: + """Close the WebSocket connection.""" + + if self.websocket: + websocket = self.websocket + self.websocket = None + try: + await websocket.close() + except Exception as exc: # pragma: no cover + log.warning("websocket.close() exception: " + repr(exc)) + + if self.session and not self._using_external_session: + await self._close_session() + + @property + def headers(self) -> Optional[LooseHeaders]: + """Get the response headers from the WebSocket connection. + + Returns: + Dictionary of response headers + """ + if self._headers: + return self._headers + return {} + + @property + def response_headers(self) -> Dict[str, str]: + """Get the response headers from the WebSocket connection. + + Returns: + Dictionary of response headers + """ + if self._response_headers: + return dict(self._response_headers) + return {} diff --git a/gql/transport/common/adapters/connection.py b/gql/transport/common/adapters/connection.py index cf361b8d..f3d77421 100644 --- a/gql/transport/common/adapters/connection.py +++ b/gql/transport/common/adapters/connection.py @@ -1,5 +1,5 @@ import abc -from typing import Any, Dict +from typing import Any, Dict, List, Optional class AdapterConnection(abc.ABC): @@ -10,6 +10,17 @@ class AdapterConnection(abc.ABC): url: str connect_args: Dict[str, Any] + subprotocols: Optional[List[str]] + + def __init__(self, url: str, connect_args: Optional[Dict[str, Any]]): + """Initialize the connection adapter.""" + self.url: str = url + + if connect_args is None: + connect_args = {} + self.connect_args = connect_args + + self.subprotocols = None @abc.abstractmethod async def connect(self) -> None: diff --git a/gql/transport/common/adapters/websockets.py b/gql/transport/common/adapters/websockets.py index 4494e256..383d4def 100644 --- a/gql/transport/common/adapters/websockets.py +++ b/gql/transport/common/adapters/websockets.py @@ -1,14 +1,16 @@ +import logging from ssl import SSLContext from typing import Any, Dict, Optional, Union import websockets from websockets.client import WebSocketClientProtocol from websockets.datastructures import Headers, HeadersLike -from websockets.exceptions import WebSocketException from ...exceptions import TransportConnectionClosed, TransportProtocolError from .connection import AdapterConnection +log = logging.getLogger("gql.transport.common.adapters.websockets") + class WebSocketsAdapter(AdapterConnection): """AdapterConnection implementation using the websockets library.""" @@ -26,16 +28,17 @@ def __init__( :param url: The GraphQL server URL. Example: 'wss://server.com:PORT/graphql'. :param headers: Dict of HTTP Headers. :param ssl: ssl_context of the connection. Use ssl=False to disable encryption - :param connect_args: Other parameters forwarded to websockets.connect + :param connect_args: Other parameters forwarded to + `websockets.connect `_ """ - self.url: str = url - self._headers: Optional[HeadersLike] = headers - self.ssl: Union[SSLContext, bool] = ssl + super().__init__( + url=url, + connect_args=connect_args, + ) - if connect_args is None: - connect_args = {} - - self.connect_args = connect_args + self._headers: Optional[HeadersLike] = headers + self.ssl = ssl self.websocket: Optional[WebSocketClientProtocol] = None self._response_headers: Optional[Headers] = None @@ -57,14 +60,17 @@ async def connect(self) -> None: "extra_headers": self.headers, } + if self.subprotocols: + connect_args["subprotocols"] = self.subprotocols + # Adding custom parameters passed from init connect_args.update(self.connect_args) # Connection to the specified url try: self.websocket = await websockets.client.connect(self.url, **connect_args) - except WebSocketException as e: - raise TransportConnectionClosed("Connection was closed") from e + except Exception as e: + raise TransportConnectionClosed("Connect failed") from e self._response_headers = self.websocket.response_headers @@ -82,7 +88,7 @@ async def send(self, message: str) -> None: try: await self.websocket.send(message) - except WebSocketException as e: + except Exception as e: raise TransportConnectionClosed("Connection was closed") from e async def receive(self) -> str: @@ -102,9 +108,7 @@ async def receive(self) -> str: # Wait for the next websocket frame. Can raise ConnectionClosed try: data = await self.websocket.recv() - except WebSocketException as e: - # When the connection is closed, make sure to clean up resources - self.websocket = None + except Exception as e: raise TransportConnectionClosed("Connection was closed") from e # websocket.recv() can return either str or bytes @@ -124,14 +128,14 @@ async def close(self) -> None: await websocket.close() @property - def headers(self) -> Dict[str, str]: + def headers(self) -> Optional[HeadersLike]: """Get the response headers from the WebSocket connection. Returns: Dictionary of response headers """ if self._headers: - return dict(self._headers) + return self._headers return {} @property diff --git a/gql/transport/common/aiohttp_closed_event.py b/gql/transport/common/aiohttp_closed_event.py new file mode 100644 index 00000000..412448f9 --- /dev/null +++ b/gql/transport/common/aiohttp_closed_event.py @@ -0,0 +1,59 @@ +import asyncio +import functools + +from aiohttp import ClientSession + + +def create_aiohttp_closed_event(session: ClientSession) -> asyncio.Event: + """Work around aiohttp issue that doesn't properly close transports on exit. + + See https://github.com/aio-libs/aiohttp/issues/1925#issuecomment-639080209 + + Returns: + An event that will be set once all transports have been properly closed. + """ + + ssl_transports = 0 + all_is_lost = asyncio.Event() + + def connection_lost(exc, orig_lost): + nonlocal ssl_transports + + try: + orig_lost(exc) + finally: + ssl_transports -= 1 + if ssl_transports == 0: + all_is_lost.set() + + def eof_received(orig_eof_received): + try: # pragma: no cover + orig_eof_received() + except AttributeError: # pragma: no cover + # It may happen that eof_received() is called after + # _app_protocol and _transport are set to None. + pass + + assert session.connector is not None + + for conn in session.connector._conns.values(): + for handler, _ in conn: + proto = getattr(handler.transport, "_ssl_protocol", None) + if proto is None: + continue + + ssl_transports += 1 + orig_lost = proto.connection_lost + orig_eof_received = proto.eof_received + + proto.connection_lost = functools.partial( + connection_lost, orig_lost=orig_lost + ) + proto.eof_received = functools.partial( + eof_received, orig_eof_received=orig_eof_received + ) + + if ssl_transports == 0: + all_is_lost.set() + + return all_is_lost diff --git a/gql/transport/websockets_protocol.py b/gql/transport/websockets_protocol.py index 84ba7656..f004d240 100644 --- a/gql/transport/websockets_protocol.py +++ b/gql/transport/websockets_protocol.py @@ -6,7 +6,7 @@ from graphql import DocumentNode, ExecutionResult, print_ast -from .common.adapters.websockets import AdapterConnection +from .common.adapters.connection import AdapterConnection from .common.base import SubscriptionTransportBase from .exceptions import ( TransportConnectionClosed, @@ -80,7 +80,7 @@ def __init__( self.GRAPHQLWS_SUBPROTOCOL, ] - self.adapter.connect_args.update({"subprotocols": subprotocols}) + self.adapter.subprotocols = subprotocols # Initialize the generic SubscriptionTransportBase parent class super().__init__( diff --git a/tests/test_aiohttp_websocket_exceptions.py b/tests/test_aiohttp_websocket_exceptions.py index 8ee44d2c..e4e56fcd 100644 --- a/tests/test_aiohttp_websocket_exceptions.py +++ b/tests/test_aiohttp_websocket_exceptions.py @@ -7,7 +7,7 @@ from gql import Client, gql from gql.transport.exceptions import ( - TransportClosed, + TransportConnectionClosed, TransportProtocolError, TransportQueryError, ) @@ -148,7 +148,7 @@ async def test_aiohttp_websocket_sending_invalid_data( invalid_data = "QSDF" print(f">>> {invalid_data}") - await session.transport.websocket.send_str(invalid_data) + await session.transport.adapter.websocket.send_str(invalid_data) await asyncio.sleep(2 * MS) @@ -289,7 +289,7 @@ async def test_aiohttp_websocket_server_closing_directly(event_loop, server): sample_transport = AIOHTTPWebsocketsTransport(url=url) - with pytest.raises(ConnectionResetError): + with pytest.raises(TransportConnectionClosed): async with Client(transport=sample_transport): pass @@ -309,7 +309,7 @@ async def test_aiohttp_websocket_server_closing_after_ack( query = gql("query { hello }") - with pytest.raises(TransportClosed): + with pytest.raises(TransportConnectionClosed): await session.execute(query) diff --git a/tests/test_aiohttp_websocket_graphqlws_exceptions.py b/tests/test_aiohttp_websocket_graphqlws_exceptions.py index b234d296..8f3567a7 100644 --- a/tests/test_aiohttp_websocket_graphqlws_exceptions.py +++ b/tests/test_aiohttp_websocket_graphqlws_exceptions.py @@ -6,6 +6,7 @@ from gql import Client, gql from gql.transport.exceptions import ( TransportClosed, + TransportConnectionClosed, TransportProtocolError, TransportQueryError, ) @@ -247,7 +248,7 @@ async def test_aiohttp_websocket_graphqlws_server_closing_directly( transport = AIOHTTPWebsocketsTransport(url=url) - with pytest.raises(ConnectionResetError): + with pytest.raises(TransportConnectionClosed): async with Client(transport=transport): pass @@ -267,7 +268,7 @@ async def test_aiohttp_websocket_graphqlws_server_closing_after_ack( query = gql("query { hello }") - with pytest.raises(TransportClosed): + with pytest.raises(TransportConnectionClosed): await session.execute(query) await session.transport.wait_closed() diff --git a/tests/test_aiohttp_websocket_graphqlws_subscription.py b/tests/test_aiohttp_websocket_graphqlws_subscription.py index d40d15ce..79cf506d 100644 --- a/tests/test_aiohttp_websocket_graphqlws_subscription.py +++ b/tests/test_aiohttp_websocket_graphqlws_subscription.py @@ -8,7 +8,7 @@ from parse import search from gql import Client, gql -from gql.transport.exceptions import TransportServerError +from gql.transport.exceptions import TransportConnectionClosed, TransportServerError from .conftest import MS, WebSocketServerHelper @@ -390,7 +390,7 @@ async def test_aiohttp_websocket_graphqlws_subscription_server_connection_closed count = 10 subscription = gql(subscription_str.format(count=count)) - with pytest.raises(ConnectionResetError): + with pytest.raises(TransportConnectionClosed): async for result in session.subscribe(subscription): number = result["number"] print(f"Number received: {number}") @@ -839,7 +839,7 @@ async def test_aiohttp_websocket_graphqlws_subscription_reconnecting_session( print("\nSUBSCRIPTION_1_WITH_DISCONNECT\n") async for result in session.subscribe(subscription_with_disconnect): pass - except ConnectionResetError: + except TransportConnectionClosed: pass await asyncio.sleep(50 * MS) diff --git a/tests/test_aiohttp_websocket_query.py b/tests/test_aiohttp_websocket_query.py index d76d646f..30b35d73 100644 --- a/tests/test_aiohttp_websocket_query.py +++ b/tests/test_aiohttp_websocket_query.py @@ -9,6 +9,7 @@ from gql.transport.exceptions import ( TransportAlreadyConnected, TransportClosed, + TransportConnectionClosed, TransportQueryError, TransportServerError, ) @@ -60,7 +61,14 @@ async def test_aiohttp_websocket_starting_client_in_context_manager( url = f"ws://{server.hostname}:{server.port}/graphql" print(f"url = {url}") - transport = AIOHTTPWebsocketsTransport(url=url, websocket_close_timeout=10) + transport = AIOHTTPWebsocketsTransport( + url=url, + websocket_close_timeout=10, + headers={"test": "1234"}, + ) + + assert transport.response_headers == {} + assert transport.headers["test"] == "1234" async with Client(transport=transport) as session: @@ -84,7 +92,7 @@ async def test_aiohttp_websocket_starting_client_in_context_manager( assert transport.response_headers["dummy"] == "test1234" # Check client is disconnect here - assert transport.websocket is None + assert transport._connected is False @pytest.mark.asyncio @@ -135,7 +143,7 @@ async def test_aiohttp_websocket_using_ssl_connection( assert africa["code"] == "AF" # Check client is disconnect here - assert transport.websocket is None + assert transport._connected is False @pytest.mark.asyncio @@ -166,19 +174,26 @@ async def test_aiohttp_websocket_using_ssl_connection_self_cert_fail( **extra_args, ) - with pytest.raises(ClientConnectorCertificateError) as exc_info: + if verify_https == "explicitely_enabled": + assert transport.ssl is True + + with pytest.raises(TransportConnectionClosed) as exc_info: async with Client(transport=transport) as session: query1 = gql(query1_str) await session.execute(query1) + cause = exc_info.value.__cause__ + + assert isinstance(cause, ClientConnectorCertificateError) + expected_error = "certificate verify failed: self-signed certificate" - assert expected_error in str(exc_info.value) + assert expected_error in str(cause) # Check client is disconnect here - assert transport.websocket is None + assert transport._connected is False @pytest.mark.asyncio @@ -380,13 +395,13 @@ async def test_aiohttp_websocket_multiple_connections_in_series( await assert_client_is_working(session) # Check client is disconnect here - assert transport.websocket is None + assert transport._connected is False async with Client(transport=transport) as session: await assert_client_is_working(session) # Check client is disconnect here - assert transport.websocket is None + assert transport._connected is False @pytest.mark.asyncio @@ -519,8 +534,8 @@ async def test_aiohttp_websocket_connect_failed_with_authentication_in_connectio await session.execute(query1) - assert transport.session is None - assert transport.websocket is None + assert transport.adapter.session is None + assert transport._connected is False @pytest.mark.parametrize("aiohttp_ws_server", [server1_answers], indirect=True) @@ -564,7 +579,7 @@ def test_aiohttp_websocket_execute_sync(aiohttp_ws_server): assert africa["code"] == "AF" # Check client is disconnect here - assert transport.websocket is None + assert transport._connected is False @pytest.mark.asyncio @@ -753,6 +768,6 @@ async def test_aiohttp_websocket_connector_owner_false(event_loop, aiohttp_ws_se assert africa["code"] == "AF" # Check client is disconnect here - assert transport.websocket is None + assert transport._connected is False await connector.close() diff --git a/tests/test_aiohttp_websocket_subscription.py b/tests/test_aiohttp_websocket_subscription.py index 9d2d652b..188e006e 100644 --- a/tests/test_aiohttp_websocket_subscription.py +++ b/tests/test_aiohttp_websocket_subscription.py @@ -9,7 +9,7 @@ from parse import search from gql import Client, gql -from gql.transport.exceptions import TransportClosed, TransportServerError +from gql.transport.exceptions import TransportConnectionClosed, TransportServerError from .conftest import MS, WebSocketServerHelper from .starwars.schema import StarWarsIntrospection, StarWarsSchema, StarWarsTypeDef @@ -381,7 +381,7 @@ async def test_aiohttp_websocket_subscription_server_connection_closed( count = 10 subscription = gql(subscription_str.format(count=count)) - with pytest.raises(ConnectionResetError): + with pytest.raises(TransportConnectionClosed): async for result in session.subscribe(subscription): @@ -772,14 +772,12 @@ async def test_subscribe_on_closing_transport(event_loop, server, subscription_s subscription = gql(subscription_str.format(count=count)) async with client as session: - session.transport.websocket._writer._closing = True + session.transport.adapter.websocket._writer._closing = True - with pytest.raises(ConnectionResetError) as e: + with pytest.raises(TransportConnectionClosed): async for _ in session.subscribe(subscription): pass - assert e.value.args[0] == "Cannot write to closing transport" - @pytest.mark.asyncio @pytest.mark.parametrize("server", [server_countdown], indirect=True) @@ -798,9 +796,7 @@ async def test_subscribe_on_null_transport(event_loop, server, subscription_str) async with client as session: - session.transport.websocket = None - with pytest.raises(TransportClosed) as e: + session.transport.adapter.websocket = None + with pytest.raises(TransportConnectionClosed): async for _ in session.subscribe(subscription): pass - - assert e.value.args[0] == "WebSocket connection is closed" diff --git a/tests/test_phoenix_channel_query.py b/tests/test_phoenix_channel_query.py index 320d1da3..732c0e14 100644 --- a/tests/test_phoenix_channel_query.py +++ b/tests/test_phoenix_channel_query.py @@ -1,6 +1,7 @@ import pytest from gql import Client, gql +from gql.transport.exceptions import TransportConnectionClosed from .conftest import get_localhost_ssl_context_client @@ -71,14 +72,10 @@ async def test_phoenix_channel_query(event_loop, server, query_str): assert africa["code"] == "AF" -@pytest.mark.skip(reason="ssl=False is not working for now") @pytest.mark.asyncio @pytest.mark.parametrize("ws_ssl_server", [query_server], indirect=True) @pytest.mark.parametrize("query_str", [query1_str]) -@pytest.mark.parametrize("verify_https", ["disabled", "cert_provided"]) -async def test_phoenix_channel_query_ssl( - event_loop, ws_ssl_server, query_str, verify_https -): +async def test_phoenix_channel_query_ssl(event_loop, ws_ssl_server, query_str): from gql.transport.phoenix_channel_websockets import ( PhoenixChannelWebsocketsTransport, ) @@ -89,12 +86,9 @@ async def test_phoenix_channel_query_ssl( extra_args = {} - if verify_https == "cert_provided": - _, ssl_context = get_localhost_ssl_context_client() + _, ssl_context = get_localhost_ssl_context_client() - extra_args["ssl"] = ssl_context - elif verify_https == "disabled": - extra_args["ssl"] = False + extra_args["ssl"] = ssl_context transport = PhoenixChannelWebsocketsTransport( channel_name="test_channel", @@ -138,13 +132,17 @@ async def test_phoenix_channel_query_ssl_self_cert_fail( query = gql(query_str) - with pytest.raises(SSLCertVerificationError) as exc_info: + with pytest.raises(TransportConnectionClosed) as exc_info: async with Client(transport=transport) as session: await session.execute(query) + cause = exc_info.value.__cause__ + + assert isinstance(cause, SSLCertVerificationError) + expected_error = "certificate verify failed: self-signed certificate" - assert expected_error in str(exc_info.value) + assert expected_error in str(cause) query2_str = """ diff --git a/tests/test_websocket_query.py b/tests/test_websocket_query.py index 7aa853bf..f7e92840 100644 --- a/tests/test_websocket_query.py +++ b/tests/test_websocket_query.py @@ -9,6 +9,7 @@ from gql.transport.exceptions import ( TransportAlreadyConnected, TransportClosed, + TransportConnectionClosed, TransportQueryError, TransportServerError, ) @@ -88,11 +89,9 @@ async def test_websocket_starting_client_in_context_manager(event_loop, server): assert transport._connected is False -@pytest.mark.skip(reason="ssl=False is not working for now") @pytest.mark.asyncio @pytest.mark.parametrize("ws_ssl_server", [server1_answers], indirect=True) -@pytest.mark.parametrize("verify_https", ["disabled", "cert_provided"]) -async def test_websocket_using_ssl_connection(event_loop, ws_ssl_server, verify_https): +async def test_websocket_using_ssl_connection(event_loop, ws_ssl_server): import websockets from gql.transport.websockets import WebsocketsTransport @@ -103,19 +102,16 @@ async def test_websocket_using_ssl_connection(event_loop, ws_ssl_server, verify_ extra_args = {} - if verify_https == "cert_provided": - _, ssl_context = get_localhost_ssl_context_client() + _, ssl_context = get_localhost_ssl_context_client() - extra_args["ssl"] = ssl_context - elif verify_https == "disabled": - extra_args["ssl"] = False + extra_args["ssl"] = ssl_context transport = WebsocketsTransport(url=url, **extra_args) async with Client(transport=transport) as session: assert isinstance( - transport.websocket, websockets.client.WebSocketClientProtocol + transport.adapter.websocket, websockets.client.WebSocketClientProtocol ) query1 = gql(query1_str) @@ -160,16 +156,20 @@ async def test_websocket_using_ssl_connection_self_cert_fail( if verify_https == "explicitely_enabled": assert transport.ssl is True - with pytest.raises(SSLCertVerificationError) as exc_info: + with pytest.raises(TransportConnectionClosed) as exc_info: async with Client(transport=transport) as session: query1 = gql(query1_str) await session.execute(query1) + cause = exc_info.value.__cause__ + + assert isinstance(cause, SSLCertVerificationError) + expected_error = "certificate verify failed: self-signed certificate" - assert expected_error in str(exc_info.value) + assert expected_error in str(cause) # Check client is disconnect here assert transport._connected is False From 750e695315380b9a1e3adb9bdf361d2b8c018c26 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Mon, 10 Mar 2025 17:00:12 +0100 Subject: [PATCH 8/9] Fix PyPy tests --- gql/transport/common/base.py | 10 +++++-- tests/conftest.py | 4 +++ ...iohttp_websocket_graphqlws_subscription.py | 28 +++++++++++++++---- tests/test_aiohttp_websocket_subscription.py | 6 +++- tests/test_client.py | 4 +++ tests/test_graphqlws_subscription.py | 28 +++++++++++++++---- tests/test_phoenix_channel_query.py | 6 +++- tests/test_phoenix_channel_subscription.py | 13 +++++++-- tests/test_websocket_subscription.py | 17 +++++++++-- 9 files changed, 95 insertions(+), 21 deletions(-) diff --git a/gql/transport/common/base.py b/gql/transport/common/base.py index 40d0b4cb..2a4d4d65 100644 --- a/gql/transport/common/base.py +++ b/gql/transport/common/base.py @@ -317,6 +317,8 @@ async def subscribe( if listener.send_stop: await self._stop_listener(query_id) listener.send_stop = False + if isinstance(e, GeneratorExit): + raise e finally: log.debug(f"In subscribe finally for query_id {query_id}") @@ -345,6 +347,11 @@ async def execute( first_result = result break + # Apparently, on pypy the GeneratorExit exception is not raised after a break + # --> the clean_close has to time out + # We still need to manually close the async generator + await generator.aclose() + if first_result is None: raise TransportQueryError( "Query completed without any answer received from the server" @@ -445,7 +452,6 @@ async def _clean_close(self, e: Exception) -> None: # Send 'stop' message for all current queries for query_id, listener in self.listeners.items(): - if listener.send_stop: await self._stop_listener(query_id) listener.send_stop = False @@ -556,7 +562,7 @@ async def wait_closed(self) -> None: try: await asyncio.wait_for(self._wait_closed.wait(), self.close_timeout) except asyncio.TimeoutError: - log.debug("Timer close_timeout fired in wait_closed") + log.warning("Timer close_timeout fired in wait_closed") log.debug("wait_close: done") diff --git a/tests/conftest.py b/tests/conftest.py index 664fe8c9..f9e11dab 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,6 +3,7 @@ import logging import os import pathlib +import platform import re import ssl import sys @@ -19,6 +20,9 @@ all_transport_dependencies = ["aiohttp", "requests", "httpx", "websockets", "botocore"] +PyPy = platform.python_implementation() == "PyPy" + + def pytest_addoption(parser): parser.addoption( "--run-online", diff --git a/tests/test_aiohttp_websocket_graphqlws_subscription.py b/tests/test_aiohttp_websocket_graphqlws_subscription.py index 79cf506d..e97da29a 100644 --- a/tests/test_aiohttp_websocket_graphqlws_subscription.py +++ b/tests/test_aiohttp_websocket_graphqlws_subscription.py @@ -10,7 +10,7 @@ from gql import Client, gql from gql.transport.exceptions import TransportConnectionClosed, TransportServerError -from .conftest import MS, WebSocketServerHelper +from .conftest import MS, PyPy, WebSocketServerHelper # Marking all tests in this file with the aiohttp AND websockets marker pytestmark = [pytest.mark.aiohttp, pytest.mark.websockets] @@ -260,7 +260,8 @@ async def test_aiohttp_websocket_graphqlws_subscription_break( count = 10 subscription = gql(subscription_str.format(count=count)) - async for result in session.subscribe(subscription): + generator = session.subscribe(subscription) + async for result in generator: number = result["number"] print(f"Number received: {number}") @@ -274,6 +275,9 @@ async def test_aiohttp_websocket_graphqlws_subscription_break( assert count == 5 + # Using aclose here to make it stop cleanly on pypy + await generator.aclose() + @pytest.mark.asyncio @pytest.mark.parametrize("graphqlws_server", [server_countdown], indirect=True) @@ -847,23 +851,33 @@ async def test_aiohttp_websocket_graphqlws_subscription_reconnecting_session( # Then with the same session handle, we make a subscription or an execute # which will detect that the transport is closed so that the client could # try to reconnect + generator = None try: if execute_instead_of_subscribe: print("\nEXECUTION_2\n") await session.execute(subscription) else: print("\nSUBSCRIPTION_2\n") - async for result in session.subscribe(subscription): + generator = session.subscribe(subscription) + async for result in generator: pass - except TransportClosed: + except (TransportClosed, TransportConnectionClosed): + if generator: + await generator.aclose() pass - await asyncio.sleep(50 * MS) + timeout = 50 + + if PyPy: + timeout = 500 + + await asyncio.sleep(timeout * MS) # And finally with the same session handle, we make a subscription # which works correctly print("\nSUBSCRIPTION_3\n") - async for result in session.subscribe(subscription): + generator = session.subscribe(subscription) + async for result in generator: number = result["number"] print(f"Number received: {number}") @@ -871,6 +885,8 @@ async def test_aiohttp_websocket_graphqlws_subscription_reconnecting_session( assert number == count count -= 1 + await generator.aclose() + assert count == -1 await client.close_async() diff --git a/tests/test_aiohttp_websocket_subscription.py b/tests/test_aiohttp_websocket_subscription.py index 188e006e..61270fe1 100644 --- a/tests/test_aiohttp_websocket_subscription.py +++ b/tests/test_aiohttp_websocket_subscription.py @@ -250,7 +250,8 @@ async def test_aiohttp_websocket_subscription_break( count = 10 subscription = gql(subscription_str.format(count=count)) - async for result in session.subscribe(subscription): + generator = session.subscribe(subscription) + async for result in generator: number = result["number"] print(f"Number received: {number}") @@ -264,6 +265,9 @@ async def test_aiohttp_websocket_subscription_break( assert count == 5 + # Using aclose here to make it stop cleanly on pypy + await generator.aclose() + @pytest.mark.asyncio @pytest.mark.parametrize("server", [server_countdown], indirect=True) diff --git a/tests/test_client.py b/tests/test_client.py index 1e794558..e5edec8b 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -280,3 +280,7 @@ async def test_async_transport_close_on_schema_retrieval_failure(): pass assert client.transport.session is None + + import asyncio + + await asyncio.sleep(1) diff --git a/tests/test_graphqlws_subscription.py b/tests/test_graphqlws_subscription.py index 1b8f7ccb..8284fea8 100644 --- a/tests/test_graphqlws_subscription.py +++ b/tests/test_graphqlws_subscription.py @@ -10,7 +10,7 @@ from gql import Client, gql from gql.transport.exceptions import TransportConnectionClosed, TransportServerError -from .conftest import MS, WebSocketServerHelper +from .conftest import MS, PyPy, WebSocketServerHelper # Marking all tests in this file with the websockets marker pytestmark = pytest.mark.websockets @@ -260,7 +260,8 @@ async def test_graphqlws_subscription_break( count = 10 subscription = gql(subscription_str.format(count=count)) - async for result in session.subscribe(subscription): + generator = session.subscribe(subscription) + async for result in generator: number = result["number"] print(f"Number received: {number}") @@ -274,6 +275,9 @@ async def test_graphqlws_subscription_break( assert count == 5 + # Using aclose here to make it stop cleanly on pypy + await generator.aclose() + @pytest.mark.asyncio @pytest.mark.parametrize("graphqlws_server", [server_countdown], indirect=True) @@ -843,23 +847,33 @@ async def test_graphqlws_subscription_reconnecting_session( # Then with the same session handle, we make a subscription or an execute # which will detect that the transport is closed so that the client could # try to reconnect + generator = None try: if execute_instead_of_subscribe: print("\nEXECUTION_2\n") await session.execute(subscription) else: print("\nSUBSCRIPTION_2\n") - async for result in session.subscribe(subscription): + generator = session.subscribe(subscription) + async for result in generator: pass - except TransportClosed: + except (TransportClosed, TransportConnectionClosed): + if generator: + await generator.aclose() pass - await asyncio.sleep(50 * MS) + timeout = 50 + + if PyPy: + timeout = 500 + + await asyncio.sleep(timeout * MS) # And finally with the same session handle, we make a subscription # which works correctly print("\nSUBSCRIPTION_3\n") - async for result in session.subscribe(subscription): + generator = session.subscribe(subscription) + async for result in generator: number = result["number"] print(f"Number received: {number}") @@ -867,6 +881,8 @@ async def test_graphqlws_subscription_reconnecting_session( assert number == count count -= 1 + await generator.aclose() + assert count == -1 await client.close_async() diff --git a/tests/test_phoenix_channel_query.py b/tests/test_phoenix_channel_query.py index 732c0e14..16d4e4f4 100644 --- a/tests/test_phoenix_channel_query.py +++ b/tests/test_phoenix_channel_query.py @@ -216,8 +216,12 @@ async def test_phoenix_channel_subscription(event_loop, server, query_str): first_result = None query = gql(query_str) async with Client(transport=transport) as session: - async for result in session.subscribe(query): + generator = session.subscribe(query) + async for result in generator: first_result = result break + # Using aclose here to make it stop cleanly on pypy + await generator.aclose() + print("Client received:", first_result) diff --git a/tests/test_phoenix_channel_subscription.py b/tests/test_phoenix_channel_subscription.py index 3be4b07d..35ca665b 100644 --- a/tests/test_phoenix_channel_subscription.py +++ b/tests/test_phoenix_channel_subscription.py @@ -201,7 +201,9 @@ async def test_phoenix_channel_subscription( subscription = gql(subscription_str.format(count=count)) async with Client(transport=sample_transport) as session: - async for result in session.subscribe(subscription): + + generator = session.subscribe(subscription) + async for result in generator: number = result["countdown"]["number"] print(f"Number received: {number}") @@ -212,6 +214,9 @@ async def test_phoenix_channel_subscription( count -= 1 + # Using aclose here to make it stop cleanly on pypy + await generator.aclose() + assert count == end_count @@ -378,7 +383,8 @@ async def test_phoenix_channel_heartbeat(event_loop, server, subscription_str): subscription = gql(heartbeat_subscription_str) async with Client(transport=sample_transport) as session: i = 0 - async for result in session.subscribe(subscription): + generator = session.subscribe(subscription) + async for result in generator: heartbeat_count = result["heartbeat"]["heartbeat_count"] print(f"Heartbeat count received: {heartbeat_count}") @@ -387,3 +393,6 @@ async def test_phoenix_channel_heartbeat(event_loop, server, subscription_str): break i += 1 + + # Using aclose here to make it stop cleanly on pypy + await generator.aclose() diff --git a/tests/test_websocket_subscription.py b/tests/test_websocket_subscription.py index 3efe63a6..927db4e9 100644 --- a/tests/test_websocket_subscription.py +++ b/tests/test_websocket_subscription.py @@ -11,7 +11,7 @@ from gql import Client, gql from gql.transport.exceptions import TransportConnectionClosed, TransportServerError -from .conftest import MS, WebSocketServerHelper +from .conftest import MS, PyPy, WebSocketServerHelper # Marking all tests in this file with the websockets marker pytestmark = pytest.mark.websockets @@ -181,7 +181,8 @@ async def test_websocket_subscription_break( count = 10 subscription = gql(subscription_str.format(count=count)) - async for result in session.subscribe(subscription): + generator = session.subscribe(subscription) + async for result in generator: number = result["number"] print(f"Number received: {number}") @@ -195,6 +196,9 @@ async def test_websocket_subscription_break( assert count == 5 + # Using aclose here to make it stop cleanly on pypy + await generator.aclose() + @pytest.mark.asyncio @pytest.mark.parametrize("server", [server_countdown], indirect=True) @@ -413,7 +417,14 @@ async def test_websocket_subscription_with_keepalive_with_timeout_ok( path = "/graphql" url = f"ws://{server.hostname}:{server.port}{path}" - sample_transport = WebsocketsTransport(url=url, keep_alive_timeout=(20 * MS)) + + keep_alive_timeout = 20 * MS + if PyPy: + keep_alive_timeout = 200 * MS + + sample_transport = WebsocketsTransport( + url=url, keep_alive_timeout=keep_alive_timeout + ) client = Client(transport=sample_transport) From 7fb869a6abd8fb035e00f3706a9e104f5eb655eb Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Tue, 11 Mar 2025 11:31:24 +0100 Subject: [PATCH 9/9] Renaming TransportConnectionClosed to TransportConnectionFailed --- gql/transport/common/adapters/aiohttp.py | 16 ++++++++-------- gql/transport/common/adapters/connection.py | 4 ++-- gql/transport/common/adapters/websockets.py | 16 ++++++++-------- gql/transport/common/base.py | 12 ++++++------ gql/transport/exceptions.py | 2 +- gql/transport/phoenix_channel_websockets.py | 4 ++-- gql/transport/websockets_protocol.py | 4 ++-- tests/test_aiohttp_websocket_exceptions.py | 6 +++--- ...est_aiohttp_websocket_graphqlws_exceptions.py | 6 +++--- ...t_aiohttp_websocket_graphqlws_subscription.py | 8 ++++---- tests/test_aiohttp_websocket_query.py | 4 ++-- tests/test_aiohttp_websocket_subscription.py | 8 ++++---- tests/test_graphqlws_exceptions.py | 6 +++--- tests/test_graphqlws_subscription.py | 8 ++++---- tests/test_phoenix_channel_query.py | 4 ++-- tests/test_websocket_exceptions.py | 6 +++--- tests/test_websocket_query.py | 4 ++-- tests/test_websocket_subscription.py | 4 ++-- tests/test_websockets_adapter.py | 6 +++--- 19 files changed, 64 insertions(+), 64 deletions(-) diff --git a/gql/transport/common/adapters/aiohttp.py b/gql/transport/common/adapters/aiohttp.py index d9af7c50..f2dff699 100644 --- a/gql/transport/common/adapters/aiohttp.py +++ b/gql/transport/common/adapters/aiohttp.py @@ -8,7 +8,7 @@ from aiohttp.typedefs import LooseHeaders, StrOrURL from multidict import CIMultiDictProxy -from ...exceptions import TransportConnectionClosed, TransportProtocolError +from ...exceptions import TransportConnectionFailed, TransportProtocolError from ..aiohttp_closed_event import create_aiohttp_closed_event from .connection import AdapterConnection @@ -160,7 +160,7 @@ async def connect(self) -> None: **connect_args, ) except Exception as e: - raise TransportConnectionClosed("Connect failed") from e + raise TransportConnectionFailed("Connect failed") from e self._response_headers = self.websocket._response.headers @@ -171,15 +171,15 @@ async def send(self, message: str) -> None: message: String message to send Raises: - TransportConnectionClosed: If connection closed + TransportConnectionFailed: If connection closed """ if self.websocket is None: - raise TransportConnectionClosed("Connection is already closed") + raise TransportConnectionFailed("Connection is already closed") try: await self.websocket.send_str(message) except ConnectionResetError as e: - raise TransportConnectionClosed("Connection was closed") from e + raise TransportConnectionFailed("Connection was closed") from e async def receive(self) -> str: """Receive message from the WebSocket server. @@ -188,12 +188,12 @@ async def receive(self) -> str: String message received Raises: - TransportConnectionClosed: If connection closed + TransportConnectionFailed: If connection closed TransportProtocolError: If protocol error or binary data received """ # It is possible that the websocket has been already closed in another task if self.websocket is None: - raise TransportConnectionClosed("Connection is already closed") + raise TransportConnectionFailed("Connection is already closed") while True: ws_message = await self.websocket.receive() @@ -208,7 +208,7 @@ async def receive(self) -> str: WSMsgType.CLOSING, WSMsgType.ERROR, ): - raise TransportConnectionClosed("Connection was closed") + raise TransportConnectionFailed("Connection was closed") elif ws_message.type is WSMsgType.BINARY: raise TransportProtocolError("Binary data received in the websocket") diff --git a/gql/transport/common/adapters/connection.py b/gql/transport/common/adapters/connection.py index f3d77421..ac178bc6 100644 --- a/gql/transport/common/adapters/connection.py +++ b/gql/transport/common/adapters/connection.py @@ -35,7 +35,7 @@ async def send(self, message: str) -> None: message: String message to send Raises: - TransportConnectionClosed: If connection closed + TransportConnectionFailed: If connection closed """ pass # pragma: no cover @@ -47,7 +47,7 @@ async def receive(self) -> str: String message received Raises: - TransportConnectionClosed: If connection closed + TransportConnectionFailed: If connection closed TransportProtocolError: If protocol error or binary data received """ pass # pragma: no cover diff --git a/gql/transport/common/adapters/websockets.py b/gql/transport/common/adapters/websockets.py index 383d4def..c2524fb4 100644 --- a/gql/transport/common/adapters/websockets.py +++ b/gql/transport/common/adapters/websockets.py @@ -6,7 +6,7 @@ from websockets.client import WebSocketClientProtocol from websockets.datastructures import Headers, HeadersLike -from ...exceptions import TransportConnectionClosed, TransportProtocolError +from ...exceptions import TransportConnectionFailed, TransportProtocolError from .connection import AdapterConnection log = logging.getLogger("gql.transport.common.adapters.websockets") @@ -70,7 +70,7 @@ async def connect(self) -> None: try: self.websocket = await websockets.client.connect(self.url, **connect_args) except Exception as e: - raise TransportConnectionClosed("Connect failed") from e + raise TransportConnectionFailed("Connect failed") from e self._response_headers = self.websocket.response_headers @@ -81,15 +81,15 @@ async def send(self, message: str) -> None: message: String message to send Raises: - TransportConnectionClosed: If connection closed + TransportConnectionFailed: If connection closed """ if self.websocket is None: - raise TransportConnectionClosed("Connection is already closed") + raise TransportConnectionFailed("Connection is already closed") try: await self.websocket.send(message) except Exception as e: - raise TransportConnectionClosed("Connection was closed") from e + raise TransportConnectionFailed("Connection was closed") from e async def receive(self) -> str: """Receive message from the WebSocket server. @@ -98,18 +98,18 @@ async def receive(self) -> str: String message received Raises: - TransportConnectionClosed: If connection closed + TransportConnectionFailed: If connection closed TransportProtocolError: If protocol error or binary data received """ # It is possible that the websocket has been already closed in another task if self.websocket is None: - raise TransportConnectionClosed("Connection is already closed") + raise TransportConnectionFailed("Connection is already closed") # Wait for the next websocket frame. Can raise ConnectionClosed try: data = await self.websocket.recv() except Exception as e: - raise TransportConnectionClosed("Connection was closed") from e + raise TransportConnectionFailed("Connection was closed") from e # websocket.recv() can return either str or bytes # In our case, we should receive only str here diff --git a/gql/transport/common/base.py b/gql/transport/common/base.py index 2a4d4d65..770a8b34 100644 --- a/gql/transport/common/base.py +++ b/gql/transport/common/base.py @@ -11,7 +11,7 @@ from ..exceptions import ( TransportAlreadyConnected, TransportClosed, - TransportConnectionClosed, + TransportConnectionFailed, TransportProtocolError, TransportQueryError, TransportServerError, @@ -134,7 +134,7 @@ async def _send(self, message: str) -> None: try: await self.adapter.send(message) log.info(">>> %s", message) - except TransportConnectionClosed as e: + except TransportConnectionFailed as e: await self._fail(e, clean_close=False) raise e @@ -146,7 +146,7 @@ async def _receive(self) -> str: raise TransportClosed("Transport is already closed") # Wait for the next frame. - # Can raise TransportConnectionClosed or TransportProtocolError + # Can raise TransportConnectionFailed or TransportProtocolError answer: str = await self.adapter.receive() log.info("<<< %s", answer) @@ -211,7 +211,7 @@ async def _receive_data_loop(self) -> None: # Wait the next answer from the server try: answer = await self._receive() - except (TransportConnectionClosed, TransportProtocolError) as e: + except (TransportConnectionFailed, TransportProtocolError) as e: await self._fail(e, clean_close=False) break except TransportClosed: @@ -296,7 +296,7 @@ async def subscribe( while True: # Wait for the answer from the queue of this query_id - # This can raise TransportError or TransportConnectionClosed + # This can raise TransportError or TransportConnectionFailed answer_type, execution_result = await listener.get() # If the received answer contains data, @@ -402,7 +402,7 @@ async def connect(self) -> None: # if no ACKs are received within the ack_timeout try: await self._initialize() - except TransportConnectionClosed as e: + except TransportConnectionFailed as e: raise e except ( TransportProtocolError, diff --git a/gql/transport/exceptions.py b/gql/transport/exceptions.py index 27cefe2f..3e63f0bc 100644 --- a/gql/transport/exceptions.py +++ b/gql/transport/exceptions.py @@ -61,7 +61,7 @@ class TransportClosed(TransportError): """ -class TransportConnectionClosed(TransportError): +class TransportConnectionFailed(TransportError): """Transport adapter connection closed. This exception is by the connection adapter code when a connection closed. diff --git a/gql/transport/phoenix_channel_websockets.py b/gql/transport/phoenix_channel_websockets.py index 0c1bd62b..3885fcac 100644 --- a/gql/transport/phoenix_channel_websockets.py +++ b/gql/transport/phoenix_channel_websockets.py @@ -8,7 +8,7 @@ from .common.adapters.websockets import WebSocketsAdapter from .common.base import SubscriptionTransportBase from .exceptions import ( - TransportConnectionClosed, + TransportConnectionFailed, TransportProtocolError, TransportQueryError, TransportServerError, @@ -127,7 +127,7 @@ async def heartbeat_coro(): } ) ) - except TransportConnectionClosed: # pragma: no cover + except TransportConnectionFailed: # pragma: no cover return self.heartbeat_task = asyncio.ensure_future(heartbeat_coro()) diff --git a/gql/transport/websockets_protocol.py b/gql/transport/websockets_protocol.py index f004d240..3348c576 100644 --- a/gql/transport/websockets_protocol.py +++ b/gql/transport/websockets_protocol.py @@ -9,7 +9,7 @@ from .common.adapters.connection import AdapterConnection from .common.base import SubscriptionTransportBase from .exceptions import ( - TransportConnectionClosed, + TransportConnectionFailed, TransportProtocolError, TransportQueryError, TransportServerError, @@ -508,7 +508,7 @@ async def _close_hook(self): if self.send_ping_task is not None: log.debug("_close_hook: cancelling send_ping_task") self.send_ping_task.cancel() - with suppress(asyncio.CancelledError, TransportConnectionClosed): + with suppress(asyncio.CancelledError, TransportConnectionFailed): log.debug("_close_hook: awaiting send_ping_task") await self.send_ping_task self.send_ping_task = None diff --git a/tests/test_aiohttp_websocket_exceptions.py b/tests/test_aiohttp_websocket_exceptions.py index e4e56fcd..81c79ba7 100644 --- a/tests/test_aiohttp_websocket_exceptions.py +++ b/tests/test_aiohttp_websocket_exceptions.py @@ -7,7 +7,7 @@ from gql import Client, gql from gql.transport.exceptions import ( - TransportConnectionClosed, + TransportConnectionFailed, TransportProtocolError, TransportQueryError, ) @@ -289,7 +289,7 @@ async def test_aiohttp_websocket_server_closing_directly(event_loop, server): sample_transport = AIOHTTPWebsocketsTransport(url=url) - with pytest.raises(TransportConnectionClosed): + with pytest.raises(TransportConnectionFailed): async with Client(transport=sample_transport): pass @@ -309,7 +309,7 @@ async def test_aiohttp_websocket_server_closing_after_ack( query = gql("query { hello }") - with pytest.raises(TransportConnectionClosed): + with pytest.raises(TransportConnectionFailed): await session.execute(query) diff --git a/tests/test_aiohttp_websocket_graphqlws_exceptions.py b/tests/test_aiohttp_websocket_graphqlws_exceptions.py index 8f3567a7..f87682d2 100644 --- a/tests/test_aiohttp_websocket_graphqlws_exceptions.py +++ b/tests/test_aiohttp_websocket_graphqlws_exceptions.py @@ -6,7 +6,7 @@ from gql import Client, gql from gql.transport.exceptions import ( TransportClosed, - TransportConnectionClosed, + TransportConnectionFailed, TransportProtocolError, TransportQueryError, ) @@ -248,7 +248,7 @@ async def test_aiohttp_websocket_graphqlws_server_closing_directly( transport = AIOHTTPWebsocketsTransport(url=url) - with pytest.raises(TransportConnectionClosed): + with pytest.raises(TransportConnectionFailed): async with Client(transport=transport): pass @@ -268,7 +268,7 @@ async def test_aiohttp_websocket_graphqlws_server_closing_after_ack( query = gql("query { hello }") - with pytest.raises(TransportConnectionClosed): + with pytest.raises(TransportConnectionFailed): await session.execute(query) await session.transport.wait_closed() diff --git a/tests/test_aiohttp_websocket_graphqlws_subscription.py b/tests/test_aiohttp_websocket_graphqlws_subscription.py index e97da29a..f380948c 100644 --- a/tests/test_aiohttp_websocket_graphqlws_subscription.py +++ b/tests/test_aiohttp_websocket_graphqlws_subscription.py @@ -8,7 +8,7 @@ from parse import search from gql import Client, gql -from gql.transport.exceptions import TransportConnectionClosed, TransportServerError +from gql.transport.exceptions import TransportConnectionFailed, TransportServerError from .conftest import MS, PyPy, WebSocketServerHelper @@ -394,7 +394,7 @@ async def test_aiohttp_websocket_graphqlws_subscription_server_connection_closed count = 10 subscription = gql(subscription_str.format(count=count)) - with pytest.raises(TransportConnectionClosed): + with pytest.raises(TransportConnectionFailed): async for result in session.subscribe(subscription): number = result["number"] print(f"Number received: {number}") @@ -843,7 +843,7 @@ async def test_aiohttp_websocket_graphqlws_subscription_reconnecting_session( print("\nSUBSCRIPTION_1_WITH_DISCONNECT\n") async for result in session.subscribe(subscription_with_disconnect): pass - except TransportConnectionClosed: + except TransportConnectionFailed: pass await asyncio.sleep(50 * MS) @@ -861,7 +861,7 @@ async def test_aiohttp_websocket_graphqlws_subscription_reconnecting_session( generator = session.subscribe(subscription) async for result in generator: pass - except (TransportClosed, TransportConnectionClosed): + except (TransportClosed, TransportConnectionFailed): if generator: await generator.aclose() pass diff --git a/tests/test_aiohttp_websocket_query.py b/tests/test_aiohttp_websocket_query.py index 30b35d73..8786d58d 100644 --- a/tests/test_aiohttp_websocket_query.py +++ b/tests/test_aiohttp_websocket_query.py @@ -9,7 +9,7 @@ from gql.transport.exceptions import ( TransportAlreadyConnected, TransportClosed, - TransportConnectionClosed, + TransportConnectionFailed, TransportQueryError, TransportServerError, ) @@ -177,7 +177,7 @@ async def test_aiohttp_websocket_using_ssl_connection_self_cert_fail( if verify_https == "explicitely_enabled": assert transport.ssl is True - with pytest.raises(TransportConnectionClosed) as exc_info: + with pytest.raises(TransportConnectionFailed) as exc_info: async with Client(transport=transport) as session: query1 = gql(query1_str) diff --git a/tests/test_aiohttp_websocket_subscription.py b/tests/test_aiohttp_websocket_subscription.py index 61270fe1..4ea11a7b 100644 --- a/tests/test_aiohttp_websocket_subscription.py +++ b/tests/test_aiohttp_websocket_subscription.py @@ -9,7 +9,7 @@ from parse import search from gql import Client, gql -from gql.transport.exceptions import TransportConnectionClosed, TransportServerError +from gql.transport.exceptions import TransportConnectionFailed, TransportServerError from .conftest import MS, WebSocketServerHelper from .starwars.schema import StarWarsIntrospection, StarWarsSchema, StarWarsTypeDef @@ -385,7 +385,7 @@ async def test_aiohttp_websocket_subscription_server_connection_closed( count = 10 subscription = gql(subscription_str.format(count=count)) - with pytest.raises(TransportConnectionClosed): + with pytest.raises(TransportConnectionFailed): async for result in session.subscribe(subscription): @@ -778,7 +778,7 @@ async def test_subscribe_on_closing_transport(event_loop, server, subscription_s async with client as session: session.transport.adapter.websocket._writer._closing = True - with pytest.raises(TransportConnectionClosed): + with pytest.raises(TransportConnectionFailed): async for _ in session.subscribe(subscription): pass @@ -801,6 +801,6 @@ async def test_subscribe_on_null_transport(event_loop, server, subscription_str) async with client as session: session.transport.adapter.websocket = None - with pytest.raises(TransportConnectionClosed): + with pytest.raises(TransportConnectionFailed): async for _ in session.subscribe(subscription): pass diff --git a/tests/test_graphqlws_exceptions.py b/tests/test_graphqlws_exceptions.py index cce31d59..3b6bd901 100644 --- a/tests/test_graphqlws_exceptions.py +++ b/tests/test_graphqlws_exceptions.py @@ -6,7 +6,7 @@ from gql import Client, gql from gql.transport.exceptions import ( TransportClosed, - TransportConnectionClosed, + TransportConnectionFailed, TransportProtocolError, TransportQueryError, ) @@ -241,7 +241,7 @@ async def test_graphqlws_server_closing_directly(event_loop, graphqlws_server): sample_transport = WebsocketsTransport(url=url) - with pytest.raises(TransportConnectionClosed): + with pytest.raises(TransportConnectionFailed): async with Client(transport=sample_transport): pass @@ -261,7 +261,7 @@ async def test_graphqlws_server_closing_after_ack( query = gql("query { hello }") - with pytest.raises(TransportConnectionClosed): + with pytest.raises(TransportConnectionFailed): await session.execute(query) await session.transport.wait_closed() diff --git a/tests/test_graphqlws_subscription.py b/tests/test_graphqlws_subscription.py index 8284fea8..d4bed34f 100644 --- a/tests/test_graphqlws_subscription.py +++ b/tests/test_graphqlws_subscription.py @@ -8,7 +8,7 @@ from parse import search from gql import Client, gql -from gql.transport.exceptions import TransportConnectionClosed, TransportServerError +from gql.transport.exceptions import TransportConnectionFailed, TransportServerError from .conftest import MS, PyPy, WebSocketServerHelper @@ -394,7 +394,7 @@ async def test_graphqlws_subscription_server_connection_closed( count = 10 subscription = gql(subscription_str.format(count=count)) - with pytest.raises(TransportConnectionClosed): + with pytest.raises(TransportConnectionFailed): async for result in session.subscribe(subscription): @@ -839,7 +839,7 @@ async def test_graphqlws_subscription_reconnecting_session( print("\nSUBSCRIPTION_1_WITH_DISCONNECT\n") async for result in session.subscribe(subscription_with_disconnect): pass - except TransportConnectionClosed: + except TransportConnectionFailed: pass await asyncio.sleep(50 * MS) @@ -857,7 +857,7 @@ async def test_graphqlws_subscription_reconnecting_session( generator = session.subscribe(subscription) async for result in generator: pass - except (TransportClosed, TransportConnectionClosed): + except (TransportClosed, TransportConnectionFailed): if generator: await generator.aclose() pass diff --git a/tests/test_phoenix_channel_query.py b/tests/test_phoenix_channel_query.py index 16d4e4f4..56d28875 100644 --- a/tests/test_phoenix_channel_query.py +++ b/tests/test_phoenix_channel_query.py @@ -1,7 +1,7 @@ import pytest from gql import Client, gql -from gql.transport.exceptions import TransportConnectionClosed +from gql.transport.exceptions import TransportConnectionFailed from .conftest import get_localhost_ssl_context_client @@ -132,7 +132,7 @@ async def test_phoenix_channel_query_ssl_self_cert_fail( query = gql(query_str) - with pytest.raises(TransportConnectionClosed) as exc_info: + with pytest.raises(TransportConnectionFailed) as exc_info: async with Client(transport=transport) as session: await session.execute(query) diff --git a/tests/test_websocket_exceptions.py b/tests/test_websocket_exceptions.py index f9f1f8db..68b2fe52 100644 --- a/tests/test_websocket_exceptions.py +++ b/tests/test_websocket_exceptions.py @@ -9,7 +9,7 @@ from gql.transport.exceptions import ( TransportAlreadyConnected, TransportClosed, - TransportConnectionClosed, + TransportConnectionFailed, TransportProtocolError, TransportQueryError, ) @@ -280,7 +280,7 @@ async def test_websocket_server_closing_directly(event_loop, server): sample_transport = WebsocketsTransport(url=url) - with pytest.raises(TransportConnectionClosed): + with pytest.raises(TransportConnectionFailed): async with Client(transport=sample_transport): pass @@ -298,7 +298,7 @@ async def test_websocket_server_closing_after_ack(event_loop, client_and_server) query = gql("query { hello }") - with pytest.raises(TransportConnectionClosed): + with pytest.raises(TransportConnectionFailed): await session.execute(query) await session.transport.wait_closed() diff --git a/tests/test_websocket_query.py b/tests/test_websocket_query.py index f7e92840..b1e3c07a 100644 --- a/tests/test_websocket_query.py +++ b/tests/test_websocket_query.py @@ -9,7 +9,7 @@ from gql.transport.exceptions import ( TransportAlreadyConnected, TransportClosed, - TransportConnectionClosed, + TransportConnectionFailed, TransportQueryError, TransportServerError, ) @@ -156,7 +156,7 @@ async def test_websocket_using_ssl_connection_self_cert_fail( if verify_https == "explicitely_enabled": assert transport.ssl is True - with pytest.raises(TransportConnectionClosed) as exc_info: + with pytest.raises(TransportConnectionFailed) as exc_info: async with Client(transport=transport) as session: query1 = gql(query1_str) diff --git a/tests/test_websocket_subscription.py b/tests/test_websocket_subscription.py index 927db4e9..6f291218 100644 --- a/tests/test_websocket_subscription.py +++ b/tests/test_websocket_subscription.py @@ -9,7 +9,7 @@ from parse import search from gql import Client, gql -from gql.transport.exceptions import TransportConnectionClosed, TransportServerError +from gql.transport.exceptions import TransportConnectionFailed, TransportServerError from .conftest import MS, PyPy, WebSocketServerHelper @@ -315,7 +315,7 @@ async def test_websocket_subscription_server_connection_closed( count = 10 subscription = gql(subscription_str.format(count=count)) - with pytest.raises(TransportConnectionClosed): + with pytest.raises(TransportConnectionFailed): async for result in session.subscribe(subscription): diff --git a/tests/test_websockets_adapter.py b/tests/test_websockets_adapter.py index f266ce29..85fbf00a 100644 --- a/tests/test_websockets_adapter.py +++ b/tests/test_websockets_adapter.py @@ -4,7 +4,7 @@ from graphql import print_ast from gql import gql -from gql.transport.exceptions import TransportConnectionClosed +from gql.transport.exceptions import TransportConnectionFailed # Marking all tests in this file with the websockets marker pytestmark = pytest.mark.websockets @@ -91,8 +91,8 @@ async def test_websockets_adapter_edge_cases(event_loop, server): # Second close call is ignored await adapter.close() - with pytest.raises(TransportConnectionClosed): + with pytest.raises(TransportConnectionFailed): await adapter.send("Blah") - with pytest.raises(TransportConnectionClosed): + with pytest.raises(TransportConnectionFailed): await adapter.receive()