diff --git a/tests/test_connection.py b/tests/test_connection.py index 45c1268..cbb4f4b 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -401,6 +401,7 @@ def test_client_open_url_options( # type: ignore[misc] 'extra_headers': [(b'X-Test-Header', b'My test header')], 'message_queue_size': 9, 'max_message_size': 333, + 'receive_buffer_size': 999, 'connect_timeout': 36, 'disconnect_timeout': 37, } diff --git a/trio_websocket/_impl.py b/trio_websocket/_impl.py index b7c1ea4..dbd28c3 100644 --- a/trio_websocket/_impl.py +++ b/trio_websocket/_impl.py @@ -117,6 +117,7 @@ async def open_websocket( extra_headers: Optional[list[tuple[bytes,bytes]]] = None, message_queue_size: int = MESSAGE_QUEUE_SIZE, max_message_size: int = MAX_MESSAGE_SIZE, + receive_buffer_size: Union[None, int] = RECEIVE_BYTES, connect_timeout: float = CONN_TIMEOUT, disconnect_timeout: float = CONN_TIMEOUT ) -> AsyncGenerator[WebSocketConnection, None]: @@ -144,6 +145,9 @@ async def open_websocket( :param int max_message_size: The maximum message size as measured by ``len()``. If a message is received that is larger than this size, then the connection is closed with code 1009 (Message Too Big). + :param Optional[int] receive_buffer_size: The buffer size we use to + receive messages internally. None to let trio choose. Defaults + to 4 KiB. :param float connect_timeout: The number of seconds to wait for the connection before timing out. :param float disconnect_timeout: The number of seconds to wait when closing @@ -182,7 +186,8 @@ async def _open_connection(nursery: trio.Nursery) -> WebSocketConnection: resource, use_ssl=use_ssl, subprotocols=subprotocols, extra_headers=extra_headers, message_queue_size=message_queue_size, - max_message_size=max_message_size) + max_message_size=max_message_size, + receive_buffer_size=receive_buffer_size) except trio.TooSlowError: raise ConnectionTimeout from None except OSError as e: @@ -311,6 +316,7 @@ async def connect_websocket( extra_headers: list[tuple[bytes, bytes]] | None = None, message_queue_size: int = MESSAGE_QUEUE_SIZE, max_message_size: int = MAX_MESSAGE_SIZE, + receive_buffer_size: Union[None, int] = RECEIVE_BYTES, ) -> WebSocketConnection: ''' Return an open WebSocket client connection to a host. @@ -339,6 +345,9 @@ async def connect_websocket( :param int max_message_size: The maximum message size as measured by ``len()``. If a message is received that is larger than this size, then the connection is closed with code 1009 (Message Too Big). + :param Optional[int] receive_buffer_size: The buffer size we use to + receive messages internally. None to let trio choose. Defaults + to 4 KiB. :rtype: WebSocketConnection ''' if use_ssl is True: @@ -368,7 +377,8 @@ async def connect_websocket( path=resource, client_subprotocols=subprotocols, client_extra_headers=extra_headers, message_queue_size=message_queue_size, - max_message_size=max_message_size) + max_message_size=max_message_size, + receive_buffer_size=receive_buffer_size) nursery.start_soon(connection._reader_task) await connection._open_handshake.wait() return connection @@ -384,6 +394,7 @@ def open_websocket_url( max_message_size: int = MAX_MESSAGE_SIZE, connect_timeout: float = CONN_TIMEOUT, disconnect_timeout: float = CONN_TIMEOUT, + receive_buffer_size: Union[None, int] = RECEIVE_BYTES, ) -> AbstractAsyncContextManager[WebSocketConnection]: ''' Open a WebSocket client connection to a URL. @@ -407,6 +418,9 @@ def open_websocket_url( :param int max_message_size: The maximum message size as measured by ``len()``. If a message is received that is larger than this size, then the connection is closed with code 1009 (Message Too Big). + :param Optional[int] receive_buffer_size: The buffer size we use to + receive messages internally. None to let trio choose. Defaults + to 4 KiB. :param float connect_timeout: The number of seconds to wait for the connection before timing out. :param float disconnect_timeout: The number of seconds to wait when closing @@ -420,6 +434,7 @@ def open_websocket_url( subprotocols=subprotocols, extra_headers=extra_headers, message_queue_size=message_queue_size, max_message_size=max_message_size, + receive_buffer_size=receive_buffer_size, connect_timeout=connect_timeout, disconnect_timeout=disconnect_timeout) @@ -432,6 +447,7 @@ async def connect_websocket_url( extra_headers: list[tuple[bytes, bytes]] | None = None, message_queue_size: int = MESSAGE_QUEUE_SIZE, max_message_size: int = MAX_MESSAGE_SIZE, + receive_buffer_size: Union[None, int] = RECEIVE_BYTES, ) -> WebSocketConnection: ''' Return an open WebSocket client connection to a URL. @@ -457,13 +473,17 @@ async def connect_websocket_url( :param int max_message_size: The maximum message size as measured by ``len()``. If a message is received that is larger than this size, then the connection is closed with code 1009 (Message Too Big). + :param Optional[int] receive_buffer_size: The buffer size we use to + receive messages internally. None to let trio choose. Defaults + to 4 KiB. :rtype: WebSocketConnection ''' host, port, resource, return_ssl_context = _url_to_host(url, ssl_context) return await connect_websocket(nursery, host, port, resource, use_ssl=return_ssl_context, subprotocols=subprotocols, extra_headers=extra_headers, message_queue_size=message_queue_size, - max_message_size=max_message_size) + max_message_size=max_message_size, + receive_buffer_size=receive_buffer_size) def _url_to_host( @@ -520,6 +540,7 @@ async def wrap_client_stream( extra_headers: list[tuple[bytes, bytes]] | None = None, message_queue_size: int = MESSAGE_QUEUE_SIZE, max_message_size: int = MAX_MESSAGE_SIZE, + receive_buffer_size: Union[None, int] = RECEIVE_BYTES, ) -> WebSocketConnection: ''' Wrap an arbitrary stream in a WebSocket connection. @@ -544,6 +565,9 @@ async def wrap_client_stream( :param int max_message_size: The maximum message size as measured by ``len()``. If a message is received that is larger than this size, then the connection is closed with code 1009 (Message Too Big). + :param Optional[int] receive_buffer_size: The buffer size we use to + receive messages internally. None to let trio choose. Defaults + to 4 KiB. :rtype: WebSocketConnection ''' connection = WebSocketConnection(stream, @@ -551,7 +575,8 @@ async def wrap_client_stream( host=host, path=resource, client_subprotocols=subprotocols, client_extra_headers=extra_headers, message_queue_size=message_queue_size, - max_message_size=max_message_size) + max_message_size=max_message_size, + receive_buffer_size=receive_buffer_size) nursery.start_soon(connection._reader_task) await connection._open_handshake.wait() return connection @@ -562,6 +587,7 @@ async def wrap_server_stream( stream: trio.abc.Stream, message_queue_size: int = MESSAGE_QUEUE_SIZE, max_message_size: int = MAX_MESSAGE_SIZE, + receive_buffer_size: Union[None, int] = RECEIVE_BYTES, ) -> WebSocketRequest: ''' Wrap an arbitrary stream in a server-side WebSocket. @@ -576,6 +602,9 @@ async def wrap_server_stream( :param int max_message_size: The maximum message size as measured by ``len()``. If a message is received that is larger than this size, then the connection is closed with code 1009 (Message Too Big). + :param Optional[int] receive_buffer_size: The buffer size we use to + receive messages internally. None to let trio choose. Defaults + to 4 KiB. :type stream: trio.abc.Stream :rtype: WebSocketRequest ''' @@ -583,12 +612,14 @@ async def wrap_server_stream( stream, WSConnection(ConnectionType.SERVER), message_queue_size=message_queue_size, - max_message_size=max_message_size) + max_message_size=max_message_size, + receive_buffer_size=receive_buffer_size) nursery.start_soon(connection._reader_task) request = await connection._get_request() return request + async def serve_websocket( handler: Callable[[WebSocketRequest], Awaitable[None]], host: str | bytes | None, @@ -598,6 +629,7 @@ async def serve_websocket( handler_nursery: trio.Nursery | None = None, message_queue_size: int = MESSAGE_QUEUE_SIZE, max_message_size: int = MAX_MESSAGE_SIZE, + receive_buffer_size: Union[None, int] = RECEIVE_BYTES, connect_timeout: float = CONN_TIMEOUT, disconnect_timeout: float = CONN_TIMEOUT, task_status: trio.TaskStatus[WebSocketServer] = trio.TASK_STATUS_IGNORED, @@ -630,6 +662,9 @@ async def serve_websocket( :param int max_message_size: The maximum message size as measured by ``len()``. If a message is received that is larger than this size, then the connection is closed with code 1009 (Message Too Big). + :param Optional[int] receive_buffer_size: The buffer size we use to + receive messages internally. None to let trio choose. Defaults + to 4 KiB. :param float connect_timeout: The number of seconds to wait for a client to finish connection handshake before timing out. :param float disconnect_timeout: The number of seconds to wait for a client @@ -658,6 +693,7 @@ async def serve_websocket( handler_nursery=handler_nursery, message_queue_size=message_queue_size, max_message_size=max_message_size, + receive_buffer_size=receive_buffer_size, connect_timeout=connect_timeout, disconnect_timeout=disconnect_timeout, ) @@ -957,7 +993,8 @@ def __init__( client_subprotocols: Iterable[str] | None = None, client_extra_headers: list[tuple[bytes, bytes]] | None = None, message_queue_size: int = MESSAGE_QUEUE_SIZE, - max_message_size: int = MAX_MESSAGE_SIZE + max_message_size: int = MAX_MESSAGE_SIZE, + receive_buffer_size: Union[None, int] = RECEIVE_BYTES, ) -> None: ''' Constructor. @@ -984,6 +1021,9 @@ def __init__( :param int max_message_size: The maximum message size as measured by ``len()``. If a message is received that is larger than this size, then the connection is closed with code 1009 (Message Too Big). + :param Optional[int] receive_buffer_size: The buffer size we use to + receive messages internally. None to let trio choose. Defaults + to 4 KiB. ''' # NOTE: The implementation uses _close_reason for more than an advisory # purpose. It's critical internal state, indicating when the @@ -996,6 +1036,7 @@ def __init__( self._message_size = 0 self._message_parts: List[Union[bytes, str]] = [] self._max_message_size = max_message_size + self._receive_buffer_size: Optional[int] = receive_buffer_size self._reader_running = True if ws_connection.client: assert host is not None @@ -1528,7 +1569,7 @@ async def _reader_task(self) -> None: # Get network data. try: - data = await self._stream.receive_some(RECEIVE_BYTES) + data = await self._stream.receive_some(self._receive_buffer_size) except (trio.BrokenResourceError, trio.ClosedResourceError): await self._abort_web_socket() break @@ -1619,6 +1660,7 @@ def __init__( handler_nursery: trio.Nursery | None = None, message_queue_size: int = MESSAGE_QUEUE_SIZE, max_message_size: int = MAX_MESSAGE_SIZE, + receive_buffer_size: Union[None, int] = RECEIVE_BYTES, connect_timeout: float = CONN_TIMEOUT, disconnect_timeout: float = CONN_TIMEOUT, ) -> None: @@ -1637,6 +1679,9 @@ def __init__( :param handler_nursery: An optional nursery to spawn connection tasks inside of. If ``None``, then a new nursery will be created internally. + :param Optional[int] receive_buffer_size: The buffer size we use to + receive messages internally. None to let trio choose. Defaults + to 4 KiB. :param float connect_timeout: The number of seconds to wait for a client to finish connection handshake before timing out. :param float disconnect_timeout: The number of seconds to wait for a client @@ -1649,6 +1694,7 @@ def __init__( self._listeners = listeners self._message_queue_size = message_queue_size self._max_message_size = max_message_size + self._receive_buffer_size = receive_buffer_size self._connect_timeout = connect_timeout self._disconnect_timeout = disconnect_timeout @@ -1741,7 +1787,8 @@ async def _handle_connection(self, stream: trio.abc.Stream) -> None: connection = WebSocketConnection(stream, WSConnection(ConnectionType.SERVER), message_queue_size=self._message_queue_size, - max_message_size=self._max_message_size) + max_message_size=self._max_message_size, + receive_buffer_size=self._receive_buffer_size) nursery.start_soon(connection._reader_task) with trio.move_on_after(self._connect_timeout) as connect_scope: request = await connection._get_request()