Skip to content

Commit 5cc4ddf

Browse files
Kludexgraingert
andauthored
Raise exception from background task on BaseHTTPMiddleware (#2812)
Co-authored-by: Thomas Grainger <tagrain@gmail.com>
1 parent f13d354 commit 5cc4ddf

File tree

3 files changed

+35
-11
lines changed

3 files changed

+35
-11
lines changed

starlette/middleware/base.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -103,10 +103,9 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
103103
request = _CachedRequest(scope, receive)
104104
wrapped_receive = request.wrapped_receive
105105
response_sent = anyio.Event()
106+
app_exc: Exception | None = None
106107

107108
async def call_next(request: Request) -> Response:
108-
app_exc: Exception | None = None
109-
110109
async def receive_or_disconnect() -> Message:
111110
if response_sent.is_set():
112111
return {"type": "http.disconnect"}
@@ -165,9 +164,6 @@ async def body_stream() -> typing.AsyncGenerator[bytes, None]:
165164
if not message.get("more_body", False):
166165
break
167166

168-
if app_exc is not None:
169-
raise app_exc
170-
171167
response = _StreamingResponse(status_code=message["status"], content=body_stream(), info=info)
172168
response.raw_headers = message["headers"]
173169
return response
@@ -181,6 +177,9 @@ async def body_stream() -> typing.AsyncGenerator[bytes, None]:
181177
response_sent.set()
182178
recv_stream.close()
183179

180+
if app_exc is not None:
181+
raise app_exc
182+
184183
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
185184
raise NotImplementedError() # pragma: no cover
186185

starlette/responses.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import anyio
1919
import anyio.to_thread
2020

21+
from starlette._utils import collapse_excgroups
2122
from starlette.background import BackgroundTask
2223
from starlette.concurrency import iterate_in_threadpool
2324
from starlette.datastructures import URL, Headers, MutableHeaders
@@ -258,14 +259,15 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
258259
except OSError:
259260
raise ClientDisconnect()
260261
else:
261-
async with anyio.create_task_group() as task_group:
262+
with collapse_excgroups():
263+
async with anyio.create_task_group() as task_group:
262264

263-
async def wrap(func: typing.Callable[[], typing.Awaitable[None]]) -> None:
264-
await func()
265-
task_group.cancel_scope.cancel()
265+
async def wrap(func: typing.Callable[[], typing.Awaitable[None]]) -> None:
266+
await func()
267+
task_group.cancel_scope.cancel()
266268

267-
task_group.start_soon(wrap, partial(self.stream_response, send))
268-
await wrap(partial(self.listen_for_disconnect, receive))
269+
task_group.start_soon(wrap, partial(self.stream_response, send))
270+
await wrap(partial(self.listen_for_disconnect, receive))
269271

270272
if self.background is not None:
271273
await self.background()

tests/middleware/test_base.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,29 @@ async def send(message: Message) -> None:
297297
assert background_task_run.is_set()
298298

299299

300+
def test_run_background_tasks_raise_exceptions(test_client_factory: TestClientFactory) -> None:
301+
# test for https://github.com/encode/starlette/issues/2625
302+
303+
async def sleep_and_set() -> None:
304+
await anyio.sleep(0.1)
305+
raise ValueError("TEST")
306+
307+
async def endpoint_with_background_task(_: Request) -> PlainTextResponse:
308+
return PlainTextResponse(background=BackgroundTask(sleep_and_set))
309+
310+
async def passthrough(request: Request, call_next: RequestResponseEndpoint) -> Response:
311+
return await call_next(request)
312+
313+
app = Starlette(
314+
middleware=[Middleware(BaseHTTPMiddleware, dispatch=passthrough)],
315+
routes=[Route("/", endpoint_with_background_task)],
316+
)
317+
318+
client = test_client_factory(app)
319+
with pytest.raises(ValueError, match="TEST"):
320+
client.get("/")
321+
322+
300323
@pytest.mark.anyio
301324
async def test_do_not_block_on_background_tasks() -> None:
302325
response_complete = anyio.Event()

0 commit comments

Comments
 (0)