Skip to content

fix: add a AsyncTestClientTransport for the AsyncTestClient #4142

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 14 commits into
base: main
Choose a base branch
from
30 changes: 6 additions & 24 deletions litestar/testing/client/async_client.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from __future__ import annotations

from contextlib import AsyncExitStack
from typing import TYPE_CHECKING, Any, Generic, Mapping, Sequence, TypeVar

from httpx import USE_CLIENT_DEFAULT, AsyncClient

from litestar.testing.client.base import BaseTestClient
from litestar.testing.life_span_handler import LifeSpanHandler
from litestar.testing.transport import ConnectionUpgradeExceptionError, TestClientTransport
from litestar.testing.life_span_handler import AsyncLifeSpanHandler
from litestar.testing.transport import AsyncTestClientTransport, ConnectionUpgradeExceptionError
from litestar.types import AnyIOBackend, ASGIApp

if TYPE_CHECKING:
Expand All @@ -24,14 +23,10 @@
from litestar.middleware.session.base import BaseBackendConfig
from litestar.testing.websocket_test_session import WebSocketTestSession


T = TypeVar("T", bound=ASGIApp)


class AsyncTestClient(AsyncClient, BaseTestClient, Generic[T]): # type: ignore[misc]
lifespan_handler: LifeSpanHandler[Any]
exit_stack: AsyncExitStack

def __init__(
self,
app: T,
Expand Down Expand Up @@ -74,7 +69,7 @@ def __init__(
headers={"user-agent": "testclient"},
follow_redirects=True,
cookies=cookies,
transport=TestClientTransport( # type: ignore [arg-type]
transport=AsyncTestClientTransport(
client=self,
raise_server_exceptions=raise_server_exceptions,
root_path=root_path,
Expand All @@ -83,24 +78,11 @@ def __init__(
)

async def __aenter__(self) -> Self:
async with AsyncExitStack() as stack:
self.blocking_portal = portal = stack.enter_context(self.portal())
self.lifespan_handler = LifeSpanHandler(client=self)
stack.enter_context(self.lifespan_handler)

@stack.callback
def reset_portal() -> None:
delattr(self, "blocking_portal")

@stack.callback
def wait_shutdown() -> None:
portal.call(self.lifespan_handler.wait_shutdown)

self.exit_stack = stack.pop_all()
return self
async with AsyncLifeSpanHandler(client=self) as _:
return self

async def __aexit__(self, *args: Any) -> None:
await self.exit_stack.aclose()
pass

async def websocket_connect(
self,
Expand Down
2 changes: 1 addition & 1 deletion litestar/testing/client/sync_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def __init__(
headers={"user-agent": "testclient"},
follow_redirects=True,
cookies=cookies,
transport=TestClientTransport( # type: ignore[arg-type]
transport=TestClientTransport(
client=self,
raise_server_exceptions=raise_server_exceptions,
root_path=root_path,
Expand Down
113 changes: 112 additions & 1 deletion litestar/testing/life_span_handler.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
from __future__ import annotations

import contextlib
import warnings
from math import inf
from typing import TYPE_CHECKING, Generic, Optional, TypeVar, cast

from anyio import create_memory_object_stream
import anyio
from anyio import TASK_STATUS_IGNORED, create_memory_object_stream
from anyio.streams.stapled import StapledObjectStream

from litestar.testing.client.base import BaseTestClient

if TYPE_CHECKING:
from types import TracebackType

from anyio.abc import TaskStatus

from litestar.types import (
LifeSpanReceiveMessage, # noqa: F401
LifeSpanSendMessage,
Expand Down Expand Up @@ -128,3 +132,110 @@ async def lifespan(self) -> None:
await self.client.app(scope, self.stream_receive.receive, self.stream_send.send)
finally:
await self.stream_send.send(None)


class AsyncLifeSpanHandler(Generic[T]):
__slots__ = (
"_startup_done",
"client",
"stream_receive",
"stream_send",
"task",
)

def __init__(self, client: T) -> None:
self.client = client
self.stream_send = StapledObjectStream[Optional["LifeSpanSendMessage"]](*create_memory_object_stream(inf)) # type: ignore[arg-type]
self.stream_receive = StapledObjectStream["LifeSpanReceiveMessage"](*create_memory_object_stream(inf)) # type: ignore[arg-type]
self._startup_done = False

async def _ensure_setup(self, is_safe: bool = False) -> None:
if self._startup_done:
return
if not is_safe:
warnings.warn(
"AsyncLifeSpanHandler used with implicit startup; Use AsyncLifeSpanHandler as a async context manager instead. "
"Implicit startup will be deprecated in version 3.0.",
category=DeprecationWarning,
stacklevel=2,
)

self._startup_done = True
async with anyio.create_task_group() as task_group:
await task_group.start(self.wait_startup)
self.task = await task_group.start(self.lifespan)
await task_group.start(self.wait_shutdown)

async def aclose(self) -> None:
await self.stream_send.aclose()
await self.stream_receive.aclose()

async def __aenter__(self) -> AsyncLifeSpanHandler:
try:
await self._ensure_setup(is_safe=True)
except Exception as exc:
await self.aclose()
raise exc
return self

async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> None:
await self.aclose()

async def receive(self) -> LifeSpanSendMessage:
await self._ensure_setup()
message = await self.stream_send.receive()
if message is None and self.task:
self.task.result()
return cast("LifeSpanSendMessage", message)

async def wait_startup(self, *, task_status: TaskStatus[None]) -> None:
task_status.started()
await self._ensure_setup()
event: LifeSpanStartupEvent = {"type": "lifespan.startup"}
await self.stream_receive.send(event)

message = await self.receive()
if message["type"] not in (
"lifespan.startup.complete",
"lifespan.startup.failed",
):
raise RuntimeError(
"Received unexpected ASGI message type. Expected 'lifespan.startup.complete' or "
f"'lifespan.startup.failed'. Got {message['type']!r}",
)
if message["type"] == "lifespan.startup.failed":
await self.receive()

async def wait_shutdown(self, *, task_status: TaskStatus[None]) -> None:
task_status.started()
await self._ensure_setup()
async with self.stream_send:
lifespan_shutdown_event: LifeSpanShutdownEvent = {"type": "lifespan.shutdown"}
await self.stream_receive.send(lifespan_shutdown_event)

message = await self.receive()
if message["type"] not in (
"lifespan.shutdown.complete",
"lifespan.shutdown.failed",
):
raise RuntimeError(
"Received unexpected ASGI message type. Expected 'lifespan.shutdown.complete' or "
f"'lifespan.shutdown.failed'. Got {message['type']!r}",
)
if message["type"] == "lifespan.shutdown.failed":
await self.receive()

async def lifespan(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED) -> None:
task_status.started()
await self._ensure_setup()
scope = {"type": "lifespan"}
try:
await self.client.app(scope, self.stream_receive.receive, self.stream_send.send)
finally:
with contextlib.suppress(anyio.ClosedResourceError):
await self.stream_send.send(None)
74 changes: 53 additions & 21 deletions litestar/testing/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from urllib.parse import unquote

from anyio import Event
from httpx import ByteStream, Response
from httpx import AsyncBaseTransport, BaseTransport, ByteStream, Response

from litestar.status_codes import HTTP_500_INTERNAL_SERVER_ERROR
from litestar.testing.websocket_test_session import WebSocketTestSession
Expand Down Expand Up @@ -43,7 +43,7 @@ class SendReceiveContext(TypedDict):
context: Any | None


class TestClientTransport(Generic[T]):
class BaseTestClientTransport(Generic[T]):
def __init__(
self,
client: T,
Expand Down Expand Up @@ -138,18 +138,64 @@ def parse_request(self, request: Request) -> dict[str, Any]:
"server": (host, port),
}

def handle_request(self, request: Request) -> Response:
def _prepare_request(self, request: Request) -> tuple[dict[str, Any], dict[str, Any]]:
scope = self.parse_request(request=request)
if scope["type"] == "websocket":
scope.update(
subprotocols=[value.strip() for value in request.headers.get("sec-websocket-protocol", "").split(",")]
)
session = WebSocketTestSession(client=self.client, scope=cast("WebSocketScope", scope)) # type: ignore[arg-type]
session = WebSocketTestSession(client=self.client, scope=cast("WebSocketScope", scope))
raise ConnectionUpgradeExceptionError(session)

scope.update(method=request.method, http_version="1.1", extensions={"http.response.template": {}})

raw_kwargs: dict[str, Any] = {"stream": BytesIO()}
return raw_kwargs, scope

def _prepare_response(self, request: Request, context: SendReceiveContext, raw_kwargs: dict[str, Any]) -> Response:
if not context["response_started"]: # pragma: no cover
if self.raise_server_exceptions:
assert context["response_started"], "TestClient did not receive any response." # noqa: S101
return Response(
status_code=HTTP_500_INTERNAL_SERVER_ERROR, headers=[], stream=ByteStream(b""), request=request
)

stream = ByteStream(raw_kwargs.pop("stream", BytesIO()).read())
response = Response(**raw_kwargs, stream=stream, request=request)
setattr(response, "template", context["template"])
setattr(response, "context", context["context"])
return response


class AsyncTestClientTransport(AsyncBaseTransport, BaseTestClientTransport, Generic[T]):
async def handle_async_request(self, request: Request) -> Response:
raw_kwargs, scope = self._prepare_request(request)
try:
response_complete = Event()
context: SendReceiveContext = {
"response_complete": response_complete,
"request_complete": False,
"raw_kwargs": raw_kwargs,
"response_started": False,
"template": None,
"context": None,
}
await self.client.app(
scope,
self.create_receive(request=request, context=context),
self.create_send(request=request, context=context),
)
except BaseException as exc:
if self.raise_server_exceptions:
raise exc
return Response(
status_code=HTTP_500_INTERNAL_SERVER_ERROR, headers=[], stream=ByteStream(b""), request=request
)
else:
return self._prepare_response(request, context, raw_kwargs)


class TestClientTransport(BaseTransport, BaseTestClientTransport, Generic[T]):
def handle_request(self, request: Request) -> Response:
raw_kwargs, scope = self._prepare_request(request)

try:
with self.client.portal() as portal:
Expand All @@ -175,18 +221,4 @@ def handle_request(self, request: Request) -> Response:
status_code=HTTP_500_INTERNAL_SERVER_ERROR, headers=[], stream=ByteStream(b""), request=request
)
else:
if not context["response_started"]: # pragma: no cover
if self.raise_server_exceptions:
assert context["response_started"], "TestClient did not receive any response." # noqa: S101
return Response(
status_code=HTTP_500_INTERNAL_SERVER_ERROR, headers=[], stream=ByteStream(b""), request=request
)

stream = ByteStream(raw_kwargs.pop("stream", BytesIO()).read())
response = Response(**raw_kwargs, stream=stream, request=request)
setattr(response, "template", context["template"])
setattr(response, "context", context["context"])
return response

async def handle_async_request(self, request: Request) -> Response:
return self.handle_request(request=request)
return self._prepare_response(request, context, raw_kwargs)
3 changes: 2 additions & 1 deletion litestar/testing/websocket_test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from litestar.status_codes import WS_1000_NORMAL_CLOSURE

if TYPE_CHECKING:
from litestar.testing.client.async_client import AsyncTestClient
from litestar.testing.client.sync_client import TestClient
from litestar.types import (
WebSocketConnectEvent,
Expand All @@ -29,7 +30,7 @@ class WebSocketTestSession:

def __init__(
self,
client: TestClient[Any],
client: TestClient[Any] | AsyncTestClient[Any],
scope: WebSocketScope,
) -> None:
self.client = client
Expand Down
39 changes: 38 additions & 1 deletion tests/unit/test_testing/test_lifespan_handler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
import asyncio
import contextlib
from asyncio import get_event_loop
from typing import AsyncGenerator
from unittest.mock import MagicMock

import pytest

from litestar.testing import TestClient
from litestar import Litestar, get
from litestar.testing import AsyncTestClient, TestClient
from litestar.testing.life_span_handler import LifeSpanHandler
from litestar.types import Receive, Scope, Send

Expand Down Expand Up @@ -35,3 +42,33 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
handler = LifeSpanHandler(TestClient(app))
await handler.wait_shutdown()
handler.close()


async def test_multiple_clients_event_loop() -> None:
@get("/")
def return_loop_id() -> dict:
return {"loop_id": id(get_event_loop())}

app = Litestar(route_handlers=[return_loop_id])

async with AsyncTestClient(app) as client_1, AsyncTestClient(app) as client_2:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should even be the same loop with different client instances, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think so, this added test is just the mcve from the issue and it fails before the changes, is there something I should add ?

Copy link
Contributor Author

@euri10 euri10 Apr 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realize the diff is overly complicate, the PR is simple: before AsyncTestClient used TestClientTransport hence the issue.
so I created a BaseClientTransport that essentially takes care of request/response and added a ASGITestClientTransport that inherits httpx ASGIBaseTransport and overrides handle_async_request while TestClientTransport overrides handle_request, the BaseClientTransport is what's common between the 2 transports.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think so, this added test is just the mcve from the issue and it fails before the changes, is there something I should add ?

Yeah, maybe a test for the transport in particular that shows the client always uses the currently running loop?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually if this test was using TestClient instead of AsyncTestClient it wouldn't pass as start_blocking_portal creates a new event loop.

This should even be the same loop with different client instances, right?

by different client instances you meant a AsyncTestClient and a TestClient ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No I mean different instances AsyncClient. But actually, I think the test should probably just show that the client uses the running loop from the context it's created in.

response_1 = await client_1.get("/")
response_2 = await client_2.get("/")

assert response_1.json() == response_2.json()


async def test_lifespan_loop() -> None:
mock = MagicMock()

@contextlib.asynccontextmanager
async def lifespan(app: Litestar) -> AsyncGenerator[None, None]:
mock(asyncio.get_running_loop())
yield

app = Litestar(lifespan=[lifespan])

async with AsyncTestClient(app):
pass

mock.assert_called_once_with(asyncio.get_running_loop())
2 changes: 1 addition & 1 deletion tests/unit/test_testing/test_test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ async def test_error_handling_on_startup(
async def test_error_handling_on_shutdown(
test_client_backend: "AnyIOBackend", test_client_cls: Type[AnyTestClient]
) -> None:
with pytest.raises(RuntimeError):
with pytest.raises(_ExceptionGroup):
async with maybe_async_cm(
test_client_cls(Litestar(on_shutdown=[raise_error]), backend=test_client_backend) # pyright: ignore
):
Expand Down
Loading