Skip to content

Commit cd1befa

Browse files
committed
fixed websocket route function computation
1 parent 18d1390 commit cd1befa

File tree

4 files changed

+20
-4
lines changed

4 files changed

+20
-4
lines changed

ellar/common/decorators/controller.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import inspect
22
import typing as t
33
from abc import ABC
4+
from types import FunctionType
45

56
from ellar.compatible import AttributeDict
67
from ellar.constants import (
@@ -14,16 +15,21 @@
1415
from ellar.core import ControllerBase
1516
from ellar.core.controller import ControllerType
1617
from ellar.core.exceptions import ImproperConfiguration
18+
from ellar.core.routing.controller import ControllerRouteOperationBase
1719
from ellar.di import RequestScope, injectable
1820
from ellar.reflect import reflect
1921

2022
if t.TYPE_CHECKING: # pragma: no cover
2123
from ellar.core.guard import GuardCanActivate
2224

2325

24-
def get_route_functions(cls: t.Type) -> t.Iterable[t.Callable]:
26+
def get_route_functions(
27+
cls: t.Type,
28+
) -> t.Iterable[t.Union[t.Callable, ControllerRouteOperationBase]]:
2529
for method in cls.__dict__.values():
26-
if hasattr(method, OPERATION_ENDPOINT_KEY):
30+
if hasattr(method, OPERATION_ENDPOINT_KEY) or isinstance(
31+
method, ControllerRouteOperationBase
32+
):
2733
yield method
2834

2935

@@ -33,7 +39,11 @@ def reflect_all_controller_type_routes(cls: t.Type[ControllerBase]) -> None:
3339
for base_cls in reversed(bases):
3440
if base_cls not in [ABC, ControllerBase, object]:
3541
for item in get_route_functions(base_cls):
36-
operation = reflect.get_metadata(CONTROLLER_OPERATION_HANDLER_KEY, item)
42+
operation = item
43+
if callable(item) and type(item) == FunctionType:
44+
operation = reflect.get_metadata( # type: ignore
45+
CONTROLLER_OPERATION_HANDLER_KEY, item
46+
)
3747
reflect.define_metadata(CONTROLLER_CLASS_KEY, cls, item)
3848
reflect.define_metadata(
3949
CONTROLLER_OPERATION_HANDLER_KEY,

ellar/core/routing/controller/websocket/route.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import typing as t
22

33
from starlette.status import WS_1008_POLICY_VIOLATION
4+
from starlette.websockets import WebSocketState
45

56
from ellar.core.context import IExecutionContext
67
from ellar.core.exceptions import WebSocketRequestValidationError
@@ -27,6 +28,8 @@ async def _handle_request(self, context: IExecutionContext) -> None:
2728
if errors:
2829
websocket = context.switch_to_websocket().get_client()
2930
exc = WebSocketRequestValidationError(errors)
31+
if websocket.client_state == WebSocketState.CONNECTING:
32+
await websocket.accept()
3033
await websocket.send_json(
3134
dict(code=WS_1008_POLICY_VIOLATION, errors=exc.errors())
3235
)

ellar/core/routing/operation_definitions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def _get_ws_operation(
7272
_operation = _ws_operation_class(**ws_route_parameters.dict())
7373
setattr(ws_route_parameters.endpoint, OPERATION_ENDPOINT_KEY, True)
7474
if self._routes is not None and not isinstance(
75-
_operation, ControllerRouteOperationBase
75+
_operation, ControllerWebsocketRouteOperation
7676
):
7777
self._routes.append(_operation)
7878
return _operation

ellar/core/routing/websocket/route.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from starlette.routing import WebSocketRoute as StarletteWebSocketRoute, compile_path
44
from starlette.status import WS_1008_POLICY_VIOLATION
5+
from starlette.websockets import WebSocketState
56

67
from ellar.constants import (
78
CONTROLLER_OPERATION_HANDLER_KEY,
@@ -105,6 +106,8 @@ async def _handle_request(self, context: IExecutionContext) -> None:
105106
if errors:
106107
websocket = context.switch_to_websocket().get_client()
107108
exc = WebSocketRequestValidationError(errors)
109+
if websocket.client_state == WebSocketState.CONNECTING:
110+
await websocket.accept()
108111
await websocket.send_json(
109112
dict(code=WS_1008_POLICY_VIOLATION, errors=exc.errors())
110113
)

0 commit comments

Comments
 (0)