Skip to content

Commit 45dfff9

Browse files
committed
Added more test on websocket
1 parent eafbd65 commit 45dfff9

File tree

4 files changed

+199
-15
lines changed

4 files changed

+199
-15
lines changed

ellar/core/routing/operation_definitions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ def ws_route(
263263
path: str = "/",
264264
*,
265265
name: str = None,
266-
encoding: str = "json",
266+
encoding: t.Optional[str] = "json",
267267
use_extra_handler: bool = False,
268268
extra_handler_type: t.Optional[t.Type] = None,
269269
) -> t.Callable[[TCallable], t.Union[TCallable, TWebsocketOperation]]:

ellar/core/routing/websocket/handler.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,13 @@ async def dispatch(
6565
elif message["type"] == "websocket.disconnect":
6666
close_code = int(message.get("code", status.WS_1000_NORMAL_CLOSURE))
6767
break
68+
except WebSocketException as wexc:
69+
await websocket.close(code=wexc.code)
70+
raise RuntimeError(wexc.reason)
6871
except Exception as exc:
6972
close_code = status.WS_1011_INTERNAL_ERROR
7073
raise exc
74+
7175
finally:
7276
await self.execute_on_disconnect(context=context, close_code=close_code)
7377

ellar/core/schema.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ class WsRouteParameters(BaseModel):
8484
path: str
8585
name: t.Optional[str] = None
8686
endpoint: t.Callable
87-
encoding: str = Field("json")
87+
encoding: t.Optional[str] = Field("json")
8888
use_extra_handler: bool = Field(False)
8989
extra_handler_type: t.Optional[t.Type["WebSocketExtraHandler"]] = None
9090

@@ -94,7 +94,7 @@ def validate_endpoint(cls, value: t.Any):
9494

9595
@validator("encoding")
9696
def validate_encoding(cls, value: t.Any):
97-
if value not in ["json", "text", "bytes"]:
97+
if value not in ["json", "text", "bytes", None]:
9898
raise ValueError(
9999
f"Encoding type not supported. Once [json | text | bytes]. Received: {value}"
100100
)

tests/test_websocket_handler.py

Lines changed: 192 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
import pytest
2-
from starlette.websockets import WebSocket
2+
from starlette.websockets import WebSocket, WebSocketState
33

4-
from ellar.common import ModuleRouter, WsBody
4+
from ellar.common import Controller, ModuleRouter, WsBody, ws_route
55
from ellar.core import TestClientFactory
66
from ellar.core.exceptions import ImproperConfiguration
77

88
from .schema import Item
99

10-
router = ModuleRouter("/")
10+
router = ModuleRouter("/router")
1111

1212

1313
@router.ws_route("/ws-with-handler", use_extra_handler=True)
@@ -24,6 +24,12 @@ async def websocket_with_handler_connect(websocket: WebSocket):
2424
await websocket.accept()
2525

2626

27+
@websocket_with_handler.disconnect
28+
async def websocket_with_handler_connect(websocket: WebSocket, code: int):
29+
# await websocket.close(code=code)
30+
assert websocket.client_state == WebSocketState.DISCONNECTED
31+
32+
2733
@router.ws_route("/ws")
2834
async def websocket_without_handler(websocket: WebSocket, query: str):
2935
assert query == "my-query"
@@ -33,12 +39,45 @@ async def websocket_without_handler(websocket: WebSocket, query: str):
3339
await websocket.close()
3440

3541

36-
tm = TestClientFactory.create_test_module(routers=[router])
42+
@Controller("/controller")
43+
class WebsocketController:
44+
@ws_route("/ws-with-handler", use_extra_handler=True)
45+
async def websocket_with_handler_c(
46+
self, websocket: WebSocket, query: str, data: Item = WsBody()
47+
):
48+
assert query == "my-query"
49+
await websocket.send_json(data.dict())
50+
await websocket.close()
51+
52+
@websocket_with_handler_c.connect
53+
async def websocket_with_handler_connect(self, websocket: WebSocket):
54+
await websocket.accept()
55+
56+
@websocket_with_handler_c.disconnect
57+
async def websocket_with_handler_connect(self, websocket: WebSocket, code: int):
58+
# await websocket.close(code=code)
59+
assert websocket.client_state == WebSocketState.DISCONNECTED
60+
61+
@ws_route("/ws")
62+
async def websocket_without_handler_c(self, websocket: WebSocket, query: str):
63+
assert query == "my-query"
64+
await websocket.accept()
65+
message = await websocket.receive_text()
66+
await websocket.send_json({"message": f"Thanks. {message}"})
67+
await websocket.close()
68+
69+
70+
tm = TestClientFactory.create_test_module(
71+
routers=[router], controllers=(WebsocketController,)
72+
)
3773
client = tm.get_client()
3874

3975

40-
def test_websocket_with_handler_works():
41-
with client.websocket_connect("/ws-with-handler?query=my-query") as session:
76+
@pytest.mark.parametrize("prefix", ["/router", "/controller"])
77+
def test_websocket_with_handler_works(prefix):
78+
with client.websocket_connect(
79+
f"{prefix}/ws-with-handler?query=my-query"
80+
) as session:
4281
session.send_json(Item(name="Ellar", price=23.34, tax=1.2).dict())
4382
message = session.receive_json()
4483
assert message == {
@@ -49,9 +88,12 @@ def test_websocket_with_handler_works():
4988
}
5089

5190

52-
def test_websocket_with_handler_fails_for_invalid_input():
91+
@pytest.mark.parametrize("prefix", ["/router", "/controller"])
92+
def test_websocket_with_handler_fails_for_invalid_input(prefix):
5393
with pytest.raises(Exception):
54-
with client.websocket_connect("/ws-with-handler?query=my-query") as session:
94+
with client.websocket_connect(
95+
f"{prefix}/ws-with-handler?query=my-query"
96+
) as session:
5597
session.send_json({"framework": "Ellar is awesome"})
5698
message = session.receive_json()
5799
assert message == {
@@ -71,9 +113,10 @@ def test_websocket_with_handler_fails_for_invalid_input():
71113
}
72114

73115

74-
def test_websocket_with_handler_fails_for_missing_route_parameter():
116+
@pytest.mark.parametrize("prefix", ["/router", "/controller"])
117+
def test_websocket_with_handler_fails_for_missing_route_parameter(prefix):
75118
with pytest.raises(Exception):
76-
with client.websocket_connect("/ws-with-handler") as session:
119+
with client.websocket_connect(f"{prefix}/ws-with-handler") as session:
77120
session.send_json(Item(name="Ellar", price=23.34, tax=1.2).dict())
78121
message = session.receive_json()
79122
assert message == {
@@ -88,8 +131,9 @@ def test_websocket_with_handler_fails_for_missing_route_parameter():
88131
}
89132

90133

91-
def test_plain_websocket_route():
92-
with client.websocket_connect("/ws?query=my-query") as websocket:
134+
@pytest.mark.parametrize("prefix", ["/router", "/controller"])
135+
def test_plain_websocket_route(prefix):
136+
with client.websocket_connect(f"{prefix}/ws?query=my-query") as websocket:
93137
websocket.send_text("Ellar")
94138
message = websocket.receive_json()
95139
assert message == {"message": "Thanks. Ellar"}
@@ -106,3 +150,139 @@ async def websocket_with_handler(
106150
websocket: WebSocket, query: str, data: Item = WsBody()
107151
):
108152
pass
153+
154+
155+
def test_websocket_endpoint_on_connect():
156+
@Controller("/ws")
157+
class WebSocketSample:
158+
@ws_route(use_extra_handler=True)
159+
async def ws(self, websocket: WebSocket):
160+
pass
161+
162+
@ws.connect
163+
async def on_connect(self, websocket):
164+
assert websocket["subprotocols"] == ["soap", "wamp"]
165+
await websocket.accept(subprotocol="wamp")
166+
167+
_client = TestClientFactory.create_test_module(
168+
controllers=(WebSocketSample,)
169+
).get_client()
170+
with _client.websocket_connect("/ws/", subprotocols=["soap", "wamp"]) as websocket:
171+
assert websocket.accepted_subprotocol == "wamp"
172+
173+
174+
def test_websocket_endpoint_on_receive_bytes():
175+
@Controller("/ws")
176+
class WebSocketSample:
177+
@ws_route(use_extra_handler=True, encoding="bytes")
178+
async def ws(self, websocket: WebSocket, data: bytes = WsBody()):
179+
await websocket.send_bytes(b"Message bytes was: " + data)
180+
181+
_client = TestClientFactory.create_test_module(
182+
controllers=(WebSocketSample,)
183+
).get_client()
184+
with _client.websocket_connect("/ws/") as websocket:
185+
websocket.send_bytes(b"Hello, world!")
186+
_bytes = websocket.receive_bytes()
187+
assert _bytes == b"Message bytes was: Hello, world!"
188+
189+
with pytest.raises(RuntimeError):
190+
with _client.websocket_connect("/ws/") as websocket:
191+
websocket.send_text("Hello world")
192+
193+
194+
def test_websocket_endpoint_on_receive_json():
195+
@Controller("/ws")
196+
class WebSocketSample:
197+
@ws_route(use_extra_handler=True, encoding="json")
198+
async def ws(self, websocket: WebSocket, data=WsBody()):
199+
await websocket.send_json({"message": data})
200+
201+
_client = TestClientFactory.create_test_module(
202+
controllers=(WebSocketSample,)
203+
).get_client()
204+
205+
with _client.websocket_connect("/ws/") as websocket:
206+
websocket.send_json({"hello": "world"})
207+
data = websocket.receive_json()
208+
assert data == {"message": {"hello": "world"}}
209+
210+
with pytest.raises(RuntimeError):
211+
with _client.websocket_connect("/ws/") as websocket:
212+
websocket.send_text("Hello world")
213+
214+
215+
def test_websocket_endpoint_on_receive_json_binary():
216+
@Controller("/ws")
217+
class WebSocketSample:
218+
@ws_route(use_extra_handler=True, encoding="json")
219+
async def ws(self, websocket: WebSocket, data=WsBody()):
220+
await websocket.send_json({"message": data}, mode="binary")
221+
222+
_client = TestClientFactory.create_test_module(
223+
controllers=(WebSocketSample,)
224+
).get_client()
225+
226+
with _client.websocket_connect("/ws/") as websocket:
227+
websocket.send_json({"hello": "world"}, mode="binary")
228+
data = websocket.receive_json(mode="binary")
229+
assert data == {"message": {"hello": "world"}}
230+
231+
232+
def test_websocket_endpoint_on_receive_text():
233+
@Controller("/ws")
234+
class WebSocketSample:
235+
@ws_route(use_extra_handler=True, encoding="text")
236+
async def ws(self, websocket: WebSocket, data: str = WsBody()):
237+
await websocket.send_text(f"Message text was: {data}")
238+
239+
_client = TestClientFactory.create_test_module(
240+
controllers=(WebSocketSample,)
241+
).get_client()
242+
243+
with _client.websocket_connect("/ws/") as websocket:
244+
websocket.send_text("Hello, world!")
245+
_text = websocket.receive_text()
246+
assert _text == "Message text was: Hello, world!"
247+
248+
with pytest.raises(RuntimeError):
249+
with _client.websocket_connect("/ws/") as websocket:
250+
websocket.send_bytes(b"Hello world")
251+
252+
253+
def test_websocket_endpoint_on_default():
254+
@Controller("/ws")
255+
class WebSocketSample:
256+
@ws_route(use_extra_handler=True, encoding=None)
257+
async def ws(self, websocket: WebSocket, data: str = WsBody()):
258+
await websocket.send_text(f"Message text was: {data}")
259+
260+
_client = TestClientFactory.create_test_module(
261+
controllers=(WebSocketSample,)
262+
).get_client()
263+
264+
with _client.websocket_connect("/ws/") as websocket:
265+
websocket.send_text("Hello, world!")
266+
_text = websocket.receive_text()
267+
assert _text == "Message text was: Hello, world!"
268+
269+
270+
def test_websocket_endpoint_on_disconnect():
271+
@Controller("/ws")
272+
class WebSocketSample:
273+
@ws_route(use_extra_handler=True, encoding=None)
274+
async def ws(self, websocket: WebSocket, data: str = WsBody()):
275+
await websocket.send_text(f"Message text was: {data}")
276+
277+
@ws.disconnect
278+
async def on_disconnect(self, websocket: WebSocket, close_code):
279+
assert close_code == 1001
280+
await websocket.close(code=close_code)
281+
282+
_client = TestClientFactory.create_test_module(
283+
controllers=(WebSocketSample,)
284+
).get_client()
285+
286+
with _client.websocket_connect("/ws/") as websocket:
287+
websocket.send_text("Hello, world!")
288+
websocket.close(code=1001)

0 commit comments

Comments
 (0)