2
2
import typing as t
3
3
4
4
from socketio import AsyncServer
5
+ from starlette import status
5
6
from starlette .concurrency import run_in_threadpool
7
+ from starlette .exceptions import WebSocketException
6
8
from starlette .routing import compile_path
7
9
8
10
from ellar .constants import (
12
14
NOT_SET ,
13
15
SCOPE_SERVICE_PROVIDER ,
14
16
)
15
- from ellar .core import IExecutionContext
17
+ from ellar .core import Config , IExecutionContext
16
18
from ellar .core .context import IExecutionContextFactory
17
19
from ellar .core .exceptions import WebSocketRequestValidationError
18
20
from ellar .core .params import WebsocketEndpointArgsModel
21
+ from ellar .core .serializer import serialize_object
19
22
from ellar .core .services import Reflector
20
23
from ellar .di import EllarInjector
21
24
from ellar .helper import get_name
22
25
from ellar .reflect import reflect
23
- from ellar .socket_io .model import GatewayBase , GatewayContext
26
+ from ellar .socket_io .context import GatewayContext
27
+ from ellar .socket_io .model import GatewayBase
24
28
from ellar .socket_io .responses import WsResponse
25
29
26
- if t .TYPE_CHECKING :
30
+ if t .TYPE_CHECKING : # pragma: no cover
27
31
from ellar .core import GuardCanActivate
28
32
from ellar .core .params import ExtraEndpointArg
29
33
@@ -61,7 +65,7 @@ def _load_model(self) -> None:
61
65
extra_route_args : t .Union [t .List ["ExtraEndpointArg" ], "ExtraEndpointArg" ] = (
62
66
reflect .get_metadata (EXTRA_ROUTE_ARGS_KEY , self .endpoint ) or []
63
67
)
64
- if not isinstance (extra_route_args , list ):
68
+ if not isinstance (extra_route_args , list ): # pragma: no cover
65
69
extra_route_args = [extra_route_args ]
66
70
67
71
if self .endpoint_parameter_model is NOT_SET :
@@ -73,6 +77,39 @@ def _load_model(self) -> None:
73
77
)
74
78
self .endpoint_parameter_model .build_model ()
75
79
80
+ async def _run_with_exception_handling (
81
+ self , gateway_instance : GatewayBase , sid : str
82
+ ) -> None :
83
+ config = gateway_instance .context .get_service_provider ().get (Config )
84
+ try :
85
+ await self .run_route_guards (context = gateway_instance .context )
86
+ await self ._run_handler (
87
+ context = gateway_instance .context , gateway_instance = gateway_instance
88
+ )
89
+
90
+ except WebSocketException as aex :
91
+ await self ._handle_error (
92
+ sid = sid ,
93
+ code = aex .code ,
94
+ reason = serialize_object (aex .reason ),
95
+ )
96
+ except WebSocketRequestValidationError as wex :
97
+ await self ._handle_error (
98
+ sid = sid ,
99
+ code = status .WS_1007_INVALID_FRAME_PAYLOAD_DATA ,
100
+ reason = serialize_object (wex .errors ()),
101
+ )
102
+ except Exception as ex :
103
+ await self ._handle_error (
104
+ sid = sid ,
105
+ code = status .WS_1011_INTERNAL_ERROR ,
106
+ reason = str (ex ) if config .DEBUG else "Something went wrong" ,
107
+ )
108
+
109
+ async def _handle_error (self , sid : str , code : int , reason : t .Any ) -> None :
110
+ await self ._server .emit ("error" , {"code" : code , "reason" : reason }, room = sid )
111
+ await self ._server .disconnect (sid = sid )
112
+
76
113
async def _context_handler (self , sid : str , environment : t .Dict ) -> t .Any :
77
114
service_provider = t .cast (
78
115
EllarInjector , environment ["asgi.scope" ][SCOPE_SERVICE_PROVIDER ]
@@ -86,9 +123,7 @@ async def _context_handler(self, sid: str, environment: t.Dict) -> t.Any:
86
123
send = environment ["asgi.send" ],
87
124
)
88
125
gateway_instance = self ._get_gateway_instance_and_context (ctx = context , sid = sid )
89
-
90
- await self .run_route_guards (context = context )
91
- await self ._run_handler (context = context , gateway_instance = gateway_instance )
126
+ await self ._run_with_exception_handling (gateway_instance , sid = sid )
92
127
93
128
def _register_handler (self ) -> None :
94
129
@self ._server .on (self ._event )
@@ -129,7 +164,7 @@ async def _run_handler(
129
164
await self ._server .emit (** res .dict ())
130
165
131
166
@t .no_type_check
132
- async def run_route_guards (self , context : IExecutionContext ) -> None :
167
+ async def run_route_guards (self , context : GatewayContext ) -> None :
133
168
reflector = context .get_service_provider ().get (Reflector )
134
169
app = context .get_app ()
135
170
@@ -149,7 +184,15 @@ async def run_route_guards(self, context: IExecutionContext) -> None:
149
184
150
185
result = await guard .can_activate (context )
151
186
if not result :
152
- guard .raise_exception ()
187
+ await self ._server .emit (
188
+ "error" ,
189
+ {
190
+ "code" : status .WS_1011_INTERNAL_ERROR ,
191
+ "reason" : "Authorization Failed" ,
192
+ },
193
+ room = context .sid ,
194
+ )
195
+ await self ._server .disconnect (sid = context .sid )
153
196
154
197
def get_control_type (self ) -> t .Type [GatewayBase ]:
155
198
"""
@@ -159,19 +202,17 @@ def get_control_type(self) -> t.Type[GatewayBase]:
159
202
"""
160
203
if not hasattr (self , "_control_type" ):
161
204
_control_type = reflect .get_metadata (CONTROLLER_CLASS_KEY , self .endpoint )
162
- if _control_type is None :
205
+ if _control_type is None : # pragma: no cover
163
206
raise Exception ("Operation must have a single control type." )
164
207
self ._control_type = t .cast (t .Type [GatewayBase ], _control_type )
165
208
166
209
return self ._control_type
167
210
168
211
def _get_gateway_instance (self , ctx : IExecutionContext ) -> GatewayBase :
169
- gateway_type : t .Optional [t .Type ] = reflect .get_metadata (
170
- CONTROLLER_CLASS_KEY , self .endpoint
171
- )
212
+ gateway_type = self .get_control_type ()
172
213
if not gateway_type or (
173
214
gateway_type and not issubclass (gateway_type , GatewayBase )
174
- ):
215
+ ): # pragma: no cover
175
216
raise RuntimeError ("GatewayBase Type was not found" )
176
217
177
218
service_provider = ctx .get_service_provider ()
@@ -197,8 +238,6 @@ def _get_gateway_instance_and_context(
197
238
class SocketMessageOperation (SocketOperationConnection ):
198
239
async def _context_handler (self , sid : str , message : t .Any ) -> t .Any :
199
240
sid_environ = self ._server .get_environ (sid )
200
- if not sid_environ :
201
- raise Exception ("Socket Environment not found." )
202
241
203
242
service_provider = t .cast (
204
243
EllarInjector , sid_environ ["asgi.scope" ][SCOPE_SERVICE_PROVIDER ]
@@ -216,8 +255,7 @@ async def _context_handler(self, sid: str, message: t.Any) -> t.Any:
216
255
ctx = context , sid = sid , message = message
217
256
)
218
257
219
- await self .run_route_guards (context = context )
220
- await self ._run_handler (context = context , gateway_instance = gateway_instance )
258
+ await self ._run_with_exception_handling (gateway_instance , sid )
221
259
222
260
def _register_handler (self ) -> None :
223
261
@self ._server .on (self ._event )
0 commit comments