Skip to content

fix: add a AsyncTestClientTransport for the AsyncTestClient #4142

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 14 commits into
base: main
Choose a base branch
from
4 changes: 2 additions & 2 deletions litestar/testing/client/async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from litestar.testing.client.base import BaseTestClient
from litestar.testing.life_span_handler import LifeSpanHandler
from litestar.testing.transport import ConnectionUpgradeExceptionError, TestClientTransport
from litestar.testing.transport import AsyncTestClientTransport, ConnectionUpgradeExceptionError
from litestar.types import AnyIOBackend, ASGIApp

if TYPE_CHECKING:
Expand Down Expand Up @@ -74,7 +74,7 @@ def __init__(
headers={"user-agent": "testclient"},
follow_redirects=True,
cookies=cookies,
transport=TestClientTransport( # type: ignore [arg-type]
transport=AsyncTestClientTransport(
client=self,
raise_server_exceptions=raise_server_exceptions,
root_path=root_path,
Expand Down
2 changes: 1 addition & 1 deletion litestar/testing/client/sync_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def __init__(
headers={"user-agent": "testclient"},
follow_redirects=True,
cookies=cookies,
transport=TestClientTransport( # type: ignore[arg-type]
transport=TestClientTransport(
client=self,
raise_server_exceptions=raise_server_exceptions,
root_path=root_path,
Expand Down
74 changes: 53 additions & 21 deletions litestar/testing/transport.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With this refactoring done, I think you actually don't even need 2 transports. The sync version could just do

with self.client.portal() as portal:
    return portal.call(self.handle_async_request, request)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the benefit of having 2 transports is that they respectively inherit their httpx counterparts so I think it's a little bit more mypy friendly but I may be wrong

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, that makes sense. Maybe could still move the request handling fn into the common base class?

Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from urllib.parse import unquote

from anyio import Event
from httpx import ByteStream, Response
from httpx import AsyncBaseTransport, BaseTransport, ByteStream, Response

from litestar.status_codes import HTTP_500_INTERNAL_SERVER_ERROR
from litestar.testing.websocket_test_session import WebSocketTestSession
Expand Down Expand Up @@ -43,7 +43,7 @@
context: Any | None


class TestClientTransport(Generic[T]):
class BaseTestClientTransport(Generic[T]):
def __init__(
self,
client: T,
Expand Down Expand Up @@ -138,18 +138,64 @@
"server": (host, port),
}

def handle_request(self, request: Request) -> Response:
def _prepare_request(self, request: Request) -> tuple[dict[str, Any], dict[str, Any]]:
scope = self.parse_request(request=request)
if scope["type"] == "websocket":
scope.update(
subprotocols=[value.strip() for value in request.headers.get("sec-websocket-protocol", "").split(",")]
)
session = WebSocketTestSession(client=self.client, scope=cast("WebSocketScope", scope)) # type: ignore[arg-type]
session = WebSocketTestSession(client=self.client, scope=cast("WebSocketScope", scope))
raise ConnectionUpgradeExceptionError(session)

scope.update(method=request.method, http_version="1.1", extensions={"http.response.template": {}})

raw_kwargs: dict[str, Any] = {"stream": BytesIO()}
return raw_kwargs, scope

def _prepare_response(self, request: Request, context: SendReceiveContext, raw_kwargs: dict[str, Any]) -> Response:
if not context["response_started"]: # pragma: no cover
if self.raise_server_exceptions:
assert context["response_started"], "TestClient did not receive any response." # noqa: S101
return Response(
status_code=HTTP_500_INTERNAL_SERVER_ERROR, headers=[], stream=ByteStream(b""), request=request
)

stream = ByteStream(raw_kwargs.pop("stream", BytesIO()).read())
response = Response(**raw_kwargs, stream=stream, request=request)
setattr(response, "template", context["template"])
setattr(response, "context", context["context"])
return response


class AsyncTestClientTransport(AsyncBaseTransport, BaseTestClientTransport, Generic[T]):
async def handle_async_request(self, request: Request) -> Response:
raw_kwargs, scope = self._prepare_request(request)
try:
response_complete = Event()
context: SendReceiveContext = {
"response_complete": response_complete,
"request_complete": False,
"raw_kwargs": raw_kwargs,
"response_started": False,
"template": None,
"context": None,
}
await self.client.app(
scope,
self.create_receive(request=request, context=context),
self.create_send(request=request, context=context),
)
except BaseException as exc:

Check warning on line 186 in litestar/testing/transport.py

View check run for this annotation

Codecov / codecov/patch

litestar/testing/transport.py#L186

Added line #L186 was not covered by tests
if self.raise_server_exceptions:
raise exc
return Response(

Check warning on line 189 in litestar/testing/transport.py

View check run for this annotation

Codecov / codecov/patch

litestar/testing/transport.py#L188-L189

Added lines #L188 - L189 were not covered by tests
status_code=HTTP_500_INTERNAL_SERVER_ERROR, headers=[], stream=ByteStream(b""), request=request
)
else:
return self._prepare_response(request, context, raw_kwargs)


class TestClientTransport(BaseTransport, BaseTestClientTransport, Generic[T]):
def handle_request(self, request: Request) -> Response:
raw_kwargs, scope = self._prepare_request(request)

try:
with self.client.portal() as portal:
Expand All @@ -175,18 +221,4 @@
status_code=HTTP_500_INTERNAL_SERVER_ERROR, headers=[], stream=ByteStream(b""), request=request
)
else:
if not context["response_started"]: # pragma: no cover
if self.raise_server_exceptions:
assert context["response_started"], "TestClient did not receive any response." # noqa: S101
return Response(
status_code=HTTP_500_INTERNAL_SERVER_ERROR, headers=[], stream=ByteStream(b""), request=request
)

stream = ByteStream(raw_kwargs.pop("stream", BytesIO()).read())
response = Response(**raw_kwargs, stream=stream, request=request)
setattr(response, "template", context["template"])
setattr(response, "context", context["context"])
return response

async def handle_async_request(self, request: Request) -> Response:
return self.handle_request(request=request)
return self._prepare_response(request, context, raw_kwargs)
3 changes: 2 additions & 1 deletion litestar/testing/websocket_test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from litestar.status_codes import WS_1000_NORMAL_CLOSURE

if TYPE_CHECKING:
from litestar.testing.client.async_client import AsyncTestClient
from litestar.testing.client.sync_client import TestClient
from litestar.types import (
WebSocketConnectEvent,
Expand All @@ -29,7 +30,7 @@ class WebSocketTestSession:

def __init__(
self,
client: TestClient[Any],
client: TestClient[Any] | AsyncTestClient[Any],
scope: WebSocketScope,
) -> None:
self.client = client
Expand Down
19 changes: 18 additions & 1 deletion tests/unit/test_testing/test_lifespan_handler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from asyncio import get_event_loop

import pytest

from litestar.testing import TestClient
from litestar import Litestar, get
from litestar.testing import AsyncTestClient, TestClient
from litestar.testing.life_span_handler import LifeSpanHandler
from litestar.types import Receive, Scope, Send

Expand Down Expand Up @@ -35,3 +38,17 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
handler = LifeSpanHandler(TestClient(app))
await handler.wait_shutdown()
handler.close()


async def test_multiple_clients_event_loop() -> None:
@get("/")
def return_loop_id() -> dict:
return {"loop_id": id(get_event_loop())}

app = Litestar(route_handlers=[return_loop_id])

async with AsyncTestClient(app) as client_1, AsyncTestClient(app) as client_2:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should even be the same loop with different client instances, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think so, this added test is just the mcve from the issue and it fails before the changes, is there something I should add ?

Copy link
Contributor Author

@euri10 euri10 Apr 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realize the diff is overly complicate, the PR is simple: before AsyncTestClient used TestClientTransport hence the issue.
so I created a BaseClientTransport that essentially takes care of request/response and added a ASGITestClientTransport that inherits httpx ASGIBaseTransport and overrides handle_async_request while TestClientTransport overrides handle_request, the BaseClientTransport is what's common between the 2 transports.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think so, this added test is just the mcve from the issue and it fails before the changes, is there something I should add ?

Yeah, maybe a test for the transport in particular that shows the client always uses the currently running loop?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually if this test was using TestClient instead of AsyncTestClient it wouldn't pass as start_blocking_portal creates a new event loop.

This should even be the same loop with different client instances, right?

by different client instances you meant a AsyncTestClient and a TestClient ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No I mean different instances AsyncClient. But actually, I think the test should probably just show that the client uses the running loop from the context it's created in.

response_1 = await client_1.get("/")
response_2 = await client_2.get("/")

assert response_1.json() == response_2.json() # FAILS
Loading