diff --git a/starlette/websockets.py b/starlette/websockets.py index 6b46f4eae..a3944330e 100644 --- a/starlette/websockets.py +++ b/starlette/websockets.py @@ -22,6 +22,14 @@ def __init__(self, code: int = 1000, reason: str | None = None) -> None: self.reason = reason or "" +class WebSocketDisconnected(RuntimeError): + """ + Raised when attempting to use a disconnected WebSocket. + """ + + pass + + class WebSocket(HTTPConnection): def __init__(self, scope: Scope, receive: Receive, send: Send) -> None: super().__init__(scope) @@ -53,7 +61,7 @@ async def receive(self) -> Message: self.client_state = WebSocketState.DISCONNECTED return message else: - raise RuntimeError('Cannot call "receive" once a disconnect message has been received.') + raise WebSocketDisconnected('Cannot call "receive" once a disconnect message has been received.') async def send(self, message: Message) -> None: """ @@ -94,7 +102,7 @@ async def send(self, message: Message) -> None: self.application_state = WebSocketState.DISCONNECTED await self._send(message) else: - raise RuntimeError('Cannot call "send" once a close message has been sent.') + raise WebSocketDisconnected('Cannot call "send" once a close message has been sent.') async def accept( self, @@ -114,14 +122,14 @@ def _raise_on_disconnect(self, message: Message) -> None: async def receive_text(self) -> str: if self.application_state != WebSocketState.CONNECTED: - raise RuntimeError('WebSocket is not connected. Need to call "accept" first.') + raise WebSocketDisconnected('WebSocket is not connected. Need to call "accept" first.') message = await self.receive() self._raise_on_disconnect(message) return typing.cast(str, message["text"]) async def receive_bytes(self) -> bytes: if self.application_state != WebSocketState.CONNECTED: - raise RuntimeError('WebSocket is not connected. Need to call "accept" first.') + raise WebSocketDisconnected('WebSocket is not connected. Need to call "accept" first.') message = await self.receive() self._raise_on_disconnect(message) return typing.cast(bytes, message["bytes"]) @@ -130,7 +138,7 @@ async def receive_json(self, mode: str = "text") -> typing.Any: if mode not in {"text", "binary"}: raise RuntimeError('The "mode" argument should be "text" or "binary".') if self.application_state != WebSocketState.CONNECTED: - raise RuntimeError('WebSocket is not connected. Need to call "accept" first.') + raise WebSocketDisconnected('WebSocket is not connected. Need to call "accept" first.') message = await self.receive() self._raise_on_disconnect(message) diff --git a/tests/test_websockets.py b/tests/test_websockets.py index e76d8f29b..450da3e9f 100644 --- a/tests/test_websockets.py +++ b/tests/test_websockets.py @@ -10,7 +10,7 @@ from starlette.responses import Response from starlette.testclient import WebSocketDenialResponse from starlette.types import Message, Receive, Scope, Send -from starlette.websockets import WebSocket, WebSocketDisconnect, WebSocketState +from starlette.websockets import WebSocket, WebSocketDisconnect, WebSocketDisconnected, WebSocketState from tests.types import TestClientFactory @@ -449,7 +449,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: await websocket.close() client = test_client_factory(app) - with pytest.raises(RuntimeError): + with pytest.raises(WebSocketDisconnected): with client.websocket_connect("/"): pass # pragma: no cover @@ -463,7 +463,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: message = await websocket.receive() client = test_client_factory(app) - with pytest.raises(RuntimeError): + with pytest.raises(WebSocketDisconnected): with client.websocket_connect("/") as websocket: websocket.close() @@ -539,7 +539,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: await websocket.receive_text() client = test_client_factory(app) - with pytest.raises(RuntimeError): + with pytest.raises(WebSocketDisconnected): with client.websocket_connect("/"): pass # pragma: no cover @@ -550,7 +550,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: await websocket.receive_bytes() client = test_client_factory(app) - with pytest.raises(RuntimeError): + with pytest.raises(WebSocketDisconnected): with client.websocket_connect("/"): pass # pragma: no cover @@ -561,7 +561,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: await websocket.receive_json() client = test_client_factory(app) - with pytest.raises(RuntimeError): + with pytest.raises(WebSocketDisconnected): with client.websocket_connect("/"): pass # pragma: no cover @@ -577,6 +577,42 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: pass # pragma: no cover +def test_receive_text_after_close(test_client_factory: TestClientFactory) -> None: + async def app(scope: Scope, receive: Receive, send: Send) -> None: + websocket = WebSocket(scope, receive=receive, send=send) + await websocket.close() + await websocket.receive_text() + + client = test_client_factory(app) + with pytest.raises(WebSocketDisconnected): + with client.websocket_connect("/"): + pass # pragma: no cover + + +def test_receive_bytes_after_close(test_client_factory: TestClientFactory) -> None: + async def app(scope: Scope, receive: Receive, send: Send) -> None: + websocket = WebSocket(scope, receive=receive, send=send) + await websocket.close() + await websocket.receive_bytes() + + client = test_client_factory(app) + with pytest.raises(WebSocketDisconnected): + with client.websocket_connect("/"): + pass # pragma: no cover + + +def test_receive_json_after_close(test_client_factory: TestClientFactory) -> None: + async def app(scope: Scope, receive: Receive, send: Send) -> None: + websocket = WebSocket(scope, receive=receive, send=send) + await websocket.close() + await websocket.receive_json() + + client = test_client_factory(app) + with pytest.raises(WebSocketDisconnected): + with client.websocket_connect("/"): + pass # pragma: no cover + + def test_send_wrong_message_type(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) @@ -602,6 +638,19 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket.send({"type": "websocket.send"}) +def test_receive_after_close(test_client_factory: TestClientFactory) -> None: + async def app(scope: Scope, receive: Receive, send: Send) -> None: + websocket = WebSocket(scope, receive=receive, send=send) + await websocket.accept() + websocket.client_state = WebSocketState.DISCONNECTED + await websocket.receive() + + client = test_client_factory(app) + with pytest.raises(WebSocketDisconnected): + with client.websocket_connect("/"): + pass # pragma: no cover + + def test_receive_wrong_message_type(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send)