diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py index 2a59337e5..b49ab611f 100644 --- a/starlette/middleware/base.py +++ b/starlette/middleware/base.py @@ -104,6 +104,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: wrapped_receive = request.wrapped_receive response_sent = anyio.Event() app_exc: Exception | None = None + exception_already_raised = False async def call_next(request: Request) -> Response: async def receive_or_disconnect() -> Message: @@ -150,6 +151,8 @@ async def coro() -> None: message = await recv_stream.receive() except anyio.EndOfStream: if app_exc is not None: + nonlocal exception_already_raised + exception_already_raised = True raise app_exc raise RuntimeError("No response returned.") @@ -176,8 +179,7 @@ async def body_stream() -> typing.AsyncGenerator[bytes, None]: await response(scope, wrapped_receive, send) response_sent.set() recv_stream.close() - - if app_exc is not None: + if app_exc is not None and not exception_already_raised: raise app_exc async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py index e4e82077f..427ec44ac 100644 --- a/tests/middleware/test_base.py +++ b/tests/middleware/test_base.py @@ -320,6 +320,27 @@ async def passthrough(request: Request, call_next: RequestResponseEndpoint) -> R client.get("/") +def test_exception_can_be_caught(test_client_factory: TestClientFactory) -> None: + async def error_endpoint(_: Request) -> None: + raise ValueError("TEST") + + async def catches_error(request: Request, call_next: RequestResponseEndpoint) -> Response: + try: + return await call_next(request) + except ValueError as exc: + return PlainTextResponse(content=str(exc), status_code=400) + + app = Starlette( + middleware=[Middleware(BaseHTTPMiddleware, dispatch=catches_error)], + routes=[Route("/", error_endpoint)], + ) + + client = test_client_factory(app) + response = client.get("/") + assert response.status_code == 400 + assert response.text == "TEST" + + @pytest.mark.anyio async def test_do_not_block_on_background_tasks() -> None: response_complete = anyio.Event()