diff --git a/gql/transport/aiohttp.py b/gql/transport/aiohttp.py index 61d01fb4..40e212cf 100644 --- a/gql/transport/aiohttp.py +++ b/gql/transport/aiohttp.py @@ -127,7 +127,7 @@ async def connect(self) -> None: # Adding custom parameters passed from init if self.client_session_args: - client_session_args.update(self.client_session_args) # type: ignore + client_session_args.update(self.client_session_args) log.debug("Connecting transport") @@ -164,36 +164,22 @@ async def close(self) -> None: self.session = None - def _prepare_batch_request( - self, - reqs: List[GraphQLRequest], - extra_args: Optional[Dict[str, Any]] = None, - ) -> Dict[str, Any]: - - payload = [req.payload for req in reqs] - - post_args = {"json": payload} - - # Log the payload - if log.isEnabledFor(logging.DEBUG): - log.debug(">>> %s", self.json_serialize(post_args["json"])) - - # Pass post_args to aiohttp post method - if extra_args: - post_args.update(extra_args) - - return post_args - def _prepare_request( self, - request: GraphQLRequest, + request: Union[GraphQLRequest, List[GraphQLRequest]], extra_args: Optional[Dict[str, Any]] = None, upload_files: bool = False, ) -> Dict[str, Any]: - payload = request.payload + payload: Dict | List + if isinstance(request, GraphQLRequest): + payload = request.payload + else: + payload = [req.payload for req in request] if upload_files: + assert isinstance(payload, Dict) + assert isinstance(request, GraphQLRequest) post_args = self._prepare_file_uploads(request, payload) else: post_args = {"json": payload} @@ -379,15 +365,15 @@ async def execute( :returns: an ExecutionResult object. """ + if self.session is None: + raise TransportClosed("Transport is not connected") + post_args = self._prepare_request( request, extra_args, upload_files, ) - if self.session is None: - raise TransportClosed("Transport is not connected") - try: async with self.session.post(self.url, ssl=self.ssl, **post_args) as resp: return await self._prepare_result(resp) @@ -413,14 +399,14 @@ async def execute_batch( if an error occurred. """ - post_args = self._prepare_batch_request( + if self.session is None: + raise TransportClosed("Transport is not connected") + + post_args = self._prepare_request( reqs, extra_args, ) - if self.session is None: - raise TransportClosed("Transport is not connected") - async with self.session.post(self.url, ssl=self.ssl, **post_args) as resp: return await self._prepare_batch_result(reqs, resp) diff --git a/gql/transport/httpx.py b/gql/transport/httpx.py index afb1360c..7fe2a7db 100644 --- a/gql/transport/httpx.py +++ b/gql/transport/httpx.py @@ -59,15 +59,22 @@ def __init__( def _prepare_request( self, - req: GraphQLRequest, + request: Union[GraphQLRequest, List[GraphQLRequest]], + *, extra_args: Optional[Dict[str, Any]] = None, upload_files: bool = False, ) -> Dict[str, Any]: - payload = req.payload + payload: Dict | List + if isinstance(request, GraphQLRequest): + payload = request.payload + else: + payload = [req.payload for req in request] if upload_files: - post_args = self._prepare_file_uploads(req, payload) + assert isinstance(payload, Dict) + assert isinstance(request, GraphQLRequest) + post_args = self._prepare_file_uploads(request, payload) else: post_args = {"json": payload} @@ -81,26 +88,6 @@ def _prepare_request( return post_args - def _prepare_batch_request( - self, - reqs: List[GraphQLRequest], - extra_args: Optional[Dict[str, Any]] = None, - ) -> Dict[str, Any]: - - payload = [req.payload for req in reqs] - - post_args = {"json": payload} - - # Log the payload - if log.isEnabledFor(logging.DEBUG): - log.debug(">>> %s", self.json_serialize(payload)) - - # Pass post_args to aiohttp post method - if extra_args: - post_args.update(extra_args) - - return post_args - def _prepare_file_uploads( self, request: GraphQLRequest, @@ -244,7 +231,7 @@ def connect(self): self.client = httpx.Client(**self.kwargs) - def execute( # type: ignore + def execute( self, request: GraphQLRequest, *, @@ -269,8 +256,8 @@ def execute( # type: ignore post_args = self._prepare_request( request, - extra_args, - upload_files, + extra_args=extra_args, + upload_files=upload_files, ) try: @@ -292,7 +279,7 @@ def execute_batch( :code:`execute_batch` on a client or a session. :param reqs: GraphQL requests as a list of GraphQLRequest objects. - :param extra_args: additional arguments to send to the aiohttp post method + :param extra_args: additional arguments to send to the httpx post method :return: A list of results of execution. For every result `data` is the result of executing the query, `errors` is null if no errors occurred, and is a non-empty array @@ -302,9 +289,9 @@ def execute_batch( if not self.client: raise TransportClosed("Transport is not connected") - post_args = self._prepare_batch_request( + post_args = self._prepare_request( reqs, - extra_args, + extra_args=extra_args, ) response = self.client.post(self.url, **post_args) @@ -361,8 +348,8 @@ async def execute( post_args = self._prepare_request( request, - extra_args, - upload_files, + extra_args=extra_args, + upload_files=upload_files, ) try: @@ -384,7 +371,7 @@ async def execute_batch( :code:`execute_batch` on a client or a session. :param reqs: GraphQL requests as a list of GraphQLRequest objects. - :param extra_args: additional arguments to send to the aiohttp post method + :param extra_args: additional arguments to send to the httpx post method :return: A list of results of execution. For every result `data` is the result of executing the query, `errors` is null if no errors occurred, and is a non-empty array @@ -394,9 +381,9 @@ async def execute_batch( if not self.client: raise TransportClosed("Transport is not connected") - post_args = self._prepare_batch_request( + post_args = self._prepare_request( reqs, - extra_args, + extra_args=extra_args, ) response = await self.client.post(self.url, **post_args) diff --git a/gql/transport/requests.py b/gql/transport/requests.py index 16d07025..17bf4695 100644 --- a/gql/transport/requests.py +++ b/gql/transport/requests.py @@ -137,32 +137,20 @@ def connect(self): else: raise TransportAlreadyConnected("Transport is already connected") - def execute( # type: ignore + def _prepare_request( self, - request: GraphQLRequest, + request: Union[GraphQLRequest, List[GraphQLRequest]], + *, timeout: Optional[int] = None, extra_args: Optional[Dict[str, Any]] = None, upload_files: bool = False, - ) -> ExecutionResult: - """Execute GraphQL query. - - Execute the provided request against the configured remote server. This - uses the requests library to perform a HTTP POST request to the remote server. - - :param request: GraphQL request as a - :class:`GraphQLRequest ` object. - :param timeout: Specifies a default timeout for requests (Default: None). - :param extra_args: additional arguments to send to the requests post method - :param upload_files: Set to True if you want to put files in the variable values - :return: The result of execution. - `data` is the result of executing the query, `errors` is null - if no errors occurred, and is a non-empty array if an error occurred. - """ - - if not self.session: - raise TransportClosed("Transport is not connected") + ) -> Dict[str, Any]: - payload = request.payload + payload: Dict | List + if isinstance(request, GraphQLRequest): + payload = request.payload + else: + payload = [req.payload for req in request] post_args: Dict[str, Any] = { "headers": self.headers, @@ -173,111 +161,139 @@ def execute( # type: ignore } if upload_files: - # If the upload_files flag is set, then we need variable_values - assert request.variable_values is not None - - # If we upload files, we will extract the files present in the - # variable_values dict and replace them by null values - nulled_variable_values, files = extract_files( - variables=request.variable_values, - file_classes=self.file_classes, + assert isinstance(payload, Dict) + assert isinstance(request, GraphQLRequest) + post_args = self._prepare_file_uploads( + request=request, + payload=payload, + post_args=post_args, ) - # Opening the files using the FileVar parameters - open_files(list(files.values())) - self.files = files + else: + data_key = "json" if self.use_json else "data" + post_args[data_key] = payload - # Save the nulled variable values in the payload - payload["variables"] = nulled_variable_values + # Log the payload + if log.isEnabledFor(logging.DEBUG): + log.debug(">>> %s", self.json_serialize(payload)) - # Add the payload to the operations field - operations_str = self.json_serialize(payload) - log.debug("operations %s", operations_str) + # Pass kwargs to requests post method + post_args.update(self.kwargs) - # Generate the file map - # path is nested in a list because the spec allows multiple pointers - # to the same file. But we don't support that. - # Will generate something like {"0": ["variables.file"]} - file_map = {str(i): [path] for i, path in enumerate(files)} + # Pass post_args to requests post method + if extra_args: + post_args.update(extra_args) + + return post_args - # Enumerate the file streams - # Will generate something like {'0': FileVar object} - file_vars = {str(i): files[path] for i, path in enumerate(files)} + def _prepare_file_uploads( + self, + request: GraphQLRequest, + *, + payload: Dict[str, Any], + post_args: Dict[str, Any], + ) -> Dict[str, Any]: + # If the upload_files flag is set, then we need variable_values + assert request.variable_values is not None + + # If we upload files, we will extract the files present in the + # variable_values dict and replace them by null values + nulled_variable_values, files = extract_files( + variables=request.variable_values, + file_classes=self.file_classes, + ) - # Add the file map field - file_map_str = self.json_serialize(file_map) - log.debug("file_map %s", file_map_str) + # Opening the files using the FileVar parameters + open_files(list(files.values())) + self.files = files - fields = {"operations": operations_str, "map": file_map_str} + # Save the nulled variable values in the payload + payload["variables"] = nulled_variable_values - # Add the extracted files as remaining fields - for k, file_var in file_vars.items(): - assert isinstance(file_var, FileVar) - name = k if file_var.filename is None else file_var.filename + # Add the payload to the operations field + operations_str = self.json_serialize(payload) + log.debug("operations %s", operations_str) - if file_var.content_type is None: - fields[k] = (name, file_var.f) - else: - fields[k] = (name, file_var.f, file_var.content_type) + # Generate the file map + # path is nested in a list because the spec allows multiple pointers + # to the same file. But we don't support that. + # Will generate something like {"0": ["variables.file"]} + file_map = {str(i): [path] for i, path in enumerate(files)} - # Prepare requests http to send multipart-encoded data - data = MultipartEncoder(fields=fields) + # Enumerate the file streams + # Will generate something like {'0': FileVar object} + file_vars = {str(i): files[path] for i, path in enumerate(files)} - post_args["data"] = data + # Add the file map field + file_map_str = self.json_serialize(file_map) + log.debug("file_map %s", file_map_str) - if post_args["headers"] is None: - post_args["headers"] = {} + fields = {"operations": operations_str, "map": file_map_str} + + # Add the extracted files as remaining fields + for k, file_var in file_vars.items(): + assert isinstance(file_var, FileVar) + name = k if file_var.filename is None else file_var.filename + + if file_var.content_type is None: + fields[k] = (name, file_var.f) else: - post_args["headers"] = dict(post_args["headers"]) + fields[k] = (name, file_var.f, file_var.content_type) + + # Prepare requests http to send multipart-encoded data + data = MultipartEncoder(fields=fields) - post_args["headers"]["Content-Type"] = data.content_type + post_args["data"] = data + if post_args["headers"] is None: + post_args["headers"] = {} else: - data_key = "json" if self.use_json else "data" - post_args[data_key] = payload + post_args["headers"] = dict(post_args["headers"]) - # Log the payload - if log.isEnabledFor(logging.DEBUG): - log.debug(">>> %s", self.json_serialize(payload)) + post_args["headers"]["Content-Type"] = data.content_type - # Pass kwargs to requests post method - post_args.update(self.kwargs) + return post_args - # Pass post_args to requests post method - if extra_args: - post_args.update(extra_args) + def execute( + self, + request: GraphQLRequest, + timeout: Optional[int] = None, + extra_args: Optional[Dict[str, Any]] = None, + upload_files: bool = False, + ) -> ExecutionResult: + """Execute GraphQL query. + + Execute the provided request against the configured remote server. This + uses the requests library to perform a HTTP POST request to the remote server. + + :param request: GraphQL request as a + :class:`GraphQLRequest ` object. + :param timeout: Specifies a default timeout for requests (Default: None). + :param extra_args: additional arguments to send to the requests post method + :param upload_files: Set to True if you want to put files in the variable values + :return: The result of execution. + `data` is the result of executing the query, `errors` is null + if no errors occurred, and is a non-empty array if an error occurred. + """ + + if not self.session: + raise TransportClosed("Transport is not connected") + + post_args = self._prepare_request( + request, + timeout=timeout, + extra_args=extra_args, + upload_files=upload_files, + ) # Using the created session to perform requests try: - response = self.session.request( - self.method, self.url, **post_args # type: ignore - ) + response = self.session.request(self.method, self.url, **post_args) finally: if upload_files: close_files(list(self.files.values())) - self.response_headers = response.headers - - try: - if self.json_deserialize == json.loads: - result = response.json() - else: - result = self.json_deserialize(response.text) - - if log.isEnabledFor(logging.DEBUG): - log.debug("<<< %s", response.text) - - except Exception: - self._raise_response_error(response, "Not a JSON answer") - - if "errors" not in result and "data" not in result: - self._raise_response_error(response, 'No "data" or "errors" keys in answer') - - return ExecutionResult( - errors=result.get("errors"), - data=result.get("data"), - extensions=result.get("extensions"), - ) + return self._prepare_result(response) @staticmethod def _raise_transport_server_error_if_status_more_than_400( @@ -327,27 +343,27 @@ def execute_batch( if not self.session: raise TransportClosed("Transport is not connected") - # Using the created session to perform requests + post_args = self._prepare_request( + reqs, + timeout=timeout, + extra_args=extra_args, + ) + response = self.session.request( self.method, self.url, - **self._build_batch_post_args(reqs, timeout, extra_args), + **post_args, ) - self.response_headers = response.headers - answers = self._extract_response(response) + return self._prepare_batch_result(reqs, response) - try: - return get_batch_execution_result_list(reqs, answers) - except TransportProtocolError: - # Raise a TransportServerError if status > 400 - self._raise_transport_server_error_if_status_more_than_400(response) - # In other cases, raise a TransportProtocolError - raise + def _get_json_result(self, response: requests.Response) -> Any: + + # Saving latest response headers in the transport + self.response_headers = response.headers - def _extract_response(self, response: requests.Response) -> Any: try: - result = response.json() + result = self.json_deserialize(response.text) if log.isEnabledFor(logging.DEBUG): log.debug("<<< %s", response.text) @@ -357,35 +373,34 @@ def _extract_response(self, response: requests.Response) -> Any: return result - def _build_batch_post_args( - self, - reqs: List[GraphQLRequest], - timeout: Optional[int] = None, - extra_args: Optional[Dict[str, Any]] = None, - ) -> Dict[str, Any]: - post_args: Dict[str, Any] = { - "headers": self.headers, - "auth": self.auth, - "cookies": self.cookies, - "timeout": timeout or self.default_timeout, - "verify": self.verify, - } + def _prepare_result(self, response: requests.Response) -> ExecutionResult: - data_key = "json" if self.use_json else "data" - post_args[data_key] = [req.payload for req in reqs] + result = self._get_json_result(response) - # Log the payload - if log.isEnabledFor(logging.DEBUG): - log.debug(">>> %s", self.json_serialize(post_args[data_key])) + if "errors" not in result and "data" not in result: + self._raise_response_error(response, 'No "data" or "errors" keys in answer') - # Pass kwargs to requests post method - post_args.update(self.kwargs) + return ExecutionResult( + errors=result.get("errors"), + data=result.get("data"), + extensions=result.get("extensions"), + ) - # Pass post_args to requests post method - if extra_args: - post_args.update(extra_args) + def _prepare_batch_result( + self, + reqs: List[GraphQLRequest], + response: requests.Response, + ) -> List[ExecutionResult]: - return post_args + answers = self._get_json_result(response) + + try: + return get_batch_execution_result_list(reqs, answers) + except TransportProtocolError: + # Raise a TransportServerError if status > 400 + self._raise_transport_server_error_if_status_more_than_400(response) + # In other cases, raise a TransportProtocolError + raise def close(self): """Closing the transport by closing the inner session"""