From 8f58dfe54265947413c385e48a9f9e2cc589c404 Mon Sep 17 00:00:00 2001 From: euri10 Date: Thu, 24 Apr 2025 08:55:29 +0200 Subject: [PATCH 01/14] added failing test --- .../test_testing/test_lifespan_handler.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/tests/unit/test_testing/test_lifespan_handler.py b/tests/unit/test_testing/test_lifespan_handler.py index f04f91a959..bd4bcd9f07 100644 --- a/tests/unit/test_testing/test_lifespan_handler.py +++ b/tests/unit/test_testing/test_lifespan_handler.py @@ -1,6 +1,9 @@ +import asyncio + import pytest -from litestar.testing import TestClient +from litestar import Litestar, get +from litestar.testing import AsyncTestClient, TestClient from litestar.testing.life_span_handler import LifeSpanHandler from litestar.types import Receive, Scope, Send @@ -35,3 +38,17 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: handler = LifeSpanHandler(TestClient(app)) await handler.wait_shutdown() handler.close() + + +async def test_multiple_clients_event_loop() -> None: + @get("/") + def return_loop_id() -> dict: + return {"loop_id": id(asyncio.get_event_loop())} + + app = Litestar(route_handlers=[return_loop_id]) + + async with AsyncTestClient(app) as client_1, AsyncTestClient(app) as client_2: + response_1 = await client_1.get("/") + response_2 = await client_2.get("/") + + assert response_1.json() == response_2.json() # FAILS From 7d756cbee888053886f9e94cb848f07d71a4cae2 Mon Sep 17 00:00:00 2001 From: euri10 Date: Thu, 24 Apr 2025 09:18:27 +0200 Subject: [PATCH 02/14] fix: add a AsyncTestClientTransport for the AsyncTestClient --- litestar/testing/client/async_client.py | 4 +- litestar/testing/transport.py | 152 +++++++++++++++++++++++- 2 files changed, 150 insertions(+), 6 deletions(-) diff --git a/litestar/testing/client/async_client.py b/litestar/testing/client/async_client.py index 4bf9eec087..edce6b2e85 100644 --- a/litestar/testing/client/async_client.py +++ b/litestar/testing/client/async_client.py @@ -7,7 +7,7 @@ from litestar.testing.client.base import BaseTestClient from litestar.testing.life_span_handler import LifeSpanHandler -from litestar.testing.transport import ConnectionUpgradeExceptionError, TestClientTransport +from litestar.testing.transport import AsyncTestClientTransport, ConnectionUpgradeExceptionError from litestar.types import AnyIOBackend, ASGIApp if TYPE_CHECKING: @@ -74,7 +74,7 @@ def __init__( headers={"user-agent": "testclient"}, follow_redirects=True, cookies=cookies, - transport=TestClientTransport( # type: ignore [arg-type] + transport=AsyncTestClientTransport( # type: ignore [arg-type] client=self, raise_server_exceptions=raise_server_exceptions, root_path=root_path, diff --git a/litestar/testing/transport.py b/litestar/testing/transport.py index 7a965dc9fd..26ac95ba8a 100644 --- a/litestar/testing/transport.py +++ b/litestar/testing/transport.py @@ -43,6 +43,153 @@ class SendReceiveContext(TypedDict): context: Any | None +class AsyncTestClientTransport(Generic[T]): + def __init__( + self, + client: T, + raise_server_exceptions: bool = True, + root_path: str = "", + ) -> None: + self.client = client + self.raise_server_exceptions = raise_server_exceptions + self.root_path = root_path + + @staticmethod + def create_receive(request: Request, context: SendReceiveContext) -> Receive: + async def receive() -> ReceiveMessage: + if context["request_complete"]: + if not context["response_complete"].is_set(): + await context["response_complete"].wait() + disconnect_event: HTTPDisconnectEvent = {"type": "http.disconnect"} + return disconnect_event + + body = cast("Union[bytes, str, GeneratorType]", (request.read() or b"")) + request_event: HTTPRequestEvent = {"type": "http.request", "body": b"", "more_body": False} + if isinstance(body, GeneratorType): # pragma: no cover + try: + chunk = body.send(None) + request_event["body"] = chunk if isinstance(chunk, bytes) else chunk.encode("utf-8") + request_event["more_body"] = True + except StopIteration: + context["request_complete"] = True + else: + context["request_complete"] = True + request_event["body"] = body if isinstance(body, bytes) else body.encode("utf-8") + return request_event + + return receive + + @staticmethod + def create_send(request: Request, context: SendReceiveContext) -> Send: + async def send(message: Message) -> None: + if message["type"] == "http.response.start": + assert not context["response_started"], 'Received multiple "http.response.start" messages.' # noqa: S101 + context["raw_kwargs"]["status_code"] = message["status"] + context["raw_kwargs"]["headers"] = [ + (k.decode("utf-8"), v.decode("utf-8")) for k, v in message.get("headers", []) + ] + context["response_started"] = True + elif message["type"] == "http.response.body": + assert context["response_started"], 'Received "http.response.body" without "http.response.start".' # noqa: S101 + assert not context[ # noqa: S101 + "response_complete" + ].is_set(), 'Received "http.response.body" after response completed.' + body = message.get("body", b"") + more_body = message.get("more_body", False) + if request.method != "HEAD": + context["raw_kwargs"]["stream"].write(body) + if not more_body: + context["raw_kwargs"]["stream"].seek(0) + context["response_complete"].set() + elif message["type"] == "http.response.template": # type: ignore[comparison-overlap] # pragma: no cover + context["template"] = message["template"] # type: ignore[unreachable] + context["context"] = message["context"] + + return send + + def parse_request(self, request: Request) -> dict[str, Any]: + scheme = request.url.scheme + netloc = unquote(request.url.netloc.decode(encoding="ascii")) + path = request.url.path + raw_path = request.url.raw_path + query = request.url.query.decode(encoding="ascii") + default_port = 433 if scheme in {"https", "wss"} else 80 + + if ":" in netloc: + host, port_string = netloc.split(":", 1) + port = int(port_string) + else: + host = netloc + port = default_port + + host_header = request.headers.pop("host", host if port == default_port else f"{host}:{port}") + + headers = [(k.lower().encode(), v.encode()) for k, v in (("host", host_header), *request.headers.items())] + + return { + "type": "websocket" if scheme in {"ws", "wss"} else "http", + "path": unquote(path), + "raw_path": raw_path, + "root_path": self.root_path, + "scheme": scheme, + "query_string": query.encode(), + "headers": headers, + "client": ("testclient", 50000), + "server": (host, port), + } + + async def handle_async_request(self, request: Request) -> Response: + scope = self.parse_request(request=request) + if scope["type"] == "websocket": + scope.update( + subprotocols=[value.strip() for value in request.headers.get("sec-websocket-protocol", "").split(",")] + ) + session = WebSocketTestSession(client=self.client, scope=cast("WebSocketScope", scope)) # type: ignore[arg-type] + raise ConnectionUpgradeExceptionError(session) + + scope.update(method=request.method, http_version="1.1", extensions={"http.response.template": {}}) + + raw_kwargs: dict[str, Any] = {"stream": BytesIO()} + response_complete = Event() + context: SendReceiveContext = { + "response_complete": response_complete, + "request_complete": False, + "raw_kwargs": raw_kwargs, + "response_started": False, + "template": None, + "context": None, + } + + try: + await self.client.app( + scope, + self.create_receive(request=request, context=context), + self.create_send(request=request, context=context), + ) + except BaseException as exc: + if self.raise_server_exceptions: + raise exc + return Response( + status_code=HTTP_500_INTERNAL_SERVER_ERROR, headers=[], stream=ByteStream(b""), request=request + ) + else: + if not context["response_started"]: # pragma: no cover + if self.raise_server_exceptions: + assert context["response_started"], "TestClient did not receive any response." # noqa: S101 + return Response( + status_code=HTTP_500_INTERNAL_SERVER_ERROR, headers=[], stream=ByteStream(b""), request=request + ) + + stream = ByteStream(raw_kwargs.pop("stream", BytesIO()).read()) + response = Response(**raw_kwargs, stream=stream, request=request) + setattr(response, "template", context["template"]) + setattr(response, "context", context["context"]) + return response + + async def aclose(self) -> None: + """Close the transport.""" + + class TestClientTransport(Generic[T]): def __init__( self, @@ -138,7 +285,7 @@ def parse_request(self, request: Request) -> dict[str, Any]: "server": (host, port), } - def handle_request(self, request: Request) -> Response: + def handle_async_request(self, request: Request) -> Response: scope = self.parse_request(request=request) if scope["type"] == "websocket": scope.update( @@ -187,6 +334,3 @@ def handle_request(self, request: Request) -> Response: setattr(response, "template", context["template"]) setattr(response, "context", context["context"]) return response - - async def handle_async_request(self, request: Request) -> Response: - return self.handle_request(request=request) From 87f20f5cab1c492569d4f719f25976f716f65165 Mon Sep 17 00:00:00 2001 From: euri10 Date: Thu, 24 Apr 2025 09:23:29 +0200 Subject: [PATCH 03/14] useless changes --- litestar/testing/transport.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/litestar/testing/transport.py b/litestar/testing/transport.py index 26ac95ba8a..f21fc87fdc 100644 --- a/litestar/testing/transport.py +++ b/litestar/testing/transport.py @@ -285,7 +285,7 @@ def parse_request(self, request: Request) -> dict[str, Any]: "server": (host, port), } - def handle_async_request(self, request: Request) -> Response: + def handle_request(self, request: Request) -> Response: scope = self.parse_request(request=request) if scope["type"] == "websocket": scope.update( From ceeb8392b825b32a9f33e569d2b1efd2d52f5e37 Mon Sep 17 00:00:00 2001 From: euri10 Date: Thu, 24 Apr 2025 09:31:04 +0200 Subject: [PATCH 04/14] import --- tests/unit/test_testing/test_lifespan_handler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/test_testing/test_lifespan_handler.py b/tests/unit/test_testing/test_lifespan_handler.py index bd4bcd9f07..7d8934e14a 100644 --- a/tests/unit/test_testing/test_lifespan_handler.py +++ b/tests/unit/test_testing/test_lifespan_handler.py @@ -1,4 +1,4 @@ -import asyncio +from asyncio import get_event_loop import pytest @@ -43,7 +43,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: async def test_multiple_clients_event_loop() -> None: @get("/") def return_loop_id() -> dict: - return {"loop_id": id(asyncio.get_event_loop())} + return {"loop_id": id(get_event_loop())} app = Litestar(route_handlers=[return_loop_id]) From 57018574aaa2661ae8cc4df13d6b109f63391583 Mon Sep 17 00:00:00 2001 From: euri10 Date: Thu, 24 Apr 2025 10:02:56 +0200 Subject: [PATCH 05/14] small refactor --- litestar/testing/transport.py | 104 ++-------------------------------- 1 file changed, 6 insertions(+), 98 deletions(-) diff --git a/litestar/testing/transport.py b/litestar/testing/transport.py index f21fc87fdc..a8236356ad 100644 --- a/litestar/testing/transport.py +++ b/litestar/testing/transport.py @@ -43,7 +43,7 @@ class SendReceiveContext(TypedDict): context: Any | None -class AsyncTestClientTransport(Generic[T]): +class BaseTestClientTransport(Generic[T]): def __init__( self, client: T, @@ -138,13 +138,15 @@ def parse_request(self, request: Request) -> dict[str, Any]: "server": (host, port), } + +class AsyncTestClientTransport(BaseTestClientTransport): async def handle_async_request(self, request: Request) -> Response: scope = self.parse_request(request=request) if scope["type"] == "websocket": scope.update( subprotocols=[value.strip() for value in request.headers.get("sec-websocket-protocol", "").split(",")] ) - session = WebSocketTestSession(client=self.client, scope=cast("WebSocketScope", scope)) # type: ignore[arg-type] + session = WebSocketTestSession(client=self.client, scope=cast("WebSocketScope", scope)) raise ConnectionUpgradeExceptionError(session) scope.update(method=request.method, http_version="1.1", extensions={"http.response.template": {}}) @@ -190,108 +192,14 @@ async def aclose(self) -> None: """Close the transport.""" -class TestClientTransport(Generic[T]): - def __init__( - self, - client: T, - raise_server_exceptions: bool = True, - root_path: str = "", - ) -> None: - self.client = client - self.raise_server_exceptions = raise_server_exceptions - self.root_path = root_path - - @staticmethod - def create_receive(request: Request, context: SendReceiveContext) -> Receive: - async def receive() -> ReceiveMessage: - if context["request_complete"]: - if not context["response_complete"].is_set(): - await context["response_complete"].wait() - disconnect_event: HTTPDisconnectEvent = {"type": "http.disconnect"} - return disconnect_event - - body = cast("Union[bytes, str, GeneratorType]", (request.read() or b"")) - request_event: HTTPRequestEvent = {"type": "http.request", "body": b"", "more_body": False} - if isinstance(body, GeneratorType): # pragma: no cover - try: - chunk = body.send(None) - request_event["body"] = chunk if isinstance(chunk, bytes) else chunk.encode("utf-8") - request_event["more_body"] = True - except StopIteration: - context["request_complete"] = True - else: - context["request_complete"] = True - request_event["body"] = body if isinstance(body, bytes) else body.encode("utf-8") - return request_event - - return receive - - @staticmethod - def create_send(request: Request, context: SendReceiveContext) -> Send: - async def send(message: Message) -> None: - if message["type"] == "http.response.start": - assert not context["response_started"], 'Received multiple "http.response.start" messages.' # noqa: S101 - context["raw_kwargs"]["status_code"] = message["status"] - context["raw_kwargs"]["headers"] = [ - (k.decode("utf-8"), v.decode("utf-8")) for k, v in message.get("headers", []) - ] - context["response_started"] = True - elif message["type"] == "http.response.body": - assert context["response_started"], 'Received "http.response.body" without "http.response.start".' # noqa: S101 - assert not context[ # noqa: S101 - "response_complete" - ].is_set(), 'Received "http.response.body" after response completed.' - body = message.get("body", b"") - more_body = message.get("more_body", False) - if request.method != "HEAD": - context["raw_kwargs"]["stream"].write(body) - if not more_body: - context["raw_kwargs"]["stream"].seek(0) - context["response_complete"].set() - elif message["type"] == "http.response.template": # type: ignore[comparison-overlap] # pragma: no cover - context["template"] = message["template"] # type: ignore[unreachable] - context["context"] = message["context"] - - return send - - def parse_request(self, request: Request) -> dict[str, Any]: - scheme = request.url.scheme - netloc = unquote(request.url.netloc.decode(encoding="ascii")) - path = request.url.path - raw_path = request.url.raw_path - query = request.url.query.decode(encoding="ascii") - default_port = 433 if scheme in {"https", "wss"} else 80 - - if ":" in netloc: - host, port_string = netloc.split(":", 1) - port = int(port_string) - else: - host = netloc - port = default_port - - host_header = request.headers.pop("host", host if port == default_port else f"{host}:{port}") - - headers = [(k.lower().encode(), v.encode()) for k, v in (("host", host_header), *request.headers.items())] - - return { - "type": "websocket" if scheme in {"ws", "wss"} else "http", - "path": unquote(path), - "raw_path": raw_path, - "root_path": self.root_path, - "scheme": scheme, - "query_string": query.encode(), - "headers": headers, - "client": ("testclient", 50000), - "server": (host, port), - } - +class TestClientTransport(BaseTestClientTransport): def handle_request(self, request: Request) -> Response: scope = self.parse_request(request=request) if scope["type"] == "websocket": scope.update( subprotocols=[value.strip() for value in request.headers.get("sec-websocket-protocol", "").split(",")] ) - session = WebSocketTestSession(client=self.client, scope=cast("WebSocketScope", scope)) # type: ignore[arg-type] + session = WebSocketTestSession(client=self.client, scope=cast("WebSocketScope", scope)) raise ConnectionUpgradeExceptionError(session) scope.update(method=request.method, http_version="1.1", extensions={"http.response.template": {}}) From c6a4b2a73ac6ceb9914072a34fb5a521cc0d7a1d Mon Sep 17 00:00:00 2001 From: euri10 Date: Thu, 24 Apr 2025 10:18:26 +0200 Subject: [PATCH 06/14] small refactor 2 --- litestar/testing/transport.py | 25 ++++++++-------------- litestar/testing/websocket_test_session.py | 3 ++- 2 files changed, 11 insertions(+), 17 deletions(-) diff --git a/litestar/testing/transport.py b/litestar/testing/transport.py index a8236356ad..a18d7129bf 100644 --- a/litestar/testing/transport.py +++ b/litestar/testing/transport.py @@ -138,9 +138,7 @@ def parse_request(self, request: Request) -> dict[str, Any]: "server": (host, port), } - -class AsyncTestClientTransport(BaseTestClientTransport): - async def handle_async_request(self, request: Request) -> Response: + def _prepare_request(self, request: Request) -> tuple[dict[str, Any], dict[str, Any]]: scope = self.parse_request(request=request) if scope["type"] == "websocket": scope.update( @@ -148,10 +146,15 @@ async def handle_async_request(self, request: Request) -> Response: ) session = WebSocketTestSession(client=self.client, scope=cast("WebSocketScope", scope)) raise ConnectionUpgradeExceptionError(session) - scope.update(method=request.method, http_version="1.1", extensions={"http.response.template": {}}) - raw_kwargs: dict[str, Any] = {"stream": BytesIO()} + return raw_kwargs, scope + + +class AsyncTestClientTransport(BaseTestClientTransport): + async def handle_async_request(self, request: Request) -> Response: + raw_kwargs, scope = self._prepare_request(request) + response_complete = Event() context: SendReceiveContext = { "response_complete": response_complete, @@ -194,17 +197,7 @@ async def aclose(self) -> None: class TestClientTransport(BaseTestClientTransport): def handle_request(self, request: Request) -> Response: - scope = self.parse_request(request=request) - if scope["type"] == "websocket": - scope.update( - subprotocols=[value.strip() for value in request.headers.get("sec-websocket-protocol", "").split(",")] - ) - session = WebSocketTestSession(client=self.client, scope=cast("WebSocketScope", scope)) - raise ConnectionUpgradeExceptionError(session) - - scope.update(method=request.method, http_version="1.1", extensions={"http.response.template": {}}) - - raw_kwargs: dict[str, Any] = {"stream": BytesIO()} + raw_kwargs, scope = self._prepare_request(request) try: with self.client.portal() as portal: diff --git a/litestar/testing/websocket_test_session.py b/litestar/testing/websocket_test_session.py index 38901f733b..c295646986 100644 --- a/litestar/testing/websocket_test_session.py +++ b/litestar/testing/websocket_test_session.py @@ -11,6 +11,7 @@ from litestar.status_codes import WS_1000_NORMAL_CLOSURE if TYPE_CHECKING: + from litestar.testing.client.async_client import AsyncTestClient from litestar.testing.client.sync_client import TestClient from litestar.types import ( WebSocketConnectEvent, @@ -29,7 +30,7 @@ class WebSocketTestSession: def __init__( self, - client: TestClient[Any], + client: TestClient[Any] | AsyncTestClient[Any], scope: WebSocketScope, ) -> None: self.client = client From 15850993942d9c6a6ab2ff75ba55de6677e4f8aa Mon Sep 17 00:00:00 2001 From: euri10 Date: Thu, 24 Apr 2025 10:32:42 +0200 Subject: [PATCH 07/14] small refactor 3 --- litestar/testing/transport.py | 40 ++++++++++++++--------------------- 1 file changed, 16 insertions(+), 24 deletions(-) diff --git a/litestar/testing/transport.py b/litestar/testing/transport.py index a18d7129bf..1f515aea7e 100644 --- a/litestar/testing/transport.py +++ b/litestar/testing/transport.py @@ -150,6 +150,20 @@ def _prepare_request(self, request: Request) -> tuple[dict[str, Any], dict[str, raw_kwargs: dict[str, Any] = {"stream": BytesIO()} return raw_kwargs, scope + def _prepare_response(self, request: Request, context: SendReceiveContext, raw_kwargs: dict[str, Any]) -> Response: + if not context["response_started"]: # pragma: no cover + if self.raise_server_exceptions: + assert context["response_started"], "TestClient did not receive any response." # noqa: S101 + return Response( + status_code=HTTP_500_INTERNAL_SERVER_ERROR, headers=[], stream=ByteStream(b""), request=request + ) + + stream = ByteStream(raw_kwargs.pop("stream", BytesIO()).read()) + response = Response(**raw_kwargs, stream=stream, request=request) + setattr(response, "template", context["template"]) + setattr(response, "context", context["context"]) + return response + class AsyncTestClientTransport(BaseTestClientTransport): async def handle_async_request(self, request: Request) -> Response: @@ -178,18 +192,7 @@ async def handle_async_request(self, request: Request) -> Response: status_code=HTTP_500_INTERNAL_SERVER_ERROR, headers=[], stream=ByteStream(b""), request=request ) else: - if not context["response_started"]: # pragma: no cover - if self.raise_server_exceptions: - assert context["response_started"], "TestClient did not receive any response." # noqa: S101 - return Response( - status_code=HTTP_500_INTERNAL_SERVER_ERROR, headers=[], stream=ByteStream(b""), request=request - ) - - stream = ByteStream(raw_kwargs.pop("stream", BytesIO()).read()) - response = Response(**raw_kwargs, stream=stream, request=request) - setattr(response, "template", context["template"]) - setattr(response, "context", context["context"]) - return response + return self._prepare_response(request, context, raw_kwargs) async def aclose(self) -> None: """Close the transport.""" @@ -223,15 +226,4 @@ def handle_request(self, request: Request) -> Response: status_code=HTTP_500_INTERNAL_SERVER_ERROR, headers=[], stream=ByteStream(b""), request=request ) else: - if not context["response_started"]: # pragma: no cover - if self.raise_server_exceptions: - assert context["response_started"], "TestClient did not receive any response." # noqa: S101 - return Response( - status_code=HTTP_500_INTERNAL_SERVER_ERROR, headers=[], stream=ByteStream(b""), request=request - ) - - stream = ByteStream(raw_kwargs.pop("stream", BytesIO()).read()) - response = Response(**raw_kwargs, stream=stream, request=request) - setattr(response, "template", context["template"]) - setattr(response, "context", context["context"]) - return response + return self._prepare_response(request, context, raw_kwargs) From 28fa120149900558d2a943573645e54f4e989ff4 Mon Sep 17 00:00:00 2001 From: euri10 Date: Thu, 24 Apr 2025 10:35:31 +0200 Subject: [PATCH 08/14] leftover --- litestar/testing/transport.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/litestar/testing/transport.py b/litestar/testing/transport.py index 1f515aea7e..ab67e9dd86 100644 --- a/litestar/testing/transport.py +++ b/litestar/testing/transport.py @@ -194,9 +194,6 @@ async def handle_async_request(self, request: Request) -> Response: else: return self._prepare_response(request, context, raw_kwargs) - async def aclose(self) -> None: - """Close the transport.""" - class TestClientTransport(BaseTestClientTransport): def handle_request(self, request: Request) -> Response: From a507e1cbdf570fc428416281a8f862f893e067df Mon Sep 17 00:00:00 2001 From: euri10 Date: Thu, 24 Apr 2025 10:37:55 +0200 Subject: [PATCH 09/14] simplify --- litestar/testing/transport.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/litestar/testing/transport.py b/litestar/testing/transport.py index ab67e9dd86..44f85b60a6 100644 --- a/litestar/testing/transport.py +++ b/litestar/testing/transport.py @@ -168,18 +168,16 @@ def _prepare_response(self, request: Request, context: SendReceiveContext, raw_k class AsyncTestClientTransport(BaseTestClientTransport): async def handle_async_request(self, request: Request) -> Response: raw_kwargs, scope = self._prepare_request(request) - - response_complete = Event() - context: SendReceiveContext = { - "response_complete": response_complete, - "request_complete": False, - "raw_kwargs": raw_kwargs, - "response_started": False, - "template": None, - "context": None, - } - try: + response_complete = Event() + context: SendReceiveContext = { + "response_complete": response_complete, + "request_complete": False, + "raw_kwargs": raw_kwargs, + "response_started": False, + "template": None, + "context": None, + } await self.client.app( scope, self.create_receive(request=request, context=context), From b1d0fcceeb2c0af966252f0ee86073d0ce2c2bd1 Mon Sep 17 00:00:00 2001 From: euri10 Date: Thu, 24 Apr 2025 11:25:06 +0200 Subject: [PATCH 10/14] better type hinting --- litestar/testing/client/async_client.py | 2 +- litestar/testing/client/sync_client.py | 2 +- litestar/testing/transport.py | 6 +++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/litestar/testing/client/async_client.py b/litestar/testing/client/async_client.py index edce6b2e85..712276495a 100644 --- a/litestar/testing/client/async_client.py +++ b/litestar/testing/client/async_client.py @@ -74,7 +74,7 @@ def __init__( headers={"user-agent": "testclient"}, follow_redirects=True, cookies=cookies, - transport=AsyncTestClientTransport( # type: ignore [arg-type] + transport=AsyncTestClientTransport( client=self, raise_server_exceptions=raise_server_exceptions, root_path=root_path, diff --git a/litestar/testing/client/sync_client.py b/litestar/testing/client/sync_client.py index 9c58d139d2..656988ee38 100644 --- a/litestar/testing/client/sync_client.py +++ b/litestar/testing/client/sync_client.py @@ -75,7 +75,7 @@ def __init__( headers={"user-agent": "testclient"}, follow_redirects=True, cookies=cookies, - transport=TestClientTransport( # type: ignore[arg-type] + transport=TestClientTransport( client=self, raise_server_exceptions=raise_server_exceptions, root_path=root_path, diff --git a/litestar/testing/transport.py b/litestar/testing/transport.py index 44f85b60a6..f7fb949dfd 100644 --- a/litestar/testing/transport.py +++ b/litestar/testing/transport.py @@ -6,7 +6,7 @@ from urllib.parse import unquote from anyio import Event -from httpx import ByteStream, Response +from httpx import AsyncBaseTransport, BaseTransport, ByteStream, Response from litestar.status_codes import HTTP_500_INTERNAL_SERVER_ERROR from litestar.testing.websocket_test_session import WebSocketTestSession @@ -165,7 +165,7 @@ def _prepare_response(self, request: Request, context: SendReceiveContext, raw_k return response -class AsyncTestClientTransport(BaseTestClientTransport): +class AsyncTestClientTransport(AsyncBaseTransport, BaseTestClientTransport, Generic[T]): async def handle_async_request(self, request: Request) -> Response: raw_kwargs, scope = self._prepare_request(request) try: @@ -193,7 +193,7 @@ async def handle_async_request(self, request: Request) -> Response: return self._prepare_response(request, context, raw_kwargs) -class TestClientTransport(BaseTestClientTransport): +class TestClientTransport(BaseTransport, BaseTestClientTransport, Generic[T]): def handle_request(self, request: Request) -> Response: raw_kwargs, scope = self._prepare_request(request) From 13367dd176a50ed74c082f5db6755b452c10f4fd Mon Sep 17 00:00:00 2001 From: euri10 Date: Thu, 24 Apr 2025 12:00:29 +0200 Subject: [PATCH 11/14] missing coverage ? --- tests/unit/test_signature/test_validation.py | 27 +++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/tests/unit/test_signature/test_validation.py b/tests/unit/test_signature/test_validation.py index 23f1cb0318..31146d8e78 100644 --- a/tests/unit/test_signature/test_validation.py +++ b/tests/unit/test_signature/test_validation.py @@ -12,7 +12,7 @@ from litestar.exceptions import ImproperlyConfiguredException, ValidationException from litestar.params import Dependency, Parameter from litestar.status_codes import HTTP_400_BAD_REQUEST, HTTP_500_INTERNAL_SERVER_ERROR -from litestar.testing import RequestFactory, create_test_client +from litestar.testing import RequestFactory, create_async_test_client, create_test_client from litestar.utils.signature import ParsedSignature @@ -341,3 +341,28 @@ async def get_deserializer(deserializer: int) -> str: res = client.get("/deserializer") assert res.status_code == 500 assert "Expected `int`, got `str` - at `$.deserializer`" in res.text + + +async def test_separate_model_namespace_async_client() -> None: + async def provide_connection() -> str: + return "connection" + + @get("/connection", dependencies={"connection": provide_connection}) + async def get_connection(connection: str) -> str: + return connection + + async def provide_deserializer() -> str: + return "deserializer" + + @get("/deserializer", dependencies={"deserializer": provide_deserializer}) + async def get_deserializer(deserializer: int) -> str: + return deserializer # type: ignore[return-value] + + async with create_async_test_client( + [get_connection, get_deserializer], raise_server_exceptions=True, debug=True + ) as client: + c = await client.get("/connection") + assert c.text == "connection" + res = await client.get("/deserializer") + assert res.status_code == 500 + assert "Expected `int`, got `str` - at `$.deserializer`" in res.text From 4dffcb509ce4909b9c7d00a4bc35e4457944498e Mon Sep 17 00:00:00 2001 From: euri10 Date: Thu, 24 Apr 2025 16:24:14 +0200 Subject: [PATCH 12/14] still coverage --- tests/unit/test_signature/test_validation.py | 27 +------------------- 1 file changed, 1 insertion(+), 26 deletions(-) diff --git a/tests/unit/test_signature/test_validation.py b/tests/unit/test_signature/test_validation.py index 31146d8e78..23f1cb0318 100644 --- a/tests/unit/test_signature/test_validation.py +++ b/tests/unit/test_signature/test_validation.py @@ -12,7 +12,7 @@ from litestar.exceptions import ImproperlyConfiguredException, ValidationException from litestar.params import Dependency, Parameter from litestar.status_codes import HTTP_400_BAD_REQUEST, HTTP_500_INTERNAL_SERVER_ERROR -from litestar.testing import RequestFactory, create_async_test_client, create_test_client +from litestar.testing import RequestFactory, create_test_client from litestar.utils.signature import ParsedSignature @@ -341,28 +341,3 @@ async def get_deserializer(deserializer: int) -> str: res = client.get("/deserializer") assert res.status_code == 500 assert "Expected `int`, got `str` - at `$.deserializer`" in res.text - - -async def test_separate_model_namespace_async_client() -> None: - async def provide_connection() -> str: - return "connection" - - @get("/connection", dependencies={"connection": provide_connection}) - async def get_connection(connection: str) -> str: - return connection - - async def provide_deserializer() -> str: - return "deserializer" - - @get("/deserializer", dependencies={"deserializer": provide_deserializer}) - async def get_deserializer(deserializer: int) -> str: - return deserializer # type: ignore[return-value] - - async with create_async_test_client( - [get_connection, get_deserializer], raise_server_exceptions=True, debug=True - ) as client: - c = await client.get("/connection") - assert c.text == "connection" - res = await client.get("/deserializer") - assert res.status_code == 500 - assert "Expected `int`, got `str` - at `$.deserializer`" in res.text From eec69c187438d57febc0086902b11c8a39042f2a Mon Sep 17 00:00:00 2001 From: euri10 Date: Sat, 26 Apr 2025 14:33:03 +0200 Subject: [PATCH 13/14] add async lifespan handler test with test breaks 2 tests --- litestar/testing/client/async_client.py | 26 +--- litestar/testing/life_span_handler.py | 113 +++++++++++++++++- .../test_testing/test_lifespan_handler.py | 22 +++- 3 files changed, 137 insertions(+), 24 deletions(-) diff --git a/litestar/testing/client/async_client.py b/litestar/testing/client/async_client.py index 712276495a..d895251bab 100644 --- a/litestar/testing/client/async_client.py +++ b/litestar/testing/client/async_client.py @@ -1,12 +1,11 @@ from __future__ import annotations -from contextlib import AsyncExitStack from typing import TYPE_CHECKING, Any, Generic, Mapping, Sequence, TypeVar from httpx import USE_CLIENT_DEFAULT, AsyncClient from litestar.testing.client.base import BaseTestClient -from litestar.testing.life_span_handler import LifeSpanHandler +from litestar.testing.life_span_handler import AsyncLifeSpanHandler from litestar.testing.transport import AsyncTestClientTransport, ConnectionUpgradeExceptionError from litestar.types import AnyIOBackend, ASGIApp @@ -24,14 +23,10 @@ from litestar.middleware.session.base import BaseBackendConfig from litestar.testing.websocket_test_session import WebSocketTestSession - T = TypeVar("T", bound=ASGIApp) class AsyncTestClient(AsyncClient, BaseTestClient, Generic[T]): # type: ignore[misc] - lifespan_handler: LifeSpanHandler[Any] - exit_stack: AsyncExitStack - def __init__( self, app: T, @@ -83,24 +78,11 @@ def __init__( ) async def __aenter__(self) -> Self: - async with AsyncExitStack() as stack: - self.blocking_portal = portal = stack.enter_context(self.portal()) - self.lifespan_handler = LifeSpanHandler(client=self) - stack.enter_context(self.lifespan_handler) - - @stack.callback - def reset_portal() -> None: - delattr(self, "blocking_portal") - - @stack.callback - def wait_shutdown() -> None: - portal.call(self.lifespan_handler.wait_shutdown) - - self.exit_stack = stack.pop_all() - return self + async with AsyncLifeSpanHandler(client=self) as _: + return self async def __aexit__(self, *args: Any) -> None: - await self.exit_stack.aclose() + pass async def websocket_connect( self, diff --git a/litestar/testing/life_span_handler.py b/litestar/testing/life_span_handler.py index 7141f4eeb5..7d61efeba1 100644 --- a/litestar/testing/life_span_handler.py +++ b/litestar/testing/life_span_handler.py @@ -1,10 +1,12 @@ from __future__ import annotations +import contextlib import warnings from math import inf from typing import TYPE_CHECKING, Generic, Optional, TypeVar, cast -from anyio import create_memory_object_stream +import anyio +from anyio import TASK_STATUS_IGNORED, create_memory_object_stream from anyio.streams.stapled import StapledObjectStream from litestar.testing.client.base import BaseTestClient @@ -12,6 +14,8 @@ if TYPE_CHECKING: from types import TracebackType + from anyio.abc import TaskStatus + from litestar.types import ( LifeSpanReceiveMessage, # noqa: F401 LifeSpanSendMessage, @@ -128,3 +132,110 @@ async def lifespan(self) -> None: await self.client.app(scope, self.stream_receive.receive, self.stream_send.send) finally: await self.stream_send.send(None) + + +class AsyncLifeSpanHandler(Generic[T]): + __slots__ = ( + "_startup_done", + "client", + "stream_receive", + "stream_send", + "task", + ) + + def __init__(self, client: T) -> None: + self.client = client + self.stream_send = StapledObjectStream[Optional["LifeSpanSendMessage"]](*create_memory_object_stream(inf)) # type: ignore[arg-type] + self.stream_receive = StapledObjectStream["LifeSpanReceiveMessage"](*create_memory_object_stream(inf)) # type: ignore[arg-type] + self._startup_done = False + + async def _ensure_setup(self, is_safe: bool = False) -> None: + if self._startup_done: + return + if not is_safe: + warnings.warn( + "AsyncLifeSpanHandler used with implicit startup; Use AsyncLifeSpanHandler as a async context manager instead. " + "Implicit startup will be deprecated in version 3.0.", + category=DeprecationWarning, + stacklevel=2, + ) + + self._startup_done = True + async with anyio.create_task_group() as task_group: + await task_group.start(self.wait_startup) + self.task = task_group.start_soon(self.lifespan) + await task_group.start(self.wait_shutdown) + + async def aclose(self) -> None: + await self.stream_send.aclose() + await self.stream_receive.aclose() + + async def __aenter__(self) -> AsyncLifeSpanHandler: + try: + await self._ensure_setup(is_safe=True) + except Exception as exc: + await self.aclose() + raise exc + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + await self.aclose() + + async def receive(self) -> LifeSpanSendMessage: + await self._ensure_setup() + message = await self.stream_send.receive() + if message is None: + self.task.result() + return cast("LifeSpanSendMessage", message) + + async def wait_startup(self, *, task_status: TaskStatus[None]) -> None: + task_status.started() + await self._ensure_setup() + event: LifeSpanStartupEvent = {"type": "lifespan.startup"} + await self.stream_receive.send(event) + + message = await self.receive() + if message["type"] not in ( + "lifespan.startup.complete", + "lifespan.startup.failed", + ): + raise RuntimeError( + "Received unexpected ASGI message type. Expected 'lifespan.startup.complete' or " + f"'lifespan.startup.failed'. Got {message['type']!r}", + ) + if message["type"] == "lifespan.startup.failed": + await self.receive() + + async def wait_shutdown(self, *, task_status: TaskStatus[None]) -> None: + task_status.started() + await self._ensure_setup() + async with self.stream_send: + lifespan_shutdown_event: LifeSpanShutdownEvent = {"type": "lifespan.shutdown"} + await self.stream_receive.send(lifespan_shutdown_event) + + message = await self.receive() + if message["type"] not in ( + "lifespan.shutdown.complete", + "lifespan.shutdown.failed", + ): + raise RuntimeError( + "Received unexpected ASGI message type. Expected 'lifespan.shutdown.complete' or " + f"'lifespan.shutdown.failed'. Got {message['type']!r}", + ) + if message["type"] == "lifespan.shutdown.failed": + await self.receive() + + async def lifespan(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED) -> None: + task_status.started() + await self._ensure_setup() + scope = {"type": "lifespan"} + try: + await self.client.app(scope, self.stream_receive.receive, self.stream_send.send) + finally: + with contextlib.suppress(anyio.ClosedResourceError): + await self.stream_send.send(None) diff --git a/tests/unit/test_testing/test_lifespan_handler.py b/tests/unit/test_testing/test_lifespan_handler.py index 7d8934e14a..0a4bad972e 100644 --- a/tests/unit/test_testing/test_lifespan_handler.py +++ b/tests/unit/test_testing/test_lifespan_handler.py @@ -1,4 +1,8 @@ +import asyncio +import contextlib from asyncio import get_event_loop +from typing import AsyncGenerator +from unittest.mock import MagicMock import pytest @@ -51,4 +55,20 @@ def return_loop_id() -> dict: response_1 = await client_1.get("/") response_2 = await client_2.get("/") - assert response_1.json() == response_2.json() # FAILS + assert response_1.json() == response_2.json() + + +async def test_lifespan_loop() -> None: + mock = MagicMock() + + @contextlib.asynccontextmanager + async def lifespan(app: Litestar) -> AsyncGenerator[None, None]: + mock(asyncio.get_running_loop()) + yield + + app = Litestar(lifespan=[lifespan]) + + async with AsyncTestClient(app): + pass + + mock.assert_called_once_with(asyncio.get_running_loop()) From 18fe44a3264b6980fbd2352cfd186e4f5e5e77eb Mon Sep 17 00:00:00 2001 From: euri10 Date: Sat, 26 Apr 2025 15:02:28 +0200 Subject: [PATCH 14/14] fix test --- litestar/testing/life_span_handler.py | 4 ++-- tests/unit/test_testing/test_test_client.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/litestar/testing/life_span_handler.py b/litestar/testing/life_span_handler.py index 7d61efeba1..1b32eca4ff 100644 --- a/litestar/testing/life_span_handler.py +++ b/litestar/testing/life_span_handler.py @@ -163,7 +163,7 @@ async def _ensure_setup(self, is_safe: bool = False) -> None: self._startup_done = True async with anyio.create_task_group() as task_group: await task_group.start(self.wait_startup) - self.task = task_group.start_soon(self.lifespan) + self.task = await task_group.start(self.lifespan) await task_group.start(self.wait_shutdown) async def aclose(self) -> None: @@ -189,7 +189,7 @@ async def __aexit__( async def receive(self) -> LifeSpanSendMessage: await self._ensure_setup() message = await self.stream_send.receive() - if message is None: + if message is None and self.task: self.task.result() return cast("LifeSpanSendMessage", message) diff --git a/tests/unit/test_testing/test_test_client.py b/tests/unit/test_testing/test_test_client.py index f35b9ce0e9..e60feaeb6c 100644 --- a/tests/unit/test_testing/test_test_client.py +++ b/tests/unit/test_testing/test_test_client.py @@ -152,7 +152,7 @@ async def test_error_handling_on_startup( async def test_error_handling_on_shutdown( test_client_backend: "AnyIOBackend", test_client_cls: Type[AnyTestClient] ) -> None: - with pytest.raises(RuntimeError): + with pytest.raises(_ExceptionGroup): async with maybe_async_cm( test_client_cls(Litestar(on_shutdown=[raise_error]), backend=test_client_backend) # pyright: ignore ):