Skip to content

Commit 6a0981c

Browse files
committed
Fixed Starlette HTTP/WsConnection to resolve to Ellars Http/WsConnection
1 parent ce1241a commit 6a0981c

File tree

1 file changed

+56
-14
lines changed
  • ellar/core/middleware

1 file changed

+56
-14
lines changed

ellar/core/middleware/di.py

Lines changed: 56 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,21 @@
22

33
from injector import CallableProvider
44
from starlette.middleware.errors import ServerErrorMiddleware
5+
from starlette.requests import (
6+
HTTPConnection as StarletteHTTPConnection,
7+
Request as StarletteRequest,
8+
)
9+
from starlette.websockets import WebSocket as StarletteWebSocket
510

611
from ellar.constants import SCOPE_EXECUTION_CONTEXT_PROVIDER, SCOPE_SERVICE_PROVIDER
7-
from ellar.core.connection import HTTPConnection, Request, WebSocket
12+
from ellar.core.connection.http import HTTPConnection, Request
13+
from ellar.core.connection.websocket import WebSocket
814
from ellar.core.context import ExecutionContext, IExecutionContext
915
from ellar.core.response import Response
1016
from ellar.types import ASGIApp, TReceive, TScope, TSend
1117

1218
if t.TYPE_CHECKING: # pragma: no cover
13-
from ellar.di.injector import EllarInjector
19+
from ellar.di.injector import EllarInjector, RequestServiceProvider
1420

1521

1622
class RequestServiceProviderMiddleware(ServerErrorMiddleware):
@@ -27,6 +33,47 @@ def __init__(
2733
)
2834
self.injector = injector
2935

36+
@classmethod
37+
def _register_connection(
38+
cls, ctx: ExecutionContext, request_provider: "RequestServiceProvider"
39+
) -> None:
40+
request_provider.update_context(
41+
HTTPConnection,
42+
CallableProvider(ctx.switch_to_http_connection),
43+
)
44+
request_provider.update_context(
45+
StarletteHTTPConnection,
46+
CallableProvider(ctx.switch_to_http_connection),
47+
)
48+
49+
@classmethod
50+
def _register_request(
51+
cls, ctx: ExecutionContext, request_provider: "RequestServiceProvider"
52+
) -> None:
53+
request_provider.update_context(
54+
Request, CallableProvider(ctx.switch_to_request)
55+
)
56+
request_provider.update_context(
57+
StarletteRequest, CallableProvider(ctx.switch_to_request)
58+
)
59+
60+
@classmethod
61+
def _register_response(
62+
cls, ctx: ExecutionContext, request_provider: "RequestServiceProvider"
63+
) -> None:
64+
request_provider.update_context(Response, CallableProvider(ctx.get_response))
65+
66+
@classmethod
67+
def _register_websocket(
68+
cls, ctx: ExecutionContext, request_provider: "RequestServiceProvider"
69+
) -> None:
70+
request_provider.update_context(
71+
WebSocket, CallableProvider(ctx.switch_to_websocket)
72+
)
73+
request_provider.update_context(
74+
StarletteWebSocket, CallableProvider(ctx.switch_to_websocket)
75+
)
76+
3077
async def __call__(self, scope: TScope, receive: TReceive, send: TSend) -> None:
3178
if scope["type"] not in ["http", "websocket"]: # pragma: no cover
3279
await super().__call__(scope, receive, send)
@@ -37,20 +84,15 @@ async def __call__(self, scope: TScope, receive: TReceive, send: TSend) -> None:
3784
request_provider.update_context(
3885
t.cast(t.Type, IExecutionContext), execute_context
3986
)
40-
41-
request_provider.update_context(
42-
HTTPConnection,
43-
CallableProvider(execute_context.switch_to_http_connection),
44-
)
4587
request_provider.update_context(
46-
WebSocket, CallableProvider(execute_context.switch_to_websocket)
47-
)
48-
request_provider.update_context(
49-
Request, CallableProvider(execute_context.switch_to_request)
50-
)
51-
request_provider.update_context(
52-
Response, CallableProvider(execute_context.get_response)
88+
t.cast(t.Type, ExecutionContext), execute_context
5389
)
90+
91+
self._register_request(execute_context, request_provider)
92+
self._register_websocket(execute_context, request_provider)
93+
self._register_response(execute_context, request_provider)
94+
self._register_connection(execute_context, request_provider)
95+
5496
scope[SCOPE_SERVICE_PROVIDER] = request_provider
5597
scope[SCOPE_EXECUTION_CONTEXT_PROVIDER] = execute_context
5698
await super().__call__(scope, receive, send)

0 commit comments

Comments
 (0)