From 1480d3b81fb8fea877e5e205512a881b87458952 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Mon, 26 May 2025 17:54:44 +0200 Subject: [PATCH 1/3] Implementation of automatic batching for async --- gql/client.py | 177 +++++++++++++++++++++++++++++---- gql/transport/aiohttp.py | 48 +++++---- gql/transport/httpx.py | 29 ++++-- gql/transport/requests.py | 70 ++++++------- tests/test_aiohttp_batch.py | 190 ++++++++++++++++++++++++++++++++++++ 5 files changed, 437 insertions(+), 77 deletions(-) diff --git a/gql/client.py b/gql/client.py index 4e269a2a..a0e07056 100644 --- a/gql/client.py +++ b/gql/client.py @@ -829,15 +829,11 @@ async def connect_async(self, reconnecting=False, **kwargs): if reconnecting: self.session = ReconnectingAsyncClientSession(client=self, **kwargs) - await self.session.start_connecting_task() else: - try: - await self.transport.connect() - except Exception as e: - await self.transport.close() - raise e self.session = AsyncClientSession(client=self) + await self.session.connect() + # Get schema from transport if needed try: if self.fetch_schema_from_transport and not self.schema: @@ -846,7 +842,7 @@ async def connect_async(self, reconnecting=False, **kwargs): # we don't know what type of exception is thrown here because it # depends on the underlying transport; we just make sure that the # transport is closed and re-raise the exception - await self.transport.close() + await self.session.close() raise return self.session @@ -854,10 +850,7 @@ async def connect_async(self, reconnecting=False, **kwargs): async def close_async(self): """Close the async transport and stop the optional reconnecting task.""" - if isinstance(self.session, ReconnectingAsyncClientSession): - await self.session.stop_connecting_task() - - await self.transport.close() + await self.session.close() async def __aenter__(self): return await self.connect_async() @@ -1564,12 +1557,17 @@ async def _execute( ): request = request.serialize_variable_values(self.client.schema) - # Execute the query with the transport with a timeout - with fail_after(self.client.execute_timeout): - result = await self.transport.execute( - request, - **kwargs, - ) + # Check if batching is enabled + if self.client.batching_enabled: + future_result = await self._execute_future(request) + result = await future_result + else: + # Execute the query with the transport with a timeout + with fail_after(self.client.execute_timeout): + result = await self.transport.execute( + request, + **kwargs, + ) # Unserialize the result if requested if self.client.schema: @@ -1828,6 +1826,134 @@ async def execute_batch( return cast(List[Dict[str, Any]], [result.data for result in results]) + async def _batch_loop(self) -> None: + """Main loop of the task used to wait for requests + to execute them in a batch""" + + stop_loop = False + + while not stop_loop: + # First wait for a first request in from the batch queue + requests_and_futures: List[Tuple[GraphQLRequest, asyncio.Future]] = [] + + # Wait for the first request + request_and_future: Optional[Tuple[GraphQLRequest, asyncio.Future]] = ( + await self.batch_queue.get() + ) + + if request_and_future is None: + # None is our sentinel value to stop the loop + break + + requests_and_futures.append(request_and_future) + + # Then wait the requested batch interval except if we already + # have the maximum number of requests in the queue + if self.batch_queue.qsize() < self.client.batch_max - 1: + # Wait for the batch interval + await asyncio.sleep(self.client.batch_interval) + + # Then get the requests which had been made during that wait interval + for _ in range(self.client.batch_max - 1): + try: + # Use get_nowait since we don't want to wait here + request_and_future = self.batch_queue.get_nowait() + + if request_and_future is None: + # Sentinel value - stop after processing current batch + stop_loop = True + break + + requests_and_futures.append(request_and_future) + + except asyncio.QueueEmpty: + # No more requests in queue, that's fine + break + + # Extract requests and futures + requests = [request for request, _ in requests_and_futures] + futures = [future for _, future in requests_and_futures] + + # Execute the batch + try: + results: List[ExecutionResult] = await self._execute_batch( + requests, + serialize_variables=False, # already done + parse_result=False, # will be done later + validate_document=False, # already validated + ) + + # Set the result for each future + for result, future in zip(results, futures): + if not future.cancelled(): + future.set_result(result) + + except Exception as exc: + # If batch execution fails, propagate the error to all futures + for future in futures: + if not future.cancelled(): + future.set_exception(exc) + + # Signal that the task has stopped + self._batch_task_stopped_event.set() + + async def _execute_future( + self, + request: GraphQLRequest, + ) -> asyncio.Future: + """If batching is enabled, this method will put a request in the batching queue + instead of executing it directly so that the requests could be put in a batch. + """ + + assert hasattr(self, "batch_queue"), "Batching is not enabled" + assert not self._batch_task_stop_requested, "Batching task has been stopped" + + future: asyncio.Future = asyncio.Future() + await self.batch_queue.put((request, future)) + + return future + + async def _batch_init(self): + """Initialize the batch task loop if batching is enabled.""" + if self.client.batching_enabled: + self.batch_queue: asyncio.Queue = asyncio.Queue() + self._batch_task_stop_requested = False + self._batch_task_stopped_event = asyncio.Event() + self._batch_task = asyncio.create_task(self._batch_loop()) + + async def _batch_cleanup(self): + """Cleanup the batching task if batching is enabled.""" + if hasattr(self, "_batch_task_stopped_event"): + # Send a None in the queue to indicate that the batching task must stop + # after having processed the remaining requests in the queue + self._batch_task_stop_requested = True + await self.batch_queue.put(None) + + # Wait for the task to process remaining requests and stop + await self._batch_task_stopped_event.wait() + + async def connect(self): + """Connect the transport and initialize the batch task loop if batching + is enabled.""" + + await self._batch_init() + + try: + await self.transport.connect() + except Exception as e: + await self.transport.close() + raise e + + async def close(self): + """Close the transport and cleanup the batching task if batching is enabled. + + Will wait until all the remaining requests in the batch processing queue + have been executed. + """ + await self._batch_cleanup() + + await self.transport.close() + async def fetch_schema(self) -> None: """Fetch the GraphQL schema explicitly using introspection. @@ -1954,6 +2080,23 @@ async def stop_connecting_task(self): self._connect_task.cancel() self._connect_task = None + async def connect(self): + """Start the connect task and initialize the batch task loop if batching + is enabled.""" + + await self._batch_init() + + await self.start_connecting_task() + + async def close(self): + """Stop the connect task and cleanup the batching task + if batching is enabled.""" + await self._batch_cleanup() + + await self.stop_connecting_task() + + await self.transport.close() + async def _execute_once( self, request: GraphQLRequest, diff --git a/gql/transport/aiohttp.py b/gql/transport/aiohttp.py index 2c0d8fa7..61d01fb4 100644 --- a/gql/transport/aiohttp.py +++ b/gql/transport/aiohttp.py @@ -274,22 +274,35 @@ def _prepare_file_uploads( return post_args - async def raise_response_error( - self, + @staticmethod + def _raise_transport_server_error_if_status_more_than_400( resp: aiohttp.ClientResponse, - reason: str, ) -> None: - # We raise a TransportServerError if status code is 400 or higher - # We raise a TransportProtocolError in the other cases - + # If the status is >400, + # then we need to raise a TransportServerError try: # Raise ClientResponseError if response status is 400 or higher resp.raise_for_status() except ClientResponseError as e: raise TransportServerError(str(e), e.status) from e + @classmethod + async def _raise_response_error( + cls, + resp: aiohttp.ClientResponse, + reason: str, + ) -> None: + # We raise a TransportServerError if status code is 400 or higher + # We raise a TransportProtocolError in the other cases + + cls._raise_transport_server_error_if_status_more_than_400(resp) + result_text = await resp.text() - self._raise_invalid_result(result_text, reason) + raise TransportProtocolError( + f"Server did not return a valid GraphQL result: " + f"{reason}: " + f"{result_text}" + ) async def _get_json_result(self, response: aiohttp.ClientResponse) -> Any: @@ -304,10 +317,10 @@ async def _get_json_result(self, response: aiohttp.ClientResponse) -> Any: log.debug("<<< %s", result_text) except Exception: - await self.raise_response_error(response, "Not a JSON answer") + await self._raise_response_error(response, "Not a JSON answer") if result is None: - await self.raise_response_error(response, "Not a JSON answer") + await self._raise_response_error(response, "Not a JSON answer") return result @@ -318,7 +331,7 @@ async def _prepare_result( result = await self._get_json_result(response) if "errors" not in result and "data" not in result: - await self.raise_response_error( + await self._raise_response_error( response, 'No "data" or "errors" keys in answer' ) @@ -336,14 +349,13 @@ async def _prepare_batch_result( answers = await self._get_json_result(response) - return get_batch_execution_result_list(reqs, answers) - - def _raise_invalid_result(self, result_text: str, reason: str) -> None: - raise TransportProtocolError( - f"Server did not return a valid GraphQL result: " - f"{reason}: " - f"{result_text}" - ) + try: + return get_batch_execution_result_list(reqs, answers) + except TransportProtocolError: + # Raise a TransportServerError if status > 400 + self._raise_transport_server_error_if_status_more_than_400(response) + # In other cases, raise a TransportProtocolError + raise async def execute( self, diff --git a/gql/transport/httpx.py b/gql/transport/httpx.py index 76324cd7..afb1360c 100644 --- a/gql/transport/httpx.py +++ b/gql/transport/httpx.py @@ -195,18 +195,33 @@ def _prepare_batch_result( answers = self._get_json_result(response) - return get_batch_execution_result_list(reqs, answers) - - def _raise_response_error(self, response: httpx.Response, reason: str) -> NoReturn: - # We raise a TransportServerError if the status code is 400 or higher - # We raise a TransportProtocolError in the other cases - try: - # Raise a HTTPError if response status is 400 or higher + return get_batch_execution_result_list(reqs, answers) + except TransportProtocolError: + # Raise a TransportServerError if status > 400 + self._raise_transport_server_error_if_status_more_than_400(response) + # In other cases, raise a TransportProtocolError + raise + + @staticmethod + def _raise_transport_server_error_if_status_more_than_400( + response: httpx.Response, + ) -> None: + # If the status is >400, + # then we need to raise a TransportServerError + try: + # Raise a HTTPStatusError if response status is 400 or higher response.raise_for_status() except httpx.HTTPStatusError as e: raise TransportServerError(str(e), e.response.status_code) from e + @classmethod + def _raise_response_error(cls, response: httpx.Response, reason: str) -> NoReturn: + # We raise a TransportServerError if the status code is 400 or higher + # We raise a TransportProtocolError in the other cases + + cls._raise_transport_server_error_if_status_more_than_400(response) + raise TransportProtocolError( f"Server did not return a GraphQL result: " f"{reason}: " f"{response.text}" ) diff --git a/gql/transport/requests.py b/gql/transport/requests.py index 7be288d2..16d07025 100644 --- a/gql/transport/requests.py +++ b/gql/transport/requests.py @@ -258,24 +258,6 @@ def execute( # type: ignore self.response_headers = response.headers - def raise_response_error(resp: requests.Response, reason: str) -> NoReturn: - # We raise a TransportServerError if the status code is 400 or higher - # We raise a TransportProtocolError in the other cases - - try: - # Raise a HTTPError if response status is 400 or higher - resp.raise_for_status() - except requests.HTTPError as e: - status_code = e.response.status_code if e.response is not None else None - raise TransportServerError(str(e), status_code) from e - - result_text = resp.text - raise TransportProtocolError( - f"Server did not return a GraphQL result: " - f"{reason}: " - f"{result_text}" - ) - try: if self.json_deserialize == json.loads: result = response.json() @@ -286,10 +268,10 @@ def raise_response_error(resp: requests.Response, reason: str) -> NoReturn: log.debug("<<< %s", response.text) except Exception: - raise_response_error(response, "Not a JSON answer") + self._raise_response_error(response, "Not a JSON answer") if "errors" not in result and "data" not in result: - raise_response_error(response, 'No "data" or "errors" keys in answer') + self._raise_response_error(response, 'No "data" or "errors" keys in answer') return ExecutionResult( errors=result.get("errors"), @@ -297,6 +279,31 @@ def raise_response_error(resp: requests.Response, reason: str) -> NoReturn: extensions=result.get("extensions"), ) + @staticmethod + def _raise_transport_server_error_if_status_more_than_400( + response: requests.Response, + ) -> None: + # If the status is >400, + # then we need to raise a TransportServerError + try: + # Raise a HTTPError if response status is 400 or higher + response.raise_for_status() + except requests.HTTPError as e: + status_code = e.response.status_code if e.response is not None else None + raise TransportServerError(str(e), status_code) from e + + @classmethod + def _raise_response_error(cls, resp: requests.Response, reason: str) -> NoReturn: + # We raise a TransportServerError if the status code is 400 or higher + # We raise a TransportProtocolError in the other cases + + cls._raise_transport_server_error_if_status_more_than_400(resp) + + result_text = resp.text + raise TransportProtocolError( + f"Server did not return a GraphQL result: " f"{reason}: " f"{result_text}" + ) + def execute_batch( self, reqs: List[GraphQLRequest], @@ -330,30 +337,23 @@ def execute_batch( answers = self._extract_response(response) - return get_batch_execution_result_list(reqs, answers) - - def _raise_invalid_result(self, result_text: str, reason: str) -> None: - raise TransportProtocolError( - f"Server did not return a valid GraphQL result: " - f"{reason}: " - f"{result_text}" - ) + try: + return get_batch_execution_result_list(reqs, answers) + except TransportProtocolError: + # Raise a TransportServerError if status > 400 + self._raise_transport_server_error_if_status_more_than_400(response) + # In other cases, raise a TransportProtocolError + raise def _extract_response(self, response: requests.Response) -> Any: try: - response.raise_for_status() result = response.json() if log.isEnabledFor(logging.DEBUG): log.debug("<<< %s", response.text) - except requests.HTTPError as e: - raise TransportServerError( - str(e), e.response.status_code if e.response is not None else None - ) from e - except Exception: - self._raise_invalid_result(str(response.text), "Not a JSON answer") + self._raise_response_error(response, "Not a JSON answer") return result diff --git a/tests/test_aiohttp_batch.py b/tests/test_aiohttp_batch.py index f04f05e4..e3407a4d 100644 --- a/tests/test_aiohttp_batch.py +++ b/tests/test_aiohttp_batch.py @@ -1,3 +1,4 @@ +import asyncio from typing import Mapping import pytest @@ -7,6 +8,7 @@ TransportClosed, TransportProtocolError, TransportQueryError, + TransportServerError, ) # Marking all tests in this file with the aiohttp marker @@ -29,6 +31,21 @@ '{"code":"SA","name":"South America"}]}}]' ) +query1_server_answer_twice_list = ( + "[" + '{"data":{"continents":[' + '{"code":"AF","name":"Africa"},{"code":"AN","name":"Antarctica"},' + '{"code":"AS","name":"Asia"},{"code":"EU","name":"Europe"},' + '{"code":"NA","name":"North America"},{"code":"OC","name":"Oceania"},' + '{"code":"SA","name":"South America"}]}},' + '{"data":{"continents":[' + '{"code":"AF","name":"Africa"},{"code":"AN","name":"Antarctica"},' + '{"code":"AS","name":"Asia"},{"code":"EU","name":"Europe"},' + '{"code":"NA","name":"North America"},{"code":"OC","name":"Oceania"},' + '{"code":"SA","name":"South America"}]}}' + "]" +) + @pytest.mark.asyncio async def test_aiohttp_batch_query(aiohttp_server): @@ -72,6 +89,179 @@ async def handler(request): assert transport.response_headers["dummy"] == "test1234" +@pytest.mark.asyncio +async def test_aiohttp_batch_query_auto_batch_enabled(aiohttp_server, run_sync_test): + from aiohttp import web + + from gql.transport.aiohttp import AIOHTTPTransport + + async def handler(request): + return web.Response( + text=query1_server_answer_list, + content_type="application/json", + headers={"dummy": "test1234"}, + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + transport = AIOHTTPTransport(url=url, timeout=10) + + async with Client( + transport=transport, + batch_interval=0.01, # 10ms batch interval + ) as session: + + query = gql(query1_str) + + result = await session.execute(query) + + continents = result["continents"] + + africa = continents[0] + + assert africa["code"] == "AF" + + # Checking response headers are saved in the transport + assert hasattr(transport, "response_headers") + assert isinstance(transport.response_headers, Mapping) + assert transport.response_headers["dummy"] == "test1234" + + +@pytest.mark.asyncio +async def test_aiohttp_batch_auto_two_requests(aiohttp_server): + from aiohttp import web + + from gql.transport.aiohttp import AIOHTTPTransport + + async def handler(request): + return web.Response( + text=query1_server_answer_twice_list, + content_type="application/json", + headers={"dummy": "test1234"}, + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + transport = AIOHTTPTransport(url=url, timeout=10) + + async with Client( + transport=transport, + batch_interval=0.01, + ) as session: + + async def test_coroutine(): + query = gql(query1_str) + + # Execute query asynchronously + result = await session.execute(query) + + continents = result["continents"] + + africa = continents[0] + + assert africa["code"] == "AF" + + # Create two concurrent tasks that will be batched together + tasks = [] + for _ in range(2): + task = asyncio.create_task(test_coroutine()) + tasks.append(task) + + # Wait for all tasks to complete + await asyncio.gather(*tasks) + + +@pytest.mark.asyncio +async def test_aiohttp_batch_auto_two_requests_close_session_directly(aiohttp_server): + from aiohttp import web + + from gql.transport.aiohttp import AIOHTTPTransport + + async def handler(request): + return web.Response( + text=query1_server_answer_twice_list, + content_type="application/json", + headers={"dummy": "test1234"}, + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + transport = AIOHTTPTransport(url=url, timeout=10) + + async with Client( + transport=transport, + batch_interval=0.1, + ) as session: + + async def test_coroutine(): + query = gql(query1_str) + + # Execute query asynchronously + result = await session.execute(query) + + continents = result["continents"] + + africa = continents[0] + + assert africa["code"] == "AF" + + # Create two concurrent tasks that will be batched together + tasks = [] + for _ in range(2): + task = asyncio.create_task(test_coroutine()) + tasks.append(task) + + await asyncio.sleep(0.01) + + # Wait for all tasks to complete + await asyncio.gather(*tasks) + + +@pytest.mark.asyncio +async def test_aiohttp_batch_error_code_401(aiohttp_server): + from aiohttp import web + + from gql.transport.aiohttp import AIOHTTPTransport + + async def handler(request): + # Will generate http error code 401 + return web.Response( + text='{"error":"Unauthorized","message":"401 Client Error: Unauthorized"}', + content_type="application/json", + status=401, + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + transport = AIOHTTPTransport(url=url, timeout=10) + + async with Client( + transport=transport, + batch_interval=0.01, # 10ms batch interval + ) as session: + + query = gql(query1_str) + + with pytest.raises(TransportServerError) as exc_info: + await session.execute(query) + + assert "401, message='Unauthorized'" in str(exc_info.value) + + @pytest.mark.asyncio async def test_aiohttp_batch_query_without_session(aiohttp_server, run_sync_test): from aiohttp import web From 93274eadc8f6fe36238cc95628890ef64d5fb782 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Tue, 27 May 2025 18:20:32 +0200 Subject: [PATCH 2/3] Addind Batching docs --- README.md | 1 + docs/advanced/batching_requests.rst | 96 +++++++++++++++++++++++++++++ docs/advanced/index.rst | 1 + 3 files changed, 98 insertions(+) create mode 100644 docs/advanced/batching_requests.rst diff --git a/README.md b/README.md index cbc53af6..e79a63d2 100644 --- a/README.md +++ b/README.md @@ -40,6 +40,7 @@ The complete documentation for GQL can be found at * Supports [sync or async usage](https://gql.readthedocs.io/en/latest/async/index.html), [allowing concurrent requests](https://gql.readthedocs.io/en/latest/advanced/async_advanced_usage.html#async-advanced-usage) * Supports [File uploads](https://gql.readthedocs.io/en/latest/usage/file_upload.html) * Supports [Custom scalars / Enums](https://gql.readthedocs.io/en/latest/usage/custom_scalars_and_enums.html) +* Supports [Batching requests](https://gql.readthedocs.io/en/latest/advanced/batching_requests.html) * [gql-cli script](https://gql.readthedocs.io/en/latest/gql-cli/intro.html) to execute GraphQL queries or download schemas from the command line * [DSL module](https://gql.readthedocs.io/en/latest/advanced/dsl_module.html) to compose GraphQL queries dynamically diff --git a/docs/advanced/batching_requests.rst b/docs/advanced/batching_requests.rst new file mode 100644 index 00000000..a71d4ffc --- /dev/null +++ b/docs/advanced/batching_requests.rst @@ -0,0 +1,96 @@ +.. _batching_requests: + +Batching requests +================= + +If you need to send multiple GraphQL queries to a backend, +and if the backend supports batch requests, +then you might want to send those requests in a batch instead of +making multiple execution requests. + +.. warning:: + - Some backends do not support batch requests + - File uploads and subscriptions are not supported with batch requests + +Batching requests manually +^^^^^^^^^^^^^^^^^^^^^^^^^^ + +To execute a batch of requests manually: + +- First Make a list of :class:`GraphQLRequest ` objects, containing: + * your GraphQL query + * Optional variable_values + * Optional operation_name + +.. code-block:: python + + request1 = GraphQLRequest(""" + query getContinents { + continents { + code + name + } + } + """ + ) + + request2 = GraphQLRequest(""" + query getContinentName ($code: ID!) { + continent (code: $code) { + name + } + } + """, + variable_values={ + "code": "AF", + }, + ) + + requests = [request1, request2] + +- Then use one of the `execute_batch` methods, either on Client, + or in a sync or async session + +**Sync**: + +.. code-block:: python + + transport = RequestsHTTPTransport(url=url) + # Or transport = HTTPXTransport(url=url) + + with Client(transport=transport) as session: + + results = session.execute_batch(requests) + + result1 = results[0] + result2 = results[1] + +**Async**: + +.. code-block:: python + + transport = AIOHTTPTransport(url=url) + # Or transport = HTTPXAsyncTransport(url=url) + + async with Client(transport=transport) as session: + + results = await session.execute_batch(requests) + + result1 = results[0] + result2 = results[1] + +.. note:: + If any request in the batch returns an error, then a TransportQueryError will be raised + with the first error found. + +Automatic Batching of requests +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +If your code execute multiple requests independently in a short time +(either from different threads in sync code, or from different asyncio tasks in async code), +then you can use gql automatic batching of request functionality. + +You define a :code:`batching_interval` in your :class:`Client ` +and each time a new execution request is received through an `execute` method, +we will wait that interval (in seconds) for other requests to arrive +before sending all the requests received in that interval in a single batch. diff --git a/docs/advanced/index.rst b/docs/advanced/index.rst index baae9276..ef14defd 100644 --- a/docs/advanced/index.rst +++ b/docs/advanced/index.rst @@ -6,6 +6,7 @@ Advanced async_advanced_usage async_permanent_session + batching_requests logging error_handling local_schema From d02bed3f932af2b396f6f991885ac13204da573f Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Tue, 27 May 2025 18:46:54 +0200 Subject: [PATCH 3/3] Modify GraphQLRequest to allow str input --- gql/graphql_request.py | 39 ++++++++++++++++++++++++----------- tests/test_graphql_request.py | 12 ++++++++++- 2 files changed, 38 insertions(+), 13 deletions(-) diff --git a/gql/graphql_request.py b/gql/graphql_request.py index 7289a8f9..29a34717 100644 --- a/gql/graphql_request.py +++ b/gql/graphql_request.py @@ -1,26 +1,38 @@ -from dataclasses import dataclass -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Union from graphql import DocumentNode, GraphQLSchema, print_ast +from .gql import gql from .utilities import serialize_variable_values -@dataclass(frozen=True) class GraphQLRequest: """GraphQL Request to be executed.""" - document: DocumentNode - """GraphQL query as AST Node object.""" + def __init__( + self, + document: Union[DocumentNode, str], + *, + variable_values: Optional[Dict[str, Any]] = None, + operation_name: Optional[str] = None, + ): + """ + Initialize a GraphQL request. - variable_values: Optional[Dict[str, Any]] = None - """Dictionary of input parameters (Default: None).""" + Args: + document: GraphQL query as AST Node object or as a string. + If string, it will be converted to DocumentNode using gql(). + variable_values: Dictionary of input parameters (Default: None). + operation_name: Name of the operation that shall be executed. + Only required in multi-operation documents (Default: None). + """ + if isinstance(document, str): + self.document = gql(document) + else: + self.document = document - operation_name: Optional[str] = None - """ - Name of the operation that shall be executed. - Only required in multi-operation documents (Default: None). - """ + self.variable_values = variable_values + self.operation_name = operation_name def serialize_variable_values(self, schema: GraphQLSchema) -> "GraphQLRequest": assert self.variable_values @@ -48,3 +60,6 @@ def payload(self) -> Dict[str, Any]: payload["variables"] = self.variable_values return payload + + def __str__(self): + return str(self.payload) diff --git a/tests/test_graphql_request.py b/tests/test_graphql_request.py index 4c9e7d76..346dc00e 100644 --- a/tests/test_graphql_request.py +++ b/tests/test_graphql_request.py @@ -20,7 +20,7 @@ from gql import GraphQLRequest, gql -from .conftest import MS +from .conftest import MS, strip_braces_spaces # Marking all tests in this file with the aiohttp marker pytestmark = pytest.mark.aiohttp @@ -200,3 +200,13 @@ def test_serialize_variables_using_money_example(): req = req.serialize_variable_values(schema) assert req.variable_values == {"money": {"amount": 10, "currency": "DM"}} + + +def test_graphql_request_using_string_instead_of_document(): + request = GraphQLRequest("{balance}") + + expected_payload = "{'query': '{\\n balance\\n}'}" + + print(request) + + assert str(request) == strip_braces_spaces(expected_payload)