2
2
3
3
from injector import CallableProvider
4
4
from starlette .middleware .errors import ServerErrorMiddleware
5
+ from starlette .requests import (
6
+ HTTPConnection as StarletteHTTPConnection ,
7
+ Request as StarletteRequest ,
8
+ )
9
+ from starlette .websockets import WebSocket as StarletteWebSocket
5
10
6
11
from ellar .constants import SCOPE_EXECUTION_CONTEXT_PROVIDER , SCOPE_SERVICE_PROVIDER
7
- from ellar .core .connection import HTTPConnection , Request , WebSocket
12
+ from ellar .core .connection .http import HTTPConnection , Request
13
+ from ellar .core .connection .websocket import WebSocket
8
14
from ellar .core .context import ExecutionContext , IExecutionContext
9
15
from ellar .core .response import Response
10
16
from ellar .types import ASGIApp , TReceive , TScope , TSend
11
17
12
18
if t .TYPE_CHECKING : # pragma: no cover
13
- from ellar .di .injector import EllarInjector
19
+ from ellar .di .injector import EllarInjector , RequestServiceProvider
14
20
15
21
16
22
class RequestServiceProviderMiddleware (ServerErrorMiddleware ):
@@ -27,6 +33,47 @@ def __init__(
27
33
)
28
34
self .injector = injector
29
35
36
+ @classmethod
37
+ def _register_connection (
38
+ cls , ctx : ExecutionContext , request_provider : "RequestServiceProvider"
39
+ ) -> None :
40
+ request_provider .update_context (
41
+ HTTPConnection ,
42
+ CallableProvider (ctx .switch_to_http_connection ),
43
+ )
44
+ request_provider .update_context (
45
+ StarletteHTTPConnection ,
46
+ CallableProvider (ctx .switch_to_http_connection ),
47
+ )
48
+
49
+ @classmethod
50
+ def _register_request (
51
+ cls , ctx : ExecutionContext , request_provider : "RequestServiceProvider"
52
+ ) -> None :
53
+ request_provider .update_context (
54
+ Request , CallableProvider (ctx .switch_to_request )
55
+ )
56
+ request_provider .update_context (
57
+ StarletteRequest , CallableProvider (ctx .switch_to_request )
58
+ )
59
+
60
+ @classmethod
61
+ def _register_response (
62
+ cls , ctx : ExecutionContext , request_provider : "RequestServiceProvider"
63
+ ) -> None :
64
+ request_provider .update_context (Response , CallableProvider (ctx .get_response ))
65
+
66
+ @classmethod
67
+ def _register_websocket (
68
+ cls , ctx : ExecutionContext , request_provider : "RequestServiceProvider"
69
+ ) -> None :
70
+ request_provider .update_context (
71
+ WebSocket , CallableProvider (ctx .switch_to_websocket )
72
+ )
73
+ request_provider .update_context (
74
+ StarletteWebSocket , CallableProvider (ctx .switch_to_websocket )
75
+ )
76
+
30
77
async def __call__ (self , scope : TScope , receive : TReceive , send : TSend ) -> None :
31
78
if scope ["type" ] not in ["http" , "websocket" ]: # pragma: no cover
32
79
await super ().__call__ (scope , receive , send )
@@ -37,20 +84,15 @@ async def __call__(self, scope: TScope, receive: TReceive, send: TSend) -> None:
37
84
request_provider .update_context (
38
85
t .cast (t .Type , IExecutionContext ), execute_context
39
86
)
40
-
41
- request_provider .update_context (
42
- HTTPConnection ,
43
- CallableProvider (execute_context .switch_to_http_connection ),
44
- )
45
87
request_provider .update_context (
46
- WebSocket , CallableProvider (execute_context .switch_to_websocket )
47
- )
48
- request_provider .update_context (
49
- Request , CallableProvider (execute_context .switch_to_request )
50
- )
51
- request_provider .update_context (
52
- Response , CallableProvider (execute_context .get_response )
88
+ t .cast (t .Type , ExecutionContext ), execute_context
53
89
)
90
+
91
+ self ._register_request (execute_context , request_provider )
92
+ self ._register_websocket (execute_context , request_provider )
93
+ self ._register_response (execute_context , request_provider )
94
+ self ._register_connection (execute_context , request_provider )
95
+
54
96
scope [SCOPE_SERVICE_PROVIDER ] = request_provider
55
97
scope [SCOPE_EXECUTION_CONTEXT_PROVIDER ] = execute_context
56
98
await super ().__call__ (scope , receive , send )
0 commit comments