From 4bb02a40f82e244fd2a14b7ad8f8ff4be2136e2d Mon Sep 17 00:00:00 2001 From: palkeo Date: Thu, 16 Nov 2023 17:19:46 +0000 Subject: [PATCH 1/4] Add the ability to specify the buffer size. --- trio_websocket/_impl.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/trio_websocket/_impl.py b/trio_websocket/_impl.py index 259f4fd..9e7ce65 100644 --- a/trio_websocket/_impl.py +++ b/trio_websocket/_impl.py @@ -38,7 +38,7 @@ CONN_TIMEOUT = 60 # default connect & disconnect timeout, in seconds MESSAGE_QUEUE_SIZE = 1 MAX_MESSAGE_SIZE = 2 ** 20 # 1 MiB -RECEIVE_BYTES = 4 * 2 ** 10 # 4 KiB +DEFAULT_RECEIVE_BYTES = 4 * 2 ** 10 # 4 KiB logger = logging.getLogger('trio-websocket') @@ -687,7 +687,8 @@ class WebSocketConnection(trio.abc.AsyncResource): def __init__(self, stream, ws_connection, *, host=None, path=None, client_subprotocols=None, client_extra_headers=None, message_queue_size=MESSAGE_QUEUE_SIZE, - max_message_size=MAX_MESSAGE_SIZE): + max_message_size=MAX_MESSAGE_SIZE, + receive_buffer_size=DEFAULT_RECEIVE_BYTES): ''' Constructor. @@ -713,6 +714,9 @@ def __init__(self, stream, ws_connection, *, host=None, path=None, :param int max_message_size: The maximum message size as measured by ``len()``. If a message is received that is larger than this size, then the connection is closed with code 1009 (Message Too Big). + :param Optional[int] receive_buffer_size: The buffer size we use to + receive messages internally. None to let trio choose. Defaults + to 4 KiB. ''' # NOTE: The implementation uses _close_reason for more than an advisory # purpose. It's critical internal state, indicating when the @@ -725,6 +729,7 @@ def __init__(self, stream, ws_connection, *, host=None, path=None, self._message_size = 0 self._message_parts: List[Union[bytes, str]] = [] self._max_message_size = max_message_size + self._receive_buffer_size = receive_buffer_size self._reader_running = True if ws_connection.client: self._initial_request: Optional[Request] = Request(host=host, target=path, @@ -1232,7 +1237,7 @@ async def _reader_task(self): # Get network data. try: - data = await self._stream.receive_some(RECEIVE_BYTES) + data = await self._stream.receive_some(self._receive_buffer_size) except (trio.BrokenResourceError, trio.ClosedResourceError): await self._abort_web_socket() break From e988b563173955993e08c55722e34d40bb96d532 Mon Sep 17 00:00:00 2001 From: palkeo Date: Thu, 16 Nov 2023 17:25:25 +0000 Subject: [PATCH 2/4] add type annotation --- trio_websocket/_impl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trio_websocket/_impl.py b/trio_websocket/_impl.py index 9e7ce65..4bfe153 100644 --- a/trio_websocket/_impl.py +++ b/trio_websocket/_impl.py @@ -729,7 +729,7 @@ def __init__(self, stream, ws_connection, *, host=None, path=None, self._message_size = 0 self._message_parts: List[Union[bytes, str]] = [] self._max_message_size = max_message_size - self._receive_buffer_size = receive_buffer_size + self._receive_buffer_size: Optional[int] = receive_buffer_size self._reader_running = True if ws_connection.client: self._initial_request: Optional[Request] = Request(host=host, target=path, From a6d4816760568405c2747efff4be44c53163560a Mon Sep 17 00:00:00 2001 From: palkeo Date: Thu, 23 Nov 2023 21:22:43 +0000 Subject: [PATCH 3/4] properly add receive_buffer_size everywhere --- tests/test_connection.py | 1 + trio_websocket/_impl.py | 72 +++++++++++++++++++++++++++++++--------- 2 files changed, 58 insertions(+), 15 deletions(-) diff --git a/tests/test_connection.py b/tests/test_connection.py index f4e087a..736796a 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -311,6 +311,7 @@ def test_client_open_url_options(open_websocket_mock): 'extra_headers': [(b'X-Test-Header', b'My test header')], 'message_queue_size': 9, 'max_message_size': 333, + 'receive_buffer_size': 999, 'connect_timeout': 36, 'disconnect_timeout': 37, } diff --git a/trio_websocket/_impl.py b/trio_websocket/_impl.py index 4bfe153..66f97bb 100644 --- a/trio_websocket/_impl.py +++ b/trio_websocket/_impl.py @@ -38,7 +38,7 @@ CONN_TIMEOUT = 60 # default connect & disconnect timeout, in seconds MESSAGE_QUEUE_SIZE = 1 MAX_MESSAGE_SIZE = 2 ** 20 # 1 MiB -DEFAULT_RECEIVE_BYTES = 4 * 2 ** 10 # 4 KiB +RECEIVE_BYTES = 4 * 2 ** 10 # 4 KiB logger = logging.getLogger('trio-websocket') @@ -81,6 +81,7 @@ def __exit__(self, ty, value, tb): async def open_websocket(host, port, resource, *, use_ssl, subprotocols=None, extra_headers=None, message_queue_size=MESSAGE_QUEUE_SIZE, max_message_size=MAX_MESSAGE_SIZE, + receive_buffer_size=RECEIVE_BYTES, connect_timeout=CONN_TIMEOUT, disconnect_timeout=CONN_TIMEOUT): ''' Open a WebSocket client connection to a host. @@ -106,6 +107,9 @@ async def open_websocket(host, port, resource, *, use_ssl, subprotocols=None, :param int max_message_size: The maximum message size as measured by ``len()``. If a message is received that is larger than this size, then the connection is closed with code 1009 (Message Too Big). + :param Optional[int] receive_buffer_size: The buffer size we use to + receive messages internally. None to let trio choose. Defaults + to 4 KiB. :param float connect_timeout: The number of seconds to wait for the connection before timing out. :param float disconnect_timeout: The number of seconds to wait when closing @@ -121,7 +125,8 @@ async def open_websocket(host, port, resource, *, use_ssl, subprotocols=None, resource, use_ssl=use_ssl, subprotocols=subprotocols, extra_headers=extra_headers, message_queue_size=message_queue_size, - max_message_size=max_message_size) + max_message_size=max_message_size, + receive_buffer_size=receive_buffer_size) except trio.TooSlowError: raise ConnectionTimeout from None except OSError as e: @@ -138,7 +143,8 @@ async def open_websocket(host, port, resource, *, use_ssl, subprotocols=None, async def connect_websocket(nursery, host, port, resource, *, use_ssl, subprotocols=None, extra_headers=None, - message_queue_size=MESSAGE_QUEUE_SIZE, max_message_size=MAX_MESSAGE_SIZE): + message_queue_size=MESSAGE_QUEUE_SIZE, max_message_size=MAX_MESSAGE_SIZE, + receive_buffer_size=RECEIVE_BYTES): ''' Return an open WebSocket client connection to a host. @@ -166,6 +172,9 @@ async def connect_websocket(nursery, host, port, resource, *, use_ssl, :param int max_message_size: The maximum message size as measured by ``len()``. If a message is received that is larger than this size, then the connection is closed with code 1009 (Message Too Big). + :param Optional[int] receive_buffer_size: The buffer size we use to + receive messages internally. None to let trio choose. Defaults + to 4 KiB. :rtype: WebSocketConnection ''' if use_ssl is True: @@ -194,7 +203,8 @@ async def connect_websocket(nursery, host, port, resource, *, use_ssl, path=resource, client_subprotocols=subprotocols, client_extra_headers=extra_headers, message_queue_size=message_queue_size, - max_message_size=max_message_size) + max_message_size=max_message_size, + receive_buffer_size=receive_buffer_size) nursery.start_soon(connection._reader_task) await connection._open_handshake.wait() return connection @@ -203,6 +213,7 @@ async def connect_websocket(nursery, host, port, resource, *, use_ssl, def open_websocket_url(url, ssl_context=None, *, subprotocols=None, extra_headers=None, message_queue_size=MESSAGE_QUEUE_SIZE, max_message_size=MAX_MESSAGE_SIZE, + receive_buffer_size=RECEIVE_BYTES, connect_timeout=CONN_TIMEOUT, disconnect_timeout=CONN_TIMEOUT): ''' Open a WebSocket client connection to a URL. @@ -226,6 +237,9 @@ def open_websocket_url(url, ssl_context=None, *, subprotocols=None, :param int max_message_size: The maximum message size as measured by ``len()``. If a message is received that is larger than this size, then the connection is closed with code 1009 (Message Too Big). + :param Optional[int] receive_buffer_size: The buffer size we use to + receive messages internally. None to let trio choose. Defaults + to 4 KiB. :param float connect_timeout: The number of seconds to wait for the connection before timing out. :param float disconnect_timeout: The number of seconds to wait when closing @@ -239,12 +253,14 @@ def open_websocket_url(url, ssl_context=None, *, subprotocols=None, subprotocols=subprotocols, extra_headers=extra_headers, message_queue_size=message_queue_size, max_message_size=max_message_size, + receive_buffer_size=receive_buffer_size, connect_timeout=connect_timeout, disconnect_timeout=disconnect_timeout) async def connect_websocket_url(nursery, url, ssl_context=None, *, subprotocols=None, extra_headers=None, - message_queue_size=MESSAGE_QUEUE_SIZE, max_message_size=MAX_MESSAGE_SIZE): + message_queue_size=MESSAGE_QUEUE_SIZE, max_message_size=MAX_MESSAGE_SIZE, + receive_buffer_size=RECEIVE_BYTES): ''' Return an open WebSocket client connection to a URL. @@ -269,13 +285,17 @@ async def connect_websocket_url(nursery, url, ssl_context=None, *, :param int max_message_size: The maximum message size as measured by ``len()``. If a message is received that is larger than this size, then the connection is closed with code 1009 (Message Too Big). + :param Optional[int] receive_buffer_size: The buffer size we use to + receive messages internally. None to let trio choose. Defaults + to 4 KiB. :rtype: WebSocketConnection ''' host, port, resource, ssl_context = _url_to_host(url, ssl_context) return await connect_websocket(nursery, host, port, resource, use_ssl=ssl_context, subprotocols=subprotocols, extra_headers=extra_headers, message_queue_size=message_queue_size, - max_message_size=max_message_size) + max_message_size=max_message_size, + receive_buffer_size=receive_buffer_size) def _url_to_host(url, ssl_context): @@ -316,7 +336,8 @@ def _url_to_host(url, ssl_context): async def wrap_client_stream(nursery, stream, host, resource, *, subprotocols=None, extra_headers=None, - message_queue_size=MESSAGE_QUEUE_SIZE, max_message_size=MAX_MESSAGE_SIZE): + message_queue_size=MESSAGE_QUEUE_SIZE, max_message_size=MAX_MESSAGE_SIZE, + receive_buffer_size=RECEIVE_BYTES): ''' Wrap an arbitrary stream in a WebSocket connection. @@ -340,6 +361,9 @@ async def wrap_client_stream(nursery, stream, host, resource, *, :param int max_message_size: The maximum message size as measured by ``len()``. If a message is received that is larger than this size, then the connection is closed with code 1009 (Message Too Big). + :param Optional[int] receive_buffer_size: The buffer size we use to + receive messages internally. None to let trio choose. Defaults + to 4 KiB. :rtype: WebSocketConnection ''' connection = WebSocketConnection(stream, @@ -347,14 +371,16 @@ async def wrap_client_stream(nursery, stream, host, resource, *, host=host, path=resource, client_subprotocols=subprotocols, client_extra_headers=extra_headers, message_queue_size=message_queue_size, - max_message_size=max_message_size) + max_message_size=max_message_size, + receive_buffer_size=receive_buffer_size) nursery.start_soon(connection._reader_task) await connection._open_handshake.wait() return connection async def wrap_server_stream(nursery, stream, - message_queue_size=MESSAGE_QUEUE_SIZE, max_message_size=MAX_MESSAGE_SIZE): + message_queue_size=MESSAGE_QUEUE_SIZE, max_message_size=MAX_MESSAGE_SIZE, + receive_buffer_size=RECEIVE_BYTES): ''' Wrap an arbitrary stream in a server-side WebSocket. @@ -368,13 +394,17 @@ async def wrap_server_stream(nursery, stream, :param int max_message_size: The maximum message size as measured by ``len()``. If a message is received that is larger than this size, then the connection is closed with code 1009 (Message Too Big). + :param Optional[int] receive_buffer_size: The buffer size we use to + receive messages internally. None to let trio choose. Defaults + to 4 KiB. :type stream: trio.abc.Stream :rtype: WebSocketRequest ''' connection = WebSocketConnection(stream, WSConnection(ConnectionType.SERVER), message_queue_size=message_queue_size, - max_message_size=max_message_size) + max_message_size=max_message_size, + receive_buffer_size=receive_buffer_size) nursery.start_soon(connection._reader_task) request = await connection._get_request() return request @@ -382,7 +412,8 @@ async def wrap_server_stream(nursery, stream, async def serve_websocket(handler, host, port, ssl_context, *, handler_nursery=None, message_queue_size=MESSAGE_QUEUE_SIZE, - max_message_size=MAX_MESSAGE_SIZE, connect_timeout=CONN_TIMEOUT, + max_message_size=MAX_MESSAGE_SIZE, receive_buffer_size=RECEIVE_BYTES, + connect_timeout=CONN_TIMEOUT, disconnect_timeout=CONN_TIMEOUT, task_status=trio.TASK_STATUS_IGNORED): ''' Serve a WebSocket over TCP. @@ -412,6 +443,9 @@ async def serve_websocket(handler, host, port, ssl_context, *, :param int max_message_size: The maximum message size as measured by ``len()``. If a message is received that is larger than this size, then the connection is closed with code 1009 (Message Too Big). + :param Optional[int] receive_buffer_size: The buffer size we use to + receive messages internally. None to let trio choose. Defaults + to 4 KiB. :param float connect_timeout: The number of seconds to wait for a client to finish connection handshake before timing out. :param float disconnect_timeout: The number of seconds to wait for a client @@ -427,7 +461,9 @@ async def serve_websocket(handler, host, port, ssl_context, *, listeners = await open_tcp_listeners() server = WebSocketServer(handler, listeners, handler_nursery=handler_nursery, message_queue_size=message_queue_size, - max_message_size=max_message_size, connect_timeout=connect_timeout, + max_message_size=max_message_size, + receive_buffer_size=receive_buffer_size, + connect_timeout=connect_timeout, disconnect_timeout=disconnect_timeout) await server.run(task_status=task_status) @@ -688,7 +724,7 @@ def __init__(self, stream, ws_connection, *, host=None, path=None, client_subprotocols=None, client_extra_headers=None, message_queue_size=MESSAGE_QUEUE_SIZE, max_message_size=MAX_MESSAGE_SIZE, - receive_buffer_size=DEFAULT_RECEIVE_BYTES): + receive_buffer_size=RECEIVE_BYTES): ''' Constructor. @@ -1321,7 +1357,8 @@ class WebSocketServer: def __init__(self, handler, listeners, *, handler_nursery=None, message_queue_size=MESSAGE_QUEUE_SIZE, - max_message_size=MAX_MESSAGE_SIZE, connect_timeout=CONN_TIMEOUT, + max_message_size=MAX_MESSAGE_SIZE, receive_buffer_size=RECEIVE_BYTES, + connect_timeout=CONN_TIMEOUT, disconnect_timeout=CONN_TIMEOUT): ''' Constructor. @@ -1338,6 +1375,9 @@ def __init__(self, handler, listeners, *, handler_nursery=None, :param handler_nursery: An optional nursery to spawn connection tasks inside of. If ``None``, then a new nursery will be created internally. + :param Optional[int] receive_buffer_size: The buffer size we use to + receive messages internally. None to let trio choose. Defaults + to 4 KiB. :param float connect_timeout: The number of seconds to wait for a client to finish connection handshake before timing out. :param float disconnect_timeout: The number of seconds to wait for a client @@ -1350,6 +1390,7 @@ def __init__(self, handler, listeners, *, handler_nursery=None, self._listeners = listeners self._message_queue_size = message_queue_size self._max_message_size = max_message_size + self._receive_buffer_size = receive_buffer_size self._connect_timeout = connect_timeout self._disconnect_timeout = disconnect_timeout @@ -1432,7 +1473,8 @@ async def _handle_connection(self, stream): connection = WebSocketConnection(stream, WSConnection(ConnectionType.SERVER), message_queue_size=self._message_queue_size, - max_message_size=self._max_message_size) + max_message_size=self._max_message_size, + receive_buffer_size=self._receive_buffer_size) nursery.start_soon(connection._reader_task) with trio.move_on_after(self._connect_timeout) as connect_scope: request = await connection._get_request() From b4c4fe7926c141b1b90b2bdf984740ea159c0194 Mon Sep 17 00:00:00 2001 From: A5rocks Date: Fri, 31 Jan 2025 09:54:29 +0900 Subject: [PATCH 4/4] Fix type hints --- trio_websocket/_impl.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/trio_websocket/_impl.py b/trio_websocket/_impl.py index b0b692b..dbd28c3 100644 --- a/trio_websocket/_impl.py +++ b/trio_websocket/_impl.py @@ -117,7 +117,7 @@ async def open_websocket( extra_headers: Optional[list[tuple[bytes,bytes]]] = None, message_queue_size: int = MESSAGE_QUEUE_SIZE, max_message_size: int = MAX_MESSAGE_SIZE, - receive_buffer_size: int = RECEIVE_BYTES, + receive_buffer_size: Union[None, int] = RECEIVE_BYTES, connect_timeout: float = CONN_TIMEOUT, disconnect_timeout: float = CONN_TIMEOUT ) -> AsyncGenerator[WebSocketConnection, None]: @@ -316,7 +316,7 @@ async def connect_websocket( extra_headers: list[tuple[bytes, bytes]] | None = None, message_queue_size: int = MESSAGE_QUEUE_SIZE, max_message_size: int = MAX_MESSAGE_SIZE, - receive_buffer_size: int = RECEIVE_BYTES, + receive_buffer_size: Union[None, int] = RECEIVE_BYTES, ) -> WebSocketConnection: ''' Return an open WebSocket client connection to a host. @@ -394,7 +394,7 @@ def open_websocket_url( max_message_size: int = MAX_MESSAGE_SIZE, connect_timeout: float = CONN_TIMEOUT, disconnect_timeout: float = CONN_TIMEOUT, - receive_buffer_size: int = RECEIVE_BYTES, + receive_buffer_size: Union[None, int] = RECEIVE_BYTES, ) -> AbstractAsyncContextManager[WebSocketConnection]: ''' Open a WebSocket client connection to a URL. @@ -447,7 +447,7 @@ async def connect_websocket_url( extra_headers: list[tuple[bytes, bytes]] | None = None, message_queue_size: int = MESSAGE_QUEUE_SIZE, max_message_size: int = MAX_MESSAGE_SIZE, - receive_buffer_size: int = RECEIVE_BYTES, + receive_buffer_size: Union[None, int] = RECEIVE_BYTES, ) -> WebSocketConnection: ''' Return an open WebSocket client connection to a URL. @@ -540,7 +540,7 @@ async def wrap_client_stream( extra_headers: list[tuple[bytes, bytes]] | None = None, message_queue_size: int = MESSAGE_QUEUE_SIZE, max_message_size: int = MAX_MESSAGE_SIZE, - receive_buffer_size: int = RECEIVE_BYTES, + receive_buffer_size: Union[None, int] = RECEIVE_BYTES, ) -> WebSocketConnection: ''' Wrap an arbitrary stream in a WebSocket connection. @@ -587,7 +587,7 @@ async def wrap_server_stream( stream: trio.abc.Stream, message_queue_size: int = MESSAGE_QUEUE_SIZE, max_message_size: int = MAX_MESSAGE_SIZE, - receive_buffer_size: int = RECEIVE_BYTES, + receive_buffer_size: Union[None, int] = RECEIVE_BYTES, ) -> WebSocketRequest: ''' Wrap an arbitrary stream in a server-side WebSocket. @@ -629,7 +629,7 @@ async def serve_websocket( handler_nursery: trio.Nursery | None = None, message_queue_size: int = MESSAGE_QUEUE_SIZE, max_message_size: int = MAX_MESSAGE_SIZE, - receive_buffer_size: int = RECEIVE_BYTES, + receive_buffer_size: Union[None, int] = RECEIVE_BYTES, connect_timeout: float = CONN_TIMEOUT, disconnect_timeout: float = CONN_TIMEOUT, task_status: trio.TaskStatus[WebSocketServer] = trio.TASK_STATUS_IGNORED, @@ -994,7 +994,7 @@ def __init__( client_extra_headers: list[tuple[bytes, bytes]] | None = None, message_queue_size: int = MESSAGE_QUEUE_SIZE, max_message_size: int = MAX_MESSAGE_SIZE, - receive_buffer_size=RECEIVE_BYTES, + receive_buffer_size: Union[None, int] = RECEIVE_BYTES, ) -> None: ''' Constructor. @@ -1660,7 +1660,7 @@ def __init__( handler_nursery: trio.Nursery | None = None, message_queue_size: int = MESSAGE_QUEUE_SIZE, max_message_size: int = MAX_MESSAGE_SIZE, - receive_buffer_size: int = RECEIVE_BYTES, + receive_buffer_size: Union[None, int] = RECEIVE_BYTES, connect_timeout: float = CONN_TIMEOUT, disconnect_timeout: float = CONN_TIMEOUT, ) -> None: