Skip to content

Commit 432b32a

Browse files
committed
Added testing
1 parent 7d4a1a7 commit 432b32a

File tree

7 files changed

+568
-3
lines changed

7 files changed

+568
-3
lines changed

docs/overview/custom_decorators.md

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,14 @@ from ellar.core import WebSocket
4646
def sample_endpoint_ws(self, websocket: WebSocket, data_schema: str = WsBody()):
4747
pass
4848

49-
@sample_endpoint_ws.connect
49+
@ws_route.connect(sample_endpoint_ws)
5050
async def on_connect(self, websocket: WebSocket):
51+
# Called when there is a connection to `sample_endpoint_ws`
5152
await websocket.accept()
5253

53-
@sample_endpoint_ws.disconnect
54-
async def on_connect(self, websocket: WebSocket, code: int):
54+
@ws_route.disconnect(sample_endpoint_ws)
55+
async def on_disconnect(self, websocket: WebSocket, code: int):
56+
# Called when there is a disconnect from `sample_endpoint_ws`
5557
await websocket.close(code)
5658
```
5759

tests/test_socket_io/__init__.py

Whitespace-only changes.

tests/test_socket_io/sample.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
import typing as t
2+
from functools import wraps
3+
4+
from starlette import status
5+
from starlette.exceptions import WebSocketException
6+
7+
from ellar.common import Guards, Header, Query, WsBody, extra_args
8+
from ellar.core.connection import HTTPConnection
9+
from ellar.core.guard import APIKeyHeader
10+
from ellar.core.params import ExtraEndpointArg
11+
from ellar.core.serializer import Serializer
12+
from ellar.di import injectable
13+
from ellar.socket_io import (
14+
WebSocketGateway,
15+
WsResponse,
16+
on_connected,
17+
on_disconnected,
18+
subscribe_message,
19+
)
20+
from ellar.socket_io.model import GatewayBase
21+
22+
23+
def add_extra_args(func):
24+
# EXTRA ARGS SETUP
25+
query1 = ExtraEndpointArg(name="query1", annotation=str, default_value=Query())
26+
query2 = ExtraEndpointArg(
27+
name="query2", annotation=str
28+
) # will default to Query during computation
29+
30+
extra_args(query1, query2)(func)
31+
32+
@wraps(func)
33+
def _wrapper(*args, **kwargs):
34+
# RESOLVING EXTRA ARGS
35+
# All extra args must be resolved before calling route function
36+
# else extra argument will be pushed to the route function
37+
resolved_query1 = query1.resolve(kwargs)
38+
resolved_query2 = query2.resolve(kwargs)
39+
40+
return func(*args, **kwargs, query2=resolved_query2, query1=resolved_query1)
41+
42+
return _wrapper
43+
44+
45+
@injectable()
46+
class HeaderGuard(APIKeyHeader):
47+
parameter_name = "x-auth-key"
48+
49+
async def authenticate(
50+
self, connection: HTTPConnection, key: t.Optional[t.Any]
51+
) -> t.Optional[t.Any]:
52+
if key == "supersecret":
53+
return key
54+
55+
56+
class MessageData(Serializer):
57+
data: t.Any
58+
59+
60+
class MessageRoom(Serializer):
61+
room: str
62+
63+
64+
class MessageChatRoom(Serializer):
65+
room: str
66+
data: t.Any
67+
68+
69+
@WebSocketGateway(path="/ws", async_mode="asgi", cors_allowed_origins="*")
70+
class EventGateway:
71+
@subscribe_message("my_event")
72+
async def my_event(self, message: MessageData = WsBody()):
73+
return WsResponse("my_response", {"data": message.data}, room=self.context.sid)
74+
75+
@subscribe_message
76+
async def my_broadcast_event(self, message: MessageData = WsBody()):
77+
await self.context.server.emit("my_response", {"data": message.data})
78+
79+
@on_connected()
80+
async def connect(self):
81+
await self.context.server.emit(
82+
"my_response", {"data": "Connected", "count": 0}, room=self.context.sid
83+
)
84+
85+
@on_disconnected()
86+
async def disconnect(self):
87+
print("Client disconnected")
88+
89+
90+
@Guards(HeaderGuard)
91+
@WebSocketGateway(path="/ws-guard")
92+
class GatewayWithGuards(GatewayBase):
93+
@subscribe_message("my_event")
94+
async def my_event(self, message: MessageData = WsBody()):
95+
return WsResponse(
96+
"my_response",
97+
{
98+
"data": message.data,
99+
"auth-key": self.context.switch_to_http_connection().get_client().user,
100+
},
101+
)
102+
103+
@subscribe_message("my_event_header")
104+
async def my_event_header(
105+
self,
106+
data: str = WsBody(..., embed=True),
107+
x_auth_key: str = Header(alias="x-auth-key"),
108+
):
109+
return WsResponse("my_response", {"data": data, "x_auth_key": x_auth_key})
110+
111+
112+
@WebSocketGateway(path="/ws-others")
113+
class GatewayOthers(GatewayBase):
114+
@subscribe_message("my_event")
115+
async def my_event(self, message: MessageData = WsBody()):
116+
raise Exception("I dont have anything to run.")
117+
118+
@subscribe_message("my_event_raise")
119+
async def my_event_raise(self, message: MessageData = WsBody()):
120+
raise WebSocketException(
121+
code=status.WS_1009_MESSAGE_TOO_BIG, reason="Message is too big"
122+
)
123+
124+
@subscribe_message("extra_args")
125+
@add_extra_args
126+
def extra_args_handler(self, query1: str, query2: str):
127+
raise WebSocketException(
128+
code=status.WS_1009_MESSAGE_TOO_BIG,
129+
reason={"query1": query1, "query2": query2},
130+
)
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from ellar.reflect import reflect
2+
from ellar.socket_io import on_connected, on_disconnected
3+
from ellar.socket_io.constants import (
4+
CONNECTION_EVENT,
5+
DISCONNECT_EVENT,
6+
MESSAGE_MAPPING_METADATA,
7+
)
8+
9+
10+
def test_on_connected_decorator_works():
11+
@on_connected()
12+
def sample_function():
13+
pass
14+
15+
assert getattr(sample_function, MESSAGE_MAPPING_METADATA)
16+
assert reflect.get_metadata_or_raise_exception(CONNECTION_EVENT, sample_function)
17+
18+
19+
def test_on_disconnected_decorator_works():
20+
@on_disconnected()
21+
def sample_function():
22+
pass
23+
24+
assert getattr(sample_function, MESSAGE_MAPPING_METADATA)
25+
assert reflect.get_metadata_or_raise_exception(DISCONNECT_EVENT, sample_function)
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
import pytest
2+
3+
from ellar.common import Guards
4+
from ellar.constants import CONTROLLER_CLASS_KEY, GUARDS_KEY
5+
from ellar.core.guard import HttpBearerAuth
6+
from ellar.helper import get_name
7+
from ellar.reflect import reflect
8+
from ellar.socket_io import WebSocketGateway, subscribe_message
9+
from ellar.socket_io.constants import (
10+
GATEWAY_MESSAGE_HANDLER_KEY,
11+
GATEWAY_METADATA,
12+
GATEWAY_OPTIONS,
13+
GATEWAY_WATERMARK,
14+
MESSAGE_METADATA,
15+
)
16+
from ellar.socket_io.model import GatewayBase, GatewayType
17+
18+
19+
@WebSocketGateway(path="/ws", namespace="/some-namespace")
20+
@Guards(HttpBearerAuth)
21+
class SampleWithoutGateway:
22+
pass
23+
24+
25+
@WebSocketGateway(path="/ws", namespace="/some-namespace")
26+
class SampleWithGateway(GatewayBase):
27+
pass
28+
29+
30+
@WebSocketGateway
31+
class SampleMarkAsGateway:
32+
pass
33+
34+
35+
@pytest.mark.parametrize(
36+
"gateway, watermark, options",
37+
[
38+
(
39+
SampleWithoutGateway,
40+
True,
41+
{
42+
"async_mode": "asgi",
43+
"cors_allowed_origins": "*",
44+
"namespace": "/some-namespace",
45+
},
46+
),
47+
(
48+
SampleWithGateway,
49+
False,
50+
{
51+
"async_mode": "asgi",
52+
"cors_allowed_origins": "*",
53+
"namespace": "/some-namespace",
54+
},
55+
),
56+
(
57+
SampleMarkAsGateway,
58+
True,
59+
{"async_mode": "asgi", "cors_allowed_origins": "*"},
60+
),
61+
],
62+
)
63+
def test_websocket_gateway_works_without_gateway(gateway, watermark, options):
64+
65+
assert isinstance(gateway, GatewayType)
66+
assert hasattr(gateway, "__GATEWAY_WATERMARK__") is watermark
67+
68+
assert reflect.get_metadata_or_raise_exception(GATEWAY_WATERMARK, gateway) is True
69+
assert reflect.get_metadata_or_raise_exception(GATEWAY_OPTIONS, gateway) == options
70+
71+
for key in GATEWAY_METADATA.keys:
72+
reflect.get_metadata_or_raise_exception(key, gateway)
73+
74+
75+
def test_gateway_guard():
76+
assert reflect.get_metadata_or_raise_exception(
77+
GUARDS_KEY, SampleWithoutGateway
78+
) == [HttpBearerAuth]
79+
80+
81+
def test_sub_message_building_works():
82+
@WebSocketGateway(path="/ws", namespace="/some-namespace")
83+
class SampleAGateway(GatewayBase):
84+
@subscribe_message
85+
def a_message(self):
86+
pass
87+
88+
message_handlers = reflect.get_metadata_or_raise_exception(
89+
GATEWAY_MESSAGE_HANDLER_KEY, SampleAGateway
90+
)
91+
assert len(message_handlers) == 1
92+
assert get_name(message_handlers[0]) == "a_message"
93+
message = reflect.get_metadata_or_raise_exception(
94+
MESSAGE_METADATA, message_handlers[0]
95+
)
96+
assert message == "a_message"
97+
assert (
98+
reflect.get_metadata_or_raise_exception(
99+
CONTROLLER_CLASS_KEY, SampleAGateway().a_message
100+
)
101+
is SampleAGateway
102+
)
103+
104+
105+
def test_sub_message_building_fails():
106+
with pytest.raises(Exception) as ex:
107+
108+
@WebSocketGateway(path="/ws", namespace="/some-namespace")
109+
class SampleBGateway(GatewayBase):
110+
@subscribe_message
111+
@reflect.metadata(CONTROLLER_CLASS_KEY, "b_message")
112+
def b_message(self):
113+
pass
114+
115+
assert (
116+
"SampleBGateway Gateway message handler tried to be processed more than once"
117+
in str(ex.value)
118+
)
119+
120+
121+
def test_cant_use_gateway_decorator_on_function():
122+
with pytest.raises(Exception) as ex:
123+
124+
@WebSocketGateway(path="/ws", namespace="/some-namespace")
125+
def sample_c_gateway():
126+
pass
127+
128+
assert "WebSocketGateway is a class decorator" in str(ex.value)
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from ellar.reflect import reflect
2+
from ellar.socket_io import subscribe_message
3+
from ellar.socket_io.constants import MESSAGE_MAPPING_METADATA, MESSAGE_METADATA
4+
5+
6+
def test_subscribe_message_works():
7+
@subscribe_message("sample")
8+
def sample_function():
9+
pass
10+
11+
assert getattr(sample_function, MESSAGE_MAPPING_METADATA)
12+
assert (
13+
reflect.get_metadata_or_raise_exception(MESSAGE_METADATA, sample_function)
14+
== "sample"
15+
)
16+
17+
@subscribe_message
18+
def sample_function2():
19+
pass
20+
21+
assert getattr(sample_function2, MESSAGE_MAPPING_METADATA)
22+
assert (
23+
reflect.get_metadata_or_raise_exception(MESSAGE_METADATA, sample_function2)
24+
== "sample_function2"
25+
)

0 commit comments

Comments
 (0)