Skip to content

Commit 741b1e8

Browse files
committed
Fixed wrong use of ASGI app args
1 parent bd5c417 commit 741b1e8

File tree

6 files changed

+53
-14
lines changed

6 files changed

+53
-14
lines changed

ellar/core/context/factory.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from ellar.di import injectable
66
from ellar.di.exceptions import ServiceUnavailable
77
from ellar.services import Reflector
8-
from ellar.types import T
8+
from ellar.types import T, TReceive, TScope, TSend
99

1010
from .exceptions import HostContextException
1111
from .execution import ExecutionContext
@@ -97,13 +97,18 @@ class ExecutionContextFactory(IExecutionContextFactory):
9797
def __init__(self, reflector: Reflector) -> None:
9898
self.reflector = reflector
9999

100-
def create_context(self, operation: "RouteOperationBase") -> IExecutionContext:
100+
def create_context(
101+
self,
102+
operation: "RouteOperationBase",
103+
scope: TScope,
104+
receive: TReceive,
105+
send: TSend,
106+
) -> IExecutionContext:
101107
scoped_request_args = ASGI_CONTEXT_VAR.get()
102108

103109
if not scoped_request_args:
104110
raise ServiceUnavailable()
105111

106-
scope, receive, send = scoped_request_args.get_args()
107112
i_execution_context = ExecutionContext(
108113
scope=scope,
109114
receive=receive,

ellar/core/context/interface.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,5 +81,11 @@ def create_context(self) -> IHostContext:
8181

8282
class IExecutionContextFactory(ABC):
8383
@abstractmethod
84-
def create_context(self, operation: "RouteOperationBase") -> IExecutionContext:
84+
def create_context(
85+
self,
86+
operation: "RouteOperationBase",
87+
scope: TScope,
88+
receive: TReceive,
89+
send: TSend,
90+
) -> IExecutionContext:
8591
pass

ellar/core/middleware/exceptions.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22

33
from starlette.exceptions import HTTPException
44

5-
from ellar.core.connection import HTTPConnection
6-
from ellar.core.context import HostContext, IHostContext
5+
from ellar.constants import SCOPE_SERVICE_PROVIDER
6+
from ellar.core.context import IHostContext
77
from ellar.types import ASGIApp, TMessage, TReceive, TScope, TSend
88

99
if t.TYPE_CHECKING: # pragma: no cover
1010
from ellar.core.exceptions.service import ExceptionMiddlewareService
11+
from ellar.di import EllarInjector
1112

1213

1314
class ExceptionMiddleware:
@@ -34,6 +35,7 @@ async def sender(message: TMessage) -> None:
3435
if message["type"] == "http.response.start":
3536
response_started = True
3637
await send(message)
38+
return
3739

3840
try:
3941
await self.app(scope, receive, sender)
@@ -57,15 +59,17 @@ async def sender(message: TMessage) -> None:
5759
msg = "Caught handled exception, but response already started."
5860
raise RuntimeError(msg) from exc
5961

60-
connection = HTTPConnection(scope, receive)
62+
service_provider: "EllarInjector" = t.cast(
63+
"EllarInjector", scope[SCOPE_SERVICE_PROVIDER]
64+
)
6165

62-
if not connection.service_provider: # pragma: no cover
63-
context = HostContext(scope=scope, receive=receive, send=send)
64-
else:
65-
context = connection.service_provider.get(IHostContext)
66+
context = service_provider.get(IHostContext)
6667

6768
if context.get_type() == "http":
6869
response = await handler.catch(context, exc)
70+
if not response and not response_started:
71+
msg = "HTTP ExceptionHandler must return a response."
72+
raise RuntimeError(msg) from exc
6973
await response(scope, receive, sender)
7074
elif context.get_type() == "websocket":
7175
await handler.catch(context, exc)

ellar/core/routing/base.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,9 @@ async def app(self, scope: TScope, receive: TReceive, send: TSend) -> None:
6262
service_provider = t.cast(EllarInjector, scope[SCOPE_SERVICE_PROVIDER])
6363

6464
execution_context_factory = service_provider.get(IExecutionContextFactory)
65-
context = execution_context_factory.create_context(operation=self)
65+
context = execution_context_factory.create_context(
66+
operation=self, scope=scope, receive=receive, send=send
67+
)
6668

6769
await self.run_route_guards(context=context)
6870
await self._handle_request(context=context)

tests/test_application/test_replacing_app_services.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,12 @@ class NewExecutionHostFactory(IExecutionContextFactory):
4242
def __init__(self, reflector: Reflector):
4343
self.reflector = reflector
4444

45-
def create_context(self, operation) -> IExecutionContext:
45+
def create_context(self, operation, scope, receive, send) -> IExecutionContext:
4646
scoped_request_args = ASGI_CONTEXT_VAR.get()
4747

4848
if not scoped_request_args:
4949
raise ServiceUnavailable()
5050

51-
scope, receive, send = scoped_request_args.get_args()
5251
i_execution_context = NewExecutionContext(
5352
scope=scope,
5453
receive=receive,

tests/test_exceptions/test_custom_exceptions.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,15 @@ async def catch(
4949
return JSONResponse({"detail": "HttpException Override"}, status_code=400)
5050

5151

52+
class RuntimeHTTPException(IExceptionHandler):
53+
exception_type_or_code = HTTPException
54+
55+
async def catch(
56+
self, ctx: IHostContext, exc: t.Union[t.Any, Exception]
57+
) -> t.Union[Response, t.Any]:
58+
return None
59+
60+
5261
class ServerErrorHandler(IExceptionHandler):
5362
exception_type_or_code = 500
5463

@@ -234,6 +243,20 @@ def homepage():
234243
assert res.json() == {"detail": "HttpException Override"}
235244

236245

246+
def test_application_http_exception_handler_raise_exception_for_returning_none():
247+
@get()
248+
def homepage():
249+
raise HTTPException(detail="Bad Request", status_code=400)
250+
251+
tm = TestClientFactory.create_test_module()
252+
tm.app.router.append(homepage)
253+
tm.app.add_exception_handler(RuntimeHTTPException())
254+
with pytest.raises(
255+
RuntimeError, match="HTTP ExceptionHandler must return a response."
256+
):
257+
tm.get_client().get("/")
258+
259+
237260
def test_application_adding_same_exception_twice():
238261
tm = TestClientFactory.create_test_module()
239262
with patch.object(

0 commit comments

Comments
 (0)