Skip to content

Commit b010303

Browse files
committed
Added exception handling for message process handling
1 parent 9ab3b40 commit b010303

File tree

1 file changed

+56
-18
lines changed

1 file changed

+56
-18
lines changed

ellar/socket_io/gateway.py

Lines changed: 56 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22
import typing as t
33

44
from socketio import AsyncServer
5+
from starlette import status
56
from starlette.concurrency import run_in_threadpool
7+
from starlette.exceptions import WebSocketException
68
from starlette.routing import compile_path
79

810
from ellar.constants import (
@@ -12,18 +14,20 @@
1214
NOT_SET,
1315
SCOPE_SERVICE_PROVIDER,
1416
)
15-
from ellar.core import IExecutionContext
17+
from ellar.core import Config, IExecutionContext
1618
from ellar.core.context import IExecutionContextFactory
1719
from ellar.core.exceptions import WebSocketRequestValidationError
1820
from ellar.core.params import WebsocketEndpointArgsModel
21+
from ellar.core.serializer import serialize_object
1922
from ellar.core.services import Reflector
2023
from ellar.di import EllarInjector
2124
from ellar.helper import get_name
2225
from ellar.reflect import reflect
23-
from ellar.socket_io.model import GatewayBase, GatewayContext
26+
from ellar.socket_io.context import GatewayContext
27+
from ellar.socket_io.model import GatewayBase
2428
from ellar.socket_io.responses import WsResponse
2529

26-
if t.TYPE_CHECKING:
30+
if t.TYPE_CHECKING: # pragma: no cover
2731
from ellar.core import GuardCanActivate
2832
from ellar.core.params import ExtraEndpointArg
2933

@@ -61,7 +65,7 @@ def _load_model(self) -> None:
6165
extra_route_args: t.Union[t.List["ExtraEndpointArg"], "ExtraEndpointArg"] = (
6266
reflect.get_metadata(EXTRA_ROUTE_ARGS_KEY, self.endpoint) or []
6367
)
64-
if not isinstance(extra_route_args, list):
68+
if not isinstance(extra_route_args, list): # pragma: no cover
6569
extra_route_args = [extra_route_args]
6670

6771
if self.endpoint_parameter_model is NOT_SET:
@@ -73,6 +77,39 @@ def _load_model(self) -> None:
7377
)
7478
self.endpoint_parameter_model.build_model()
7579

80+
async def _run_with_exception_handling(
81+
self, gateway_instance: GatewayBase, sid: str
82+
) -> None:
83+
config = gateway_instance.context.get_service_provider().get(Config)
84+
try:
85+
await self.run_route_guards(context=gateway_instance.context)
86+
await self._run_handler(
87+
context=gateway_instance.context, gateway_instance=gateway_instance
88+
)
89+
90+
except WebSocketException as aex:
91+
await self._handle_error(
92+
sid=sid,
93+
code=aex.code,
94+
reason=serialize_object(aex.reason),
95+
)
96+
except WebSocketRequestValidationError as wex:
97+
await self._handle_error(
98+
sid=sid,
99+
code=status.WS_1007_INVALID_FRAME_PAYLOAD_DATA,
100+
reason=serialize_object(wex.errors()),
101+
)
102+
except Exception as ex:
103+
await self._handle_error(
104+
sid=sid,
105+
code=status.WS_1011_INTERNAL_ERROR,
106+
reason=str(ex) if config.DEBUG else "Something went wrong",
107+
)
108+
109+
async def _handle_error(self, sid: str, code: int, reason: t.Any) -> None:
110+
await self._server.emit("error", {"code": code, "reason": reason}, room=sid)
111+
await self._server.disconnect(sid=sid)
112+
76113
async def _context_handler(self, sid: str, environment: t.Dict) -> t.Any:
77114
service_provider = t.cast(
78115
EllarInjector, environment["asgi.scope"][SCOPE_SERVICE_PROVIDER]
@@ -86,9 +123,7 @@ async def _context_handler(self, sid: str, environment: t.Dict) -> t.Any:
86123
send=environment["asgi.send"],
87124
)
88125
gateway_instance = self._get_gateway_instance_and_context(ctx=context, sid=sid)
89-
90-
await self.run_route_guards(context=context)
91-
await self._run_handler(context=context, gateway_instance=gateway_instance)
126+
await self._run_with_exception_handling(gateway_instance, sid=sid)
92127

93128
def _register_handler(self) -> None:
94129
@self._server.on(self._event)
@@ -129,7 +164,7 @@ async def _run_handler(
129164
await self._server.emit(**res.dict())
130165

131166
@t.no_type_check
132-
async def run_route_guards(self, context: IExecutionContext) -> None:
167+
async def run_route_guards(self, context: GatewayContext) -> None:
133168
reflector = context.get_service_provider().get(Reflector)
134169
app = context.get_app()
135170

@@ -149,7 +184,15 @@ async def run_route_guards(self, context: IExecutionContext) -> None:
149184

150185
result = await guard.can_activate(context)
151186
if not result:
152-
guard.raise_exception()
187+
await self._server.emit(
188+
"error",
189+
{
190+
"code": status.WS_1011_INTERNAL_ERROR,
191+
"reason": "Authorization Failed",
192+
},
193+
room=context.sid,
194+
)
195+
await self._server.disconnect(sid=context.sid)
153196

154197
def get_control_type(self) -> t.Type[GatewayBase]:
155198
"""
@@ -159,19 +202,17 @@ def get_control_type(self) -> t.Type[GatewayBase]:
159202
"""
160203
if not hasattr(self, "_control_type"):
161204
_control_type = reflect.get_metadata(CONTROLLER_CLASS_KEY, self.endpoint)
162-
if _control_type is None:
205+
if _control_type is None: # pragma: no cover
163206
raise Exception("Operation must have a single control type.")
164207
self._control_type = t.cast(t.Type[GatewayBase], _control_type)
165208

166209
return self._control_type
167210

168211
def _get_gateway_instance(self, ctx: IExecutionContext) -> GatewayBase:
169-
gateway_type: t.Optional[t.Type] = reflect.get_metadata(
170-
CONTROLLER_CLASS_KEY, self.endpoint
171-
)
212+
gateway_type = self.get_control_type()
172213
if not gateway_type or (
173214
gateway_type and not issubclass(gateway_type, GatewayBase)
174-
):
215+
): # pragma: no cover
175216
raise RuntimeError("GatewayBase Type was not found")
176217

177218
service_provider = ctx.get_service_provider()
@@ -197,8 +238,6 @@ def _get_gateway_instance_and_context(
197238
class SocketMessageOperation(SocketOperationConnection):
198239
async def _context_handler(self, sid: str, message: t.Any) -> t.Any:
199240
sid_environ = self._server.get_environ(sid)
200-
if not sid_environ:
201-
raise Exception("Socket Environment not found.")
202241

203242
service_provider = t.cast(
204243
EllarInjector, sid_environ["asgi.scope"][SCOPE_SERVICE_PROVIDER]
@@ -216,8 +255,7 @@ async def _context_handler(self, sid: str, message: t.Any) -> t.Any:
216255
ctx=context, sid=sid, message=message
217256
)
218257

219-
await self.run_route_guards(context=context)
220-
await self._run_handler(context=context, gateway_instance=gateway_instance)
258+
await self._run_with_exception_handling(gateway_instance, sid)
221259

222260
def _register_handler(self) -> None:
223261
@self._server.on(self._event)

0 commit comments

Comments
 (0)