diff --git a/Makefile b/Makefile index f97c321..67d1bee 100644 --- a/Makefile +++ b/Makefile @@ -11,10 +11,11 @@ test: $(PYTHON) -m pytest --cov=trio_websocket --cov-report=term-missing --no-cov-on-fail lint: + $(PYTHON) -m black trio_websocket/ tests/ autobahn/ examples/ $(PYTHON) -m pylint trio_websocket/ tests/ autobahn/ examples/ typecheck: - $(PYTHON) -m mypy --explicit-package-bases trio_websocket tests autobahn examples + $(PYTHON) -m mypy publish: rm -fr build dist .egg trio_websocket.egg-info diff --git a/autobahn/client.py b/autobahn/client.py index d93be1c..027bb4e 100644 --- a/autobahn/client.py +++ b/autobahn/client.py @@ -1,7 +1,8 @@ -''' +""" This test client runs against the Autobahn test server. It is based on the test_client.py in wsproto. -''' +""" + import argparse import json import logging @@ -11,28 +12,28 @@ from trio_websocket import open_websocket_url, ConnectionClosed -AGENT = 'trio-websocket' +AGENT = "trio-websocket" MAX_MESSAGE_SIZE = 16 * 1024 * 1024 logging.basicConfig(level=logging.INFO) -logger = logging.getLogger('client') +logger = logging.getLogger("client") -async def get_case_count(url): - url = url + '/getCaseCount' +async def get_case_count(url: str) -> int: + url = url + "/getCaseCount" async with open_websocket_url(url) as conn: case_count = await conn.get_message() - logger.info('Case count=%s', case_count) + logger.info("Case count=%s", case_count) return int(case_count) -async def get_case_info(url, case): - url = f'{url}/getCaseInfo?case={case}' +async def get_case_info(url: str, case: str) -> object: + url = f"{url}/getCaseInfo?case={case}" async with open_websocket_url(url) as conn: return json.loads(await conn.get_message()) -async def run_case(url, case): - url = f'{url}/runCase?case={case}&agent={AGENT}' +async def run_case(url: str, case: str) -> None: + url = f"{url}/runCase?case={case}&agent={AGENT}" try: async with open_websocket_url(url, max_message_size=MAX_MESSAGE_SIZE) as conn: while True: @@ -42,16 +43,16 @@ async def run_case(url, case): pass -async def update_reports(url): - url = f'{url}/updateReports?agent={AGENT}' +async def update_reports(url: str) -> None: + url = f"{url}/updateReports?agent={AGENT}" async with open_websocket_url(url) as conn: # This command runs as soon as we connect to it, so we don't need to # send any messages. pass -async def run_tests(args): - logger = logging.getLogger('trio-websocket') +async def run_tests(args: argparse.Namespace) -> None: + logger = logging.getLogger("trio-websocket") if args.debug_cases: # Don't fetch case count when debugging a subset of test cases. It adds # noise to the debug logging. @@ -62,7 +63,10 @@ async def run_tests(args): test_cases = list(range(1, case_count + 1)) exception_cases = [] for case in test_cases: - case_id = (await get_case_info(args.url, case))['id'] + result = await get_case_info(args.url, case) + assert isinstance(result, dict) + case_id = result["id"] + assert isinstance(case_id, int) if case_count: logger.info("Running test case %s (%d of %d)", case_id, case, case_count) else: @@ -71,28 +75,37 @@ async def run_tests(args): try: await run_case(args.url, case) except Exception: # pylint: disable=broad-exception-caught - logger.exception(' runtime exception during test case %s (%d)', case_id, case) + logger.exception( + " runtime exception during test case %s (%d)", case_id, case + ) exception_cases.append(case_id) logger.setLevel(logging.INFO) - logger.info('Updating report') + logger.info("Updating report") await update_reports(args.url) if exception_cases: - logger.error('Runtime exception in %d of %d test cases: %s', - len(exception_cases), len(test_cases), exception_cases) + logger.error( + "Runtime exception in %d of %d test cases: %s", + len(exception_cases), + len(test_cases), + exception_cases, + ) sys.exit(1) -def parse_args(): - ''' Parse command line arguments. ''' - parser = argparse.ArgumentParser(description='Autobahn client for' - ' trio-websocket') - parser.add_argument('url', help='WebSocket URL for server') +def parse_args() -> argparse.Namespace: + """Parse command line arguments.""" + parser = argparse.ArgumentParser(description="Autobahn client for trio-websocket") + parser.add_argument("url", help="WebSocket URL for server") # TODO: accept case ID's rather than indices - parser.add_argument('debug_cases', type=int, nargs='*', help='Run' - ' individual test cases with debug logging (optional)') + parser.add_argument( + "debug_cases", + type=int, + nargs="*", + help="Run individual test cases with debug logging (optional)", + ) return parser.parse_args() -if __name__ == '__main__': +if __name__ == "__main__": args = parse_args() trio.run(run_tests, args) diff --git a/autobahn/server.py b/autobahn/server.py index ff23846..223248d 100644 --- a/autobahn/server.py +++ b/autobahn/server.py @@ -1,4 +1,4 @@ -''' +""" This simple WebSocket server responds to text messages by reversing each message string and sending it back. @@ -7,34 +7,36 @@ To use SSL/TLS: install the `trustme` package from PyPI and run the `generate-cert.py` script in this directory. -''' +""" + import argparse import logging import trio from trio_websocket import serve_websocket, ConnectionClosed, WebSocketRequest -BIND_IP = '0.0.0.0' +BIND_IP = "0.0.0.0" BIND_PORT = 9000 MAX_MESSAGE_SIZE = 16 * 1024 * 1024 logging.basicConfig() -logger = logging.getLogger('client') +logger = logging.getLogger("client") logger.setLevel(logging.INFO) connection_count = 0 -async def main(): - ''' Main entry point. ''' - logger.info('Starting websocket server on ws://%s:%d', BIND_IP, BIND_PORT) - await serve_websocket(handler, BIND_IP, BIND_PORT, ssl_context=None, - max_message_size=MAX_MESSAGE_SIZE) +async def main() -> None: + """Main entry point.""" + logger.info("Starting websocket server on ws://%s:%d", BIND_IP, BIND_PORT) + await serve_websocket( + handler, BIND_IP, BIND_PORT, ssl_context=None, max_message_size=MAX_MESSAGE_SIZE + ) -async def handler(request: WebSocketRequest): - ''' Reverse incoming websocket messages and send them back. ''' +async def handler(request: WebSocketRequest) -> None: + """Reverse incoming websocket messages and send them back.""" global connection_count # pylint: disable=global-statement connection_count += 1 - logger.info('Connection #%d', connection_count) + logger.info("Connection #%d", connection_count) ws = await request.accept() while True: try: @@ -43,20 +45,22 @@ async def handler(request: WebSocketRequest): except ConnectionClosed: break except Exception: # pylint: disable=broad-exception-caught - logger.exception(' runtime exception handling connection #%d', connection_count) + logger.exception( + " runtime exception handling connection #%d", connection_count + ) -def parse_args(): - ''' Parse command line arguments. ''' - parser = argparse.ArgumentParser(description='Autobahn server for' - ' trio-websocket') - parser.add_argument('-d', '--debug', action='store_true', - help='WebSocket URL for server') +def parse_args() -> argparse.Namespace: + """Parse command line arguments.""" + parser = argparse.ArgumentParser(description="Autobahn server for trio-websocket") + parser.add_argument( + "-d", "--debug", action="store_true", help="WebSocket URL for server" + ) return parser.parse_args() -if __name__ == '__main__': +if __name__ == "__main__": args = parse_args() if args.debug: - logging.getLogger('trio-websocket').setLevel(logging.DEBUG) + logging.getLogger("trio-websocket").setLevel(logging.DEBUG) trio.run(main) diff --git a/docs/conf.py b/docs/conf.py index 649051b..88a2596 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -19,11 +19,12 @@ # -- Project information ----------------------------------------------------- -project = 'Trio WebSocket' -copyright = '2018, Hyperion Gray' -author = 'Hyperion Gray' +project = "Trio WebSocket" +copyright = "2018, Hyperion Gray" +author = "Hyperion Gray" from trio_websocket._version import __version__ as version + release = version @@ -37,22 +38,22 @@ # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.intersphinx', - 'sphinxcontrib_trio', + "sphinx.ext.autodoc", + "sphinx.ext.intersphinx", + "sphinxcontrib_trio", ] # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # The suffix(es) of source filenames. # You can specify multiple suffix as a list of string: # # source_suffix = ['.rst', '.md'] -source_suffix = '.rst' +source_suffix = ".rst" # The master toctree document. -master_doc = 'index' +master_doc = "index" # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. @@ -64,7 +65,7 @@ # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. -exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] # The name of the Pygments (syntax highlighting) style to use. pygments_style = None @@ -75,7 +76,7 @@ # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # -html_theme = 'sphinx_rtd_theme' +html_theme = "sphinx_rtd_theme" # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the @@ -86,7 +87,7 @@ # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +html_static_path = ["_static"] # Custom sidebar templates, must be a dictionary that maps document names # to template names. @@ -102,26 +103,22 @@ # -- Options for HTMLHelp output --------------------------------------------- # Output file base name for HTML help builder. -htmlhelp_basename = 'TrioWebSocketdoc' +htmlhelp_basename = "TrioWebSocketdoc" # -- Options for LaTeX output ------------------------------------------------ latex_elements = { # The paper size ('letterpaper' or 'a4paper'). - # # 'papersize': 'letterpaper', - - # The font size ('10pt', '11pt' or '12pt'). # + # The font size ('10pt', '11pt' or '12pt'). # 'pointsize': '10pt', - - # Additional stuff for the LaTeX preamble. # + # Additional stuff for the LaTeX preamble. # 'preamble': '', - - # Latex figure (float) alignment # + # Latex figure (float) alignment # 'figure_align': 'htbp', } @@ -129,8 +126,13 @@ # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ - (master_doc, 'TrioWebSocket.tex', 'Trio WebSocket Documentation', - 'Hyperion Gray', 'manual'), + ( + master_doc, + "TrioWebSocket.tex", + "Trio WebSocket Documentation", + "Hyperion Gray", + "manual", + ), ] @@ -138,10 +140,7 @@ # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). -man_pages = [ - (master_doc, 'triowebsocket', 'Trio WebSocket Documentation', - [author], 1) -] +man_pages = [(master_doc, "triowebsocket", "Trio WebSocket Documentation", [author], 1)] # -- Options for Texinfo output ---------------------------------------------- @@ -150,9 +149,15 @@ # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ - (master_doc, 'TrioWebSocket', 'Trio WebSocket Documentation', - author, 'TrioWebSocket', 'One line description of project.', - 'Miscellaneous'), + ( + master_doc, + "TrioWebSocket", + "Trio WebSocket Documentation", + author, + "TrioWebSocket", + "One line description of project.", + "Miscellaneous", + ), ] @@ -171,10 +176,10 @@ # epub_uid = '' # A list of files that should not be packed into the epub file. -epub_exclude_files = ['search.html'] +epub_exclude_files = ["search.html"] # -- Extension configuration ------------------------------------------------- intersphinx_mapping = { - 'trio': ('https://trio.readthedocs.io/en/stable/', None), + "trio": ("https://trio.readthedocs.io/en/stable/", None), } diff --git a/examples/client.py b/examples/client.py index 030c12b..c145357 100644 --- a/examples/client.py +++ b/examples/client.py @@ -1,69 +1,80 @@ -''' +""" This interactive WebSocket client allows the user to send frames to a WebSocket server, including text message, ping, and close frames. To use SSL/TLS: install the `trustme` package from PyPI and run the `generate-cert.py` script in this directory. -''' +""" + import argparse import logging import pathlib import ssl import sys import urllib.parse +from typing import NoReturn import trio -from trio_websocket import open_websocket_url, ConnectionClosed, HandshakeError +from trio_websocket import ( + open_websocket_url, + ConnectionClosed, + HandshakeError, + WebSocketConnection, + CloseReason, +) logging.basicConfig(level=logging.DEBUG) here = pathlib.Path(__file__).parent -def commands(): - ''' Print the supported commands. ''' - print('Commands: ') - print('send -> send message') - print('ping -> send ping with payload') - print('close [] -> politely close connection with optional reason') +def commands() -> None: + """Print the supported commands.""" + print("Commands: ") + print("send -> send message") + print("ping -> send ping with payload") + print("close [] -> politely close connection with optional reason") print() -def parse_args(): - ''' Parse command line arguments. ''' - parser = argparse.ArgumentParser(description='Example trio-websocket client') - parser.add_argument('--heartbeat', action='store_true', - help='Create a heartbeat task') - parser.add_argument('url', help='WebSocket URL to connect to') +def parse_args() -> argparse.Namespace: + """Parse command line arguments.""" + parser = argparse.ArgumentParser(description="Example trio-websocket client") + parser.add_argument( + "--heartbeat", action="store_true", help="Create a heartbeat task" + ) + parser.add_argument("url", help="WebSocket URL to connect to") return parser.parse_args() -async def main(args): - ''' Main entry point, returning False in the case of logged error. ''' - if urllib.parse.urlsplit(args.url).scheme == 'wss': +async def main(args: argparse.Namespace) -> bool: + """Main entry point, returning False in the case of logged error.""" + if urllib.parse.urlsplit(args.url).scheme == "wss": # Configure SSL context to handle our self-signed certificate. Most # clients won't need to do this. try: ssl_context = ssl.create_default_context() - ssl_context.load_verify_locations(here / 'fake.ca.pem') + ssl_context.load_verify_locations(here / "fake.ca.pem") except FileNotFoundError: - logging.error('Did not find file "fake.ca.pem". You need to run' - ' generate-cert.py') + logging.error( + 'Did not find file "fake.ca.pem". You need to run generate-cert.py' + ) return False else: ssl_context = None try: - logging.debug('Connecting to WebSocket…') + logging.debug("Connecting to WebSocket…") async with open_websocket_url(args.url, ssl_context) as conn: await handle_connection(conn, args.heartbeat) except HandshakeError as e: - logging.error('Connection attempt failed: %s', e) + logging.error("Connection attempt failed: %s", e) return False + return True -async def handle_connection(ws, use_heartbeat): - ''' Handle the connection. ''' - logging.debug('Connected!') +async def handle_connection(ws: WebSocketConnection, use_heartbeat: bool) -> None: + """Handle the connection.""" + logging.debug("Connected!") try: async with trio.open_nursery() as nursery: if use_heartbeat: @@ -71,12 +82,15 @@ async def handle_connection(ws, use_heartbeat): nursery.start_soon(get_commands, ws) nursery.start_soon(get_messages, ws) except ConnectionClosed as cc: - reason = '' if cc.reason.reason is None else f'"{cc.reason.reason}"' - print(f'Closed: {cc.reason.code}/{cc.reason.name} {reason}') + assert isinstance(cc.reason, CloseReason) + reason = "" if cc.reason.reason is None else f'"{cc.reason.reason}"' + print(f"Closed: {cc.reason.code}/{cc.reason.name} {reason}") -async def heartbeat(ws, timeout, interval): - ''' +async def heartbeat( + ws: WebSocketConnection, timeout: float, interval: float +) -> NoReturn: + """ Send periodic pings on WebSocket ``ws``. Wait up to ``timeout`` seconds to send a ping and receive a pong. Raises @@ -92,28 +106,27 @@ async def heartbeat(ws, timeout, interval): :raises: ``ConnectionClosed`` if ``ws`` is closed. :raises: ``TooSlowError`` if the timeout expires. :returns: This function runs until cancelled. - ''' + """ while True: with trio.fail_after(timeout): await ws.ping() await trio.sleep(interval) -async def get_commands(ws): - ''' In a loop: get a command from the user and execute it. ''' +async def get_commands(ws: WebSocketConnection) -> None: + """In a loop: get a command from the user and execute it.""" while True: - cmd = await trio.to_thread.run_sync(input, 'cmd> ', - cancellable=True) - if cmd.startswith('ping'): - payload = cmd[5:].encode('utf8') or None + cmd = await trio.to_thread.run_sync(input, "cmd> ") + if cmd.startswith("ping"): + payload = cmd[5:].encode("utf8") or None await ws.ping(payload) - elif cmd.startswith('send'): + elif cmd.startswith("send"): message = cmd[5:] or None if message is None: logging.error('The "send" command requires a message.') else: await ws.send_message(message) - elif cmd.startswith('close'): + elif cmd.startswith("close"): reason = cmd[6:] or None await ws.aclose(code=1000, reason=reason) break @@ -123,14 +136,14 @@ async def get_commands(ws): await trio.sleep(0.25) -async def get_messages(ws): - ''' In a loop: get a WebSocket message and print it out. ''' +async def get_messages(ws: WebSocketConnection) -> None: + """In a loop: get a WebSocket message and print it out.""" while True: message = await ws.get_message() - print(f'message: {message}') + print(f"message: {message!r}") -if __name__ == '__main__': +if __name__ == "__main__": try: if not trio.run(main, parse_args()): sys.exit(1) diff --git a/examples/generate-cert.py b/examples/generate-cert.py index cc21698..2f5429b 100644 --- a/examples/generate-cert.py +++ b/examples/generate-cert.py @@ -3,22 +3,23 @@ import trustme -def main(): + +def main() -> None: here = pathlib.Path(__file__).parent - ca_path = here / 'fake.ca.pem' - server_path = here / 'fake.server.pem' + ca_path = here / "fake.ca.pem" + server_path = here / "fake.server.pem" if ca_path.exists() and server_path.exists(): - print('The CA ceritificate and server certificate already exist.') + print("The CA ceritificate and server certificate already exist.") sys.exit(1) - print('Creating self-signed certificate for localhost/127.0.0.1:') + print("Creating self-signed certificate for localhost/127.0.0.1:") ca_cert = trustme.CA() ca_cert.cert_pem.write_to_path(ca_path) - print(f' * CA certificate: {ca_path}') - server_cert = ca_cert.issue_server_cert('localhost', '127.0.0.1') + print(f" * CA certificate: {ca_path}") + server_cert = ca_cert.issue_server_cert("localhost", "127.0.0.1") server_cert.private_key_and_cert_chain_pem.write_to_path(server_path) - print(f' * Server certificate: {server_path}') - print('Done') + print(f" * Server certificate: {server_path}") + print("Done") -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/server.py b/examples/server.py index 611d89b..1af40f4 100644 --- a/examples/server.py +++ b/examples/server.py @@ -1,4 +1,4 @@ -''' +""" This simple WebSocket server responds to text messages by reversing each message string and sending it back. @@ -7,14 +7,15 @@ To use SSL/TLS: install the `trustme` package from PyPI and run the `generate-cert.py` script in this directory. -''' +""" + import argparse import logging import pathlib import ssl import trio -from trio_websocket import serve_websocket, ConnectionClosed +from trio_websocket import serve_websocket, ConnectionClosed, WebSocketRequest logging.basicConfig(level=logging.DEBUG) @@ -22,34 +23,38 @@ here = pathlib.Path(__file__).parent -def parse_args(): - ''' Parse command line arguments. ''' - parser = argparse.ArgumentParser(description='Example trio-websocket client') - parser.add_argument('--ssl', action='store_true', help='Use SSL') - parser.add_argument('host', help='Host interface to bind. If omitted, ' - 'then bind all interfaces.', nargs='?') - parser.add_argument('port', type=int, help='Port to bind.') +def parse_args() -> argparse.Namespace: + """Parse command line arguments.""" + parser = argparse.ArgumentParser(description="Example trio-websocket client") + parser.add_argument("--ssl", action="store_true", help="Use SSL") + parser.add_argument( + "host", + help="Host interface to bind. If omitted, then bind all interfaces.", + nargs="?", + ) + parser.add_argument("port", type=int, help="Port to bind.") return parser.parse_args() -async def main(args): - ''' Main entry point. ''' - logging.info('Starting websocket server…') +async def main(args: argparse.Namespace) -> None: + """Main entry point.""" + logging.info("Starting websocket server…") if args.ssl: ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) try: - ssl_context.load_cert_chain(here / 'fake.server.pem') + ssl_context.load_cert_chain(here / "fake.server.pem") except FileNotFoundError: - logging.error('Did not find file "fake.server.pem". You need to run' - ' generate-cert.py') + logging.error( + 'Did not find file "fake.server.pem". You need to run generate-cert.py' + ) else: ssl_context = None - host = None if args.host == '*' else args.host + host = None if args.host == "*" else args.host await serve_websocket(handler, host, args.port, ssl_context) -async def handler(request): - ''' Reverse incoming websocket messages and send them back. ''' +async def handler(request: WebSocketRequest) -> None: + """Reverse incoming websocket messages and send them back.""" logging.info('Handler starting on path "%s"', request.path) ws = await request.accept() while True: @@ -57,12 +62,12 @@ async def handler(request): message = await ws.get_message() await ws.send_message(message[::-1]) except ConnectionClosed: - logging.info('Connection closed') + logging.info("Connection closed") break - logging.info('Handler exiting') + logging.info("Handler exiting") -if __name__ == '__main__': +if __name__ == "__main__": try: trio.run(main, parse_args()) except KeyboardInterrupt: diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..95d5ff9 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,13 @@ +[tool.mypy] +explicit_package_bases = true +files = ["trio_websocket", "tests", "autobahn", "examples"] +show_column_numbers = true +show_error_codes = true +show_traceback = true +disallow_any_decorated = true +disallow_any_unimported = true +ignore_missing_imports = true +local_partial_types = true +no_implicit_optional = true +strict = true +warn_unreachable = true diff --git a/requirements-dev-full.txt b/requirements-dev-full.txt index dbc8570..c554717 100644 --- a/requirements-dev-full.txt +++ b/requirements-dev-full.txt @@ -15,6 +15,10 @@ attrs==23.2.0 # trio babel==2.15.0 # via sphinx +black==24.4.2 + # via -r requirements-dev.in +bleach==6.0.0 + # via readme-renderer backports-tarfile==1.2.0 # via jaraco-context build==1.2.1 diff --git a/requirements-dev.in b/requirements-dev.in index 922fb76..30907fd 100644 --- a/requirements-dev.in +++ b/requirements-dev.in @@ -1,5 +1,6 @@ # requirements for `make test` and dependency management attrs>=19.2.0 +black>=24.4.2 pip-tools>=5.5.0 pytest>=4.6 pytest-cov diff --git a/requirements-extras.in b/requirements-extras.in index eb2cd30..10fe1ef 100644 --- a/requirements-extras.in +++ b/requirements-extras.in @@ -1,4 +1,5 @@ # requirements for `make lint/docs/publish` +black mypy pylint sphinx diff --git a/setup.py b/setup.py index 17a21f9..2166c5e 100644 --- a/setup.py +++ b/setup.py @@ -10,42 +10,44 @@ # Get description -with (here / 'README.md').open(encoding='utf-8') as f: +with (here / "README.md").open(encoding="utf-8") as f: long_description = f.read() setup( - name='trio-websocket', - version=version['__version__'], - description='WebSocket library for Trio', + name="trio-websocket", + version=version["__version__"], + description="WebSocket library for Trio", long_description=long_description, - long_description_content_type='text/markdown', - url='https://github.com/python-trio/trio-websocket', - author='Mark E. Haase', - author_email='mehaase@gmail.com', + long_description_content_type="text/markdown", + url="https://github.com/python-trio/trio-websocket", + author="Mark E. Haase", + author_email="mehaase@gmail.com", classifiers=[ # See https://pypi.org/classifiers/ - 'Development Status :: 3 - Alpha', - 'Intended Audience :: Developers', - 'Topic :: Software Development :: Libraries', - 'License :: OSI Approved :: MIT License', - 'Programming Language :: Python :: 3.8', - 'Programming Language :: Python :: 3.9', - 'Programming Language :: Python :: 3.10', - 'Programming Language :: Python :: 3.11', - 'Programming Language :: Python :: 3.12', - 'Programming Language :: Python :: Implementation :: CPython', - 'Programming Language :: Python :: Implementation :: PyPy', + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "Topic :: Software Development :: Libraries", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", + "Typing :: Typed", ], python_requires=">=3.8", - keywords='websocket client server trio', - packages=find_packages(exclude=['docs', 'examples', 'tests']), + keywords="websocket client server trio", + packages=find_packages(exclude=["docs", "examples", "tests"]), + package_data={"trio-websocket": ["py.typed"]}, install_requires=[ 'exceptiongroup; python_version<"3.11"', - 'trio>=0.11', - 'wsproto>=0.14', + "trio>=0.11", + "wsproto>=0.14", ], project_urls={ - 'Bug Reports': 'https://github.com/python-trio/trio-websocket/issues', - 'Source': 'https://github.com/python-trio/trio-websocket', + "Bug Reports": "https://github.com/python-trio/trio-websocket/issues", + "Source": "https://github.com/python-trio/trio-websocket", }, ) diff --git a/tests/test_connection.py b/tests/test_connection.py index 6cccefa..0583523 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1,4 +1,4 @@ -''' +""" Unit tests for trio_websocket. Many of these tests involve networking, i.e. real TCP sockets. To maximize @@ -28,14 +28,19 @@ call ``ws.get_message()`` without actually sending it a message. This will cause the server to block until the client has sent the closing handshake. In other circumstances -''' +""" + from __future__ import annotations -from functools import partial, wraps +import copy import re import ssl import sys -from unittest.mock import patch +from collections.abc import AsyncGenerator +from functools import partial, wraps +from typing import TYPE_CHECKING, TypeVar, cast +from unittest.mock import Mock, patch +from importlib.metadata import version import attr import pytest @@ -57,34 +62,48 @@ except ImportError: pass + from trio_websocket import ( - connect_websocket, - connect_websocket_url, + CloseReason, ConnectionClosed, ConnectionRejected, ConnectionTimeout, DisconnectionTimeout, Endpoint, HandshakeError, + WebSocketConnection, + WebSocketRequest, + WebSocketServer, + connect_websocket, + connect_websocket_url, open_websocket, open_websocket_url, serve_websocket, - WebSocketConnection, - WebSocketServer, - WebSocketRequest, wrap_client_stream, - wrap_server_stream + wrap_server_stream, ) - from trio_websocket._impl import _TRIO_EXC_GROUP_TYPE if sys.version_info < (3, 11): from exceptiongroup import BaseExceptionGroup # pylint: disable=redefined-builtin -WS_PROTO_VERSION = tuple(map(int, wsproto.__version__.split('.'))) +if TYPE_CHECKING: + from collections.abc import Awaitable, Callable + from wsproto.events import Event + + from typing_extensions import ParamSpec, TypeAlias + + PS = ParamSpec("PS") -HOST = '127.0.0.1' -RESOURCE = '/resource' + StapledMemoryStream: TypeAlias = trio.StapledStream[ + trio.testing.MemorySendStream, + trio.testing.MemoryReceiveStream, + ] + +WS_PROTO_VERSION = tuple(map(int, wsproto.__version__.split("."))) + +HOST = "127.0.0.1" +RESOURCE = "/resource" DEFAULT_TEST_MAX_DURATION = 1 # Timeout tests follow a general pattern: one side waits TIMEOUT seconds for an @@ -95,30 +114,34 @@ FORCE_TIMEOUT = 2 TIMEOUT_TEST_MAX_DURATION = 3 +T = TypeVar("T") + @pytest.fixture -async def echo_server(nursery): - ''' A server that reads one message, sends back the same message, - then closes the connection. ''' - serve_fn = partial(serve_websocket, echo_request_handler, HOST, 0, - ssl_context=None) +async def echo_server(nursery: trio.Nursery) -> AsyncGenerator[WebSocketServer, None]: + """A server that reads one message, sends back the same message, + then closes the connection.""" + serve_fn = partial(serve_websocket, echo_request_handler, HOST, 0, ssl_context=None) server = await nursery.start(serve_fn) - yield server + # Cast needed because currently `nursery.start` has typing issues + # blocked by https://github.com/python/mypy/pull/17512 + yield cast(WebSocketServer, server) @pytest.fixture -async def echo_conn(echo_server): - ''' Return a client connection instance that is connected to an echo - server. ''' - async with open_websocket(HOST, echo_server.port, RESOURCE, - use_ssl=False) as conn: +async def echo_conn( + echo_server: WebSocketServer, +) -> AsyncGenerator[WebSocketConnection, None]: + """Return a client connection instance that is connected to an echo + server.""" + async with open_websocket(HOST, echo_server.port, RESOURCE, use_ssl=False) as conn: yield conn -async def echo_request_handler(request): - ''' +async def echo_request_handler(request: WebSocketRequest) -> None: + """ Accept incoming request and then pass off to echo connection handler. - ''' + """ conn = await request.accept() try: msg = await conn.get_message() @@ -128,37 +151,59 @@ async def echo_request_handler(request): class fail_after: - ''' This decorator fails if the runtime of the decorated function (as - measured by the Trio clock) exceeds the specified value. ''' - def __init__(self, seconds): + """This decorator fails if the runtime of the decorated function (as + measured by the Trio clock) exceeds the specified value.""" + + def __init__(self, seconds: int) -> None: self._seconds = seconds - def __call__(self, fn): + def __call__( + self, fn: Callable[PS, Awaitable[T]] + ) -> Callable[PS, Awaitable[T | None]]: + # Type of decorated function contains type `Any` @wraps(fn) - async def wrapper(*args, **kwargs): + async def wrapper( # type: ignore[misc] + *args: PS.args, + **kwargs: PS.kwargs, + ) -> T: with trio.move_on_after(self._seconds) as cancel_scope: - await fn(*args, **kwargs) + return await fn(*args, **kwargs) if cancel_scope.cancelled_caught: - pytest.fail(f'Test runtime exceeded the maximum {self._seconds} seconds') + pytest.fail( + f"Test runtime exceeded the maximum {self._seconds} seconds" + ) + raise AssertionError("Should be unreachable") + return wrapper @attr.s(hash=False, eq=False) -class MemoryListener(trio.abc.Listener): - closed = attr.ib(default=False) - accepted_streams: list[ - tuple[trio.abc.SendChannel[str], trio.abc.ReceiveChannel[str]] - ] = attr.ib(factory=list) - queued_streams = attr.ib(factory=lambda: trio.open_memory_channel[str](1)) - accept_hook = attr.ib(default=None) - - async def connect(self): +class MemoryListener(trio.abc.Listener["StapledMemoryStream"]): + closed: bool = attr.ib(default=False) + accepted_streams: list[StapledMemoryStream] = attr.ib(factory=list) + queued_streams: tuple[ + trio.MemorySendChannel[StapledMemoryStream], + trio.MemoryReceiveChannel[StapledMemoryStream], + ] = attr.ib(factory=lambda: trio.open_memory_channel["StapledMemoryStream"](1)) + accept_hook: Callable[[], Awaitable[object]] | None = attr.ib(default=None) + + async def connect( + self, + ) -> trio.StapledStream[ + trio.testing.MemorySendStream, + trio.testing.MemoryReceiveStream, + ]: assert not self.closed client, server = memory_stream_pair() await self.queued_streams[0].send(server) return client - async def accept(self): + async def accept( + self, + ) -> trio.StapledStream[ + trio.testing.MemorySendStream, + trio.testing.MemoryReceiveStream, + ]: await trio.sleep(0) assert not self.closed if self.accept_hook is not None: @@ -167,47 +212,49 @@ async def accept(self): self.accepted_streams.append(stream) return stream - async def aclose(self): + async def aclose(self) -> None: self.closed = True await trio.sleep(0) -async def test_endpoint_ipv4(): - e1 = Endpoint('10.105.0.2', 80, False) - assert e1.url == 'ws://10.105.0.2' +async def test_endpoint_ipv4() -> None: + e1 = Endpoint("10.105.0.2", 80, False) + assert e1.url == "ws://10.105.0.2" assert str(e1) == 'Endpoint(address="10.105.0.2", port=80, is_ssl=False)' - e2 = Endpoint('127.0.0.1', 8000, False) - assert e2.url == 'ws://127.0.0.1:8000' + e2 = Endpoint("127.0.0.1", 8000, False) + assert e2.url == "ws://127.0.0.1:8000" assert str(e2) == 'Endpoint(address="127.0.0.1", port=8000, is_ssl=False)' - e3 = Endpoint('0.0.0.0', 443, True) - assert e3.url == 'wss://0.0.0.0' + e3 = Endpoint("0.0.0.0", 443, True) + assert e3.url == "wss://0.0.0.0" assert str(e3) == 'Endpoint(address="0.0.0.0", port=443, is_ssl=True)' -async def test_listen_port_ipv6(): - e1 = Endpoint('2599:8807:6201:b7:16cf:bb9c:a6d3:51ab', 80, False) - assert e1.url == 'ws://[2599:8807:6201:b7:16cf:bb9c:a6d3:51ab]' - assert str(e1) == 'Endpoint(address="2599:8807:6201:b7:16cf:bb9c:a6d3' \ - ':51ab", port=80, is_ssl=False)' - e2 = Endpoint('::1', 8000, False) - assert e2.url == 'ws://[::1]:8000' +async def test_listen_port_ipv6() -> None: + e1 = Endpoint("2599:8807:6201:b7:16cf:bb9c:a6d3:51ab", 80, False) + assert e1.url == "ws://[2599:8807:6201:b7:16cf:bb9c:a6d3:51ab]" + assert ( + str(e1) == 'Endpoint(address="2599:8807:6201:b7:16cf:bb9c:a6d3' + ':51ab", port=80, is_ssl=False)' + ) + e2 = Endpoint("::1", 8000, False) + assert e2.url == "ws://[::1]:8000" assert str(e2) == 'Endpoint(address="::1", port=8000, is_ssl=False)' - e3 = Endpoint('::', 443, True) - assert e3.url == 'wss://[::]' + e3 = Endpoint("::", 443, True) + assert e3.url == "wss://[::]" assert str(e3) == 'Endpoint(address="::", port=443, is_ssl=True)' -async def test_server_has_listeners(nursery): - server = await nursery.start(serve_websocket, echo_request_handler, HOST, 0, - None) +async def test_server_has_listeners(nursery: trio.Nursery) -> None: + server = await nursery.start(serve_websocket, echo_request_handler, HOST, 0, None) + assert isinstance(server, WebSocketServer) assert len(server.listeners) > 0 assert isinstance(server.listeners[0], Endpoint) -async def test_serve(nursery): +async def test_serve(nursery: trio.Nursery) -> None: task = current_task() - server = await nursery.start(serve_websocket, echo_request_handler, HOST, 0, - None) + server = await nursery.start(serve_websocket, echo_request_handler, HOST, 0, None) + assert isinstance(server, WebSocketServer) port = server.port assert server.port != 0 # The server nursery begins with one task (server.listen). @@ -220,7 +267,7 @@ async def test_serve(nursery): assert len(task.child_nurseries) == no_clients_nursery_count + 1 -async def test_serve_ssl(nursery): +async def test_serve_ssl(nursery: trio.Nursery) -> None: server_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) client_context = ssl.create_default_context() ca = trustme.CA() @@ -228,21 +275,31 @@ async def test_serve_ssl(nursery): cert = ca.issue_server_cert(HOST) cert.configure_cert(server_context) - server = await nursery.start(serve_websocket, echo_request_handler, HOST, 0, - server_context) + server = await nursery.start( + serve_websocket, echo_request_handler, HOST, 0, server_context + ) + assert isinstance(server, WebSocketServer) port = server.port - async with open_websocket(HOST, port, RESOURCE, use_ssl=client_context - ) as conn: + async with open_websocket(HOST, port, RESOURCE, use_ssl=client_context) as conn: assert not conn.closed + assert isinstance(conn.local, Endpoint) assert conn.local.is_ssl + assert isinstance(conn.remote, Endpoint) assert conn.remote.is_ssl -async def test_serve_handler_nursery(nursery): +async def test_serve_handler_nursery(nursery: trio.Nursery) -> None: async with trio.open_nursery() as handler_nursery: - serve_with_nursery = partial(serve_websocket, echo_request_handler, - HOST, 0, None, handler_nursery=handler_nursery) + serve_with_nursery = partial( + serve_websocket, + echo_request_handler, + HOST, + 0, + None, + handler_nursery=handler_nursery, + ) server = await nursery.start(serve_with_nursery) + assert isinstance(server, WebSocketServer) port = server.port # The server nursery begins with one task (server.listen). assert len(nursery.child_tasks) == 1 @@ -252,25 +309,40 @@ async def test_serve_handler_nursery(nursery): assert len(handler_nursery.child_tasks) == 1 -async def test_serve_with_zero_listeners(): +async def test_serve_with_zero_listeners() -> None: with pytest.raises(ValueError): WebSocketServer(echo_request_handler, []) -async def test_serve_non_tcp_listener(nursery): - listeners = [MemoryListener()] - server = WebSocketServer(echo_request_handler, listeners) +def memory_listener() -> trio.SocketListener: + return MemoryListener() # type: ignore[return-value] + + +async def test_serve_non_tcp_listener(nursery: trio.Nursery) -> None: + listeners = [memory_listener()] + server = WebSocketServer( + echo_request_handler, + listeners, + ) await nursery.start(server.run) assert len(server.listeners) == 1 with pytest.raises(RuntimeError): server.port # pylint: disable=pointless-statement - assert server.listeners[0].startswith('MemoryListener(') + listener = server.listeners[0] + assert isinstance(listener, str) + assert listener.startswith("MemoryListener(") -async def test_serve_multiple_listeners(nursery): +async def test_serve_multiple_listeners(nursery: trio.Nursery) -> None: listener1 = (await trio.open_tcp_listeners(0, host=HOST))[0] - listener2 = MemoryListener() - server = WebSocketServer(echo_request_handler, [listener1, listener2]) + listener2 = memory_listener() + server = WebSocketServer( + echo_request_handler, + [ + listener1, + listener2, + ], + ) await nursery.start(server.run) assert len(server.listeners) == 2 with pytest.raises(RuntimeError): @@ -278,90 +350,122 @@ async def test_serve_multiple_listeners(nursery): # usable if you have exactly one listener. server.port # pylint: disable=pointless-statement # The first listener metadata is a ListenPort instance. - assert server.listeners[0].port != 0 + listener_zero = server.listeners[0] + assert isinstance(listener_zero, Endpoint) + assert listener_zero.port != 0 # The second listener metadata is a string containing the repr() of a # MemoryListener object. - assert server.listeners[1].startswith('MemoryListener(') + listener_one = server.listeners[1] + assert isinstance(listener_one, str) + assert listener_one.startswith("MemoryListener(") -async def test_client_open(echo_server): - async with open_websocket(HOST, echo_server.port, RESOURCE, use_ssl=False) \ - as conn: +async def test_client_open(echo_server: WebSocketServer) -> None: + async with open_websocket(HOST, echo_server.port, RESOURCE, use_ssl=False) as conn: assert not conn.closed assert conn.is_client - assert str(conn).startswith('client-') + assert str(conn).startswith("client-") -@pytest.mark.parametrize('path, expected_path', [ - ('/', '/'), - ('', '/'), - (RESOURCE + '/path', RESOURCE + '/path'), - (RESOURCE + '?foo=bar', RESOURCE + '?foo=bar') -]) -async def test_client_open_url(path, expected_path, echo_server): - url = f'ws://{HOST}:{echo_server.port}{path}' +@pytest.mark.parametrize( + "path, expected_path", + [ + ("/", "/"), + ("", "/"), + (RESOURCE + "/path", RESOURCE + "/path"), + (RESOURCE + "?foo=bar", RESOURCE + "?foo=bar"), + ], +) +async def test_client_open_url( + path: str, expected_path: str, echo_server: WebSocketServer +) -> None: + url = f"ws://{HOST}:{echo_server.port}{path}" async with open_websocket_url(url) as conn: assert conn.path == expected_path -async def test_client_open_invalid_url(echo_server): +async def test_client_open_invalid_url(echo_server: WebSocketServer) -> None: with pytest.raises(ValueError): - async with open_websocket_url('http://foo.com/bar'): + async with open_websocket_url("http://foo.com/bar"): pass -async def test_client_open_invalid_ssl(echo_server, nursery): - with pytest.raises(TypeError, match='`use_ssl` argument must be bool or ssl.SSLContext'): - await connect_websocket(nursery, HOST, echo_server.port, RESOURCE, use_ssl=1) - - url = f'ws://{HOST}:{echo_server.port}{RESOURCE}' - with pytest.raises(ValueError, match='^SSL context must be None for ws: URL scheme$' ): - await connect_websocket_url(nursery, url, ssl_context=ssl.SSLContext(ssl.PROTOCOL_SSLv23)) - -async def test_ascii_encoded_path_is_ok(echo_server): - path = '%D7%90%D7%91%D7%90?%D7%90%D7%9E%D7%90' - url = f'ws://{HOST}:{echo_server.port}{RESOURCE}/{path}' +async def test_client_open_invalid_ssl( + echo_server: WebSocketServer, + nursery: trio.Nursery, +) -> None: + with pytest.raises( + TypeError, match="`use_ssl` argument must be bool or ssl.SSLContext" + ): + await connect_websocket( + nursery, + HOST, + echo_server.port, + RESOURCE, + use_ssl=1, # type: ignore[arg-type] + ) + + url = f"ws://{HOST}:{echo_server.port}{RESOURCE}" + with pytest.raises( + ValueError, match="^SSL context must be None for ws: URL scheme$" + ): + await connect_websocket_url( + nursery, url, ssl_context=ssl.SSLContext(ssl.PROTOCOL_SSLv23) + ) + + +async def test_ascii_encoded_path_is_ok(echo_server: WebSocketServer) -> None: + path = "%D7%90%D7%91%D7%90?%D7%90%D7%9E%D7%90" + url = f"ws://{HOST}:{echo_server.port}{RESOURCE}/{path}" async with open_websocket_url(url) as conn: - assert conn.path == RESOURCE + '/' + path + assert conn.path == RESOURCE + "/" + path -@patch('trio_websocket._impl.open_websocket') -def test_client_open_url_options(open_websocket_mock): +# Type ignore because @patch contains `Any` +@patch("trio_websocket._impl.open_websocket") +def test_client_open_url_options( # type: ignore[misc] + open_websocket_mock: Mock, +) -> None: """open_websocket_url() must pass its options on to open_websocket()""" port = 1234 - url = f'ws://{HOST}:{port}{RESOURCE}' + url = f"ws://{HOST}:{port}{RESOURCE}" options = { - 'subprotocols': ['chat'], - 'extra_headers': [(b'X-Test-Header', b'My test header')], - 'message_queue_size': 9, - 'max_message_size': 333, - 'connect_timeout': 36, - 'disconnect_timeout': 37, + "subprotocols": ["chat"], + "extra_headers": [(b"X-Test-Header", b"My test header")], + "message_queue_size": 9, + "max_message_size": 333, + "connect_timeout": 36, + "disconnect_timeout": 37, } - open_websocket_url(url, **options) + open_websocket_url(url, **options) # type: ignore[arg-type] _, call_args, call_kwargs = open_websocket_mock.mock_calls[0] assert call_args == (HOST, port, RESOURCE) - assert not call_kwargs.pop('use_ssl') + assert not call_kwargs.pop("use_ssl") assert call_kwargs == options - open_websocket_url(url.replace('ws:', 'wss:')) + open_websocket_url(url.replace("ws:", "wss:")) _, call_args, call_kwargs = open_websocket_mock.mock_calls[1] - assert call_kwargs['use_ssl'] + assert call_kwargs["use_ssl"] -async def test_client_connect(echo_server, nursery): - conn = await connect_websocket(nursery, HOST, echo_server.port, RESOURCE, - use_ssl=False) +async def test_client_connect( + echo_server: WebSocketServer, nursery: trio.Nursery +) -> None: + conn = await connect_websocket( + nursery, HOST, echo_server.port, RESOURCE, use_ssl=False + ) assert not conn.closed -async def test_client_connect_url(echo_server, nursery): - url = f'ws://{HOST}:{echo_server.port}{RESOURCE}' +async def test_client_connect_url( + echo_server: WebSocketServer, nursery: trio.Nursery +) -> None: + url = f"ws://{HOST}:{echo_server.port}{RESOURCE}" conn = await connect_websocket_url(nursery, url) assert not conn.closed -async def test_connection_has_endpoints(echo_conn): +async def test_connection_has_endpoints(echo_conn: WebSocketConnection) -> None: async with echo_conn: assert isinstance(echo_conn.local, Endpoint) assert str(echo_conn.local.address) == HOST @@ -375,91 +479,111 @@ async def test_connection_has_endpoints(echo_conn): @fail_after(1) -async def test_handshake_has_endpoints(nursery): - async def handler(request): +async def test_handshake_has_endpoints(nursery: trio.Nursery) -> None: + async def handler(request: WebSocketRequest) -> None: + assert isinstance(server, WebSocketServer) + assert isinstance(request.local, Endpoint) assert str(request.local.address) == HOST assert request.local.port == server.port assert not request.local.is_ssl + assert isinstance(request.remote, Endpoint) assert str(request.remote.address) == HOST assert not request.remote.is_ssl await request.accept() server = await nursery.start(serve_websocket, handler, HOST, 0, None) + assert isinstance(server, WebSocketServer) async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False): pass -async def test_handshake_subprotocol(nursery): - async def handler(request): - assert request.proposed_subprotocols == ('chat', 'file') - server_ws = await request.accept(subprotocol='chat') - assert server_ws.subprotocol == 'chat' +async def test_handshake_subprotocol(nursery: trio.Nursery) -> None: + async def handler(request: WebSocketRequest) -> None: + assert request.proposed_subprotocols == ("chat", "file") + server_ws = await request.accept(subprotocol="chat") + assert server_ws.subprotocol == "chat" server = await nursery.start(serve_websocket, handler, HOST, 0, None) - async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False, - subprotocols=('chat', 'file')) as client_ws: - assert client_ws.subprotocol == 'chat' + assert isinstance(server, WebSocketServer) + async with open_websocket( + HOST, server.port, RESOURCE, use_ssl=False, subprotocols=("chat", "file") + ) as client_ws: + assert client_ws.subprotocol == "chat" -async def test_handshake_path(nursery): - async def handler(request): +async def test_handshake_path(nursery: trio.Nursery) -> None: + async def handler(request: WebSocketRequest) -> None: assert request.path == RESOURCE server_ws = await request.accept() assert server_ws.path == RESOURCE server = await nursery.start(serve_websocket, handler, HOST, 0, None) - async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False, - ) as client_ws: + assert isinstance(server, WebSocketServer) + async with open_websocket( + HOST, + server.port, + RESOURCE, + use_ssl=False, + ) as client_ws: assert client_ws.path == RESOURCE @fail_after(1) -async def test_handshake_client_headers(nursery): - async def handler(request): +async def test_handshake_client_headers(nursery: trio.Nursery) -> None: + async def handler(request: WebSocketRequest) -> None: headers = dict(request.headers) - assert b'x-test-header' in headers - assert headers[b'x-test-header'] == b'My test header' + assert b"x-test-header" in headers + assert headers[b"x-test-header"] == b"My test header" server_ws = await request.accept() - await server_ws.send_message('test') + await server_ws.send_message("test") server = await nursery.start(serve_websocket, handler, HOST, 0, None) - headers = [(b'X-Test-Header', b'My test header')] - async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False, - extra_headers=headers) as client_ws: + assert isinstance(server, WebSocketServer) + headers = [(b"X-Test-Header", b"My test header")] + async with open_websocket( + HOST, server.port, RESOURCE, use_ssl=False, extra_headers=headers + ) as client_ws: await client_ws.get_message() @fail_after(1) -async def test_handshake_server_headers(nursery): - async def handler(request): - headers = [('X-Test-Header', 'My test header')] +async def test_handshake_server_headers(nursery: trio.Nursery) -> None: + async def handler(request: WebSocketRequest) -> None: + headers = [(b"X-Test-Header", b"My test header")] await request.accept(extra_headers=headers) server = await nursery.start(serve_websocket, handler, HOST, 0, None) - async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False - ) as client_ws: + assert isinstance(server, WebSocketServer) + async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False) as client_ws: header_key, header_value = client_ws.handshake_headers[0] - assert header_key == b'x-test-header' - assert header_value == b'My test header' - - + assert header_key == b"x-test-header" + assert header_value == b"My test header" @fail_after(5) -async def test_open_websocket_internal_ki(nursery, monkeypatch, autojump_clock): +async def test_open_websocket_internal_ki( + nursery: trio.Nursery, + monkeypatch: pytest.MonkeyPatch, + autojump_clock: trio.testing.MockClock, +) -> None: """_reader_task._handle_ping_event triggers KeyboardInterrupt. user code also raises exception. Make sure that KI is delivered, and the user exception is in the __cause__ exceptiongroup """ - async def ki_raising_ping_handler(*args, **kwargs) -> None: - print("raising ki") + + async def ki_raising_ping_handler(*args: object, **kwargs: object) -> None: raise KeyboardInterrupt - monkeypatch.setattr(WebSocketConnection, "_handle_ping_event", ki_raising_ping_handler) - async def handler(request): + + monkeypatch.setattr( + WebSocketConnection, "_handle_ping_event", ki_raising_ping_handler + ) + + async def handler(request: WebSocketRequest) -> None: server_ws = await request.accept() await server_ws.ping(b"a") server = await nursery.start(serve_websocket, handler, HOST, 0, None) + assert isinstance(server, WebSocketServer) with pytest.raises(KeyboardInterrupt) as exc_info: async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False): with trio.fail_after(1) as cs: @@ -470,51 +594,76 @@ async def handler(request): assert isinstance(e_cause, _TRIO_EXC_GROUP_TYPE) assert any(isinstance(e, trio.TooSlowError) for e in e_cause.exceptions) + @fail_after(5) -async def test_open_websocket_internal_exc(nursery, monkeypatch, autojump_clock): +async def test_open_websocket_internal_exc( + nursery: trio.Nursery, + monkeypatch: pytest.MonkeyPatch, + autojump_clock: trio.testing.MockClock, +) -> None: """_reader_task._handle_ping_event triggers ValueError. user code also raises exception. - internal exception is in __cause__ exceptiongroup and user exc is delivered + internal exception is in __context__ exceptiongroup and user exc is delivered """ - my_value_error = ValueError() - async def raising_ping_event(*args, **kwargs) -> None: - raise my_value_error + internal_error = ValueError() + internal_error.__context__ = TypeError() + user_error = NameError() + user_error_context = KeyError() + + async def raising_ping_event(*args: object, **kwargs: object) -> None: + raise internal_error monkeypatch.setattr(WebSocketConnection, "_handle_ping_event", raising_ping_event) - async def handler(request): + + async def handler(request: WebSocketRequest) -> None: server_ws = await request.accept() await server_ws.ping(b"a") server = await nursery.start(serve_websocket, handler, HOST, 0, None) - with pytest.raises(trio.TooSlowError) as exc_info: + assert isinstance(server, WebSocketServer) + with pytest.raises(type(user_error)) as exc_info: async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False): - with trio.fail_after(1) as cs: - cs.shield = True - await trio.sleep(2) + await trio.lowlevel.checkpoint() + user_error.__context__ = user_error_context + raise user_error + + assert exc_info.value is user_error + e_context = exc_info.value.__context__ + assert isinstance( + e_context, + BaseExceptionGroup, # pylint: disable=possibly-used-before-assignment + ) + assert internal_error in e_context.exceptions + assert user_error_context in e_context.exceptions - e_cause = exc_info.value.__cause__ - assert isinstance(e_cause, _TRIO_EXC_GROUP_TYPE) - assert my_value_error in e_cause.exceptions @fail_after(5) -async def test_open_websocket_cancellations(nursery, monkeypatch, autojump_clock): +async def test_open_websocket_cancellations( + nursery: trio.Nursery, + monkeypatch: pytest.MonkeyPatch, + autojump_clock: trio.testing.MockClock, +) -> None: """Both user code and _reader_task raise Cancellation. Check that open_websocket reraises the one from user code for traceback reasons. """ - - async def sleeping_ping_event(*args, **kwargs) -> None: + async def sleeping_ping_event(*args: object, **kwargs: object) -> None: await trio.sleep_forever() # We monkeypatch WebSocketConnection._handle_ping_event to ensure it will actually # raise Cancelled upon being cancelled. For some reason it doesn't otherwise. monkeypatch.setattr(WebSocketConnection, "_handle_ping_event", sleeping_ping_event) - async def handler(request): + + async def handler(request: WebSocketRequest) -> None: server_ws = await request.accept() await server_ws.ping(b"a") + user_cancelled = None + user_cancelled_cause = None + user_cancelled_context = None server = await nursery.start(serve_websocket, handler, HOST, 0, None) + assert isinstance(server, WebSocketServer) with trio.move_on_after(2): with pytest.raises(trio.Cancelled) as exc_info: async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False): @@ -522,28 +671,38 @@ async def handler(request): await trio.sleep_forever() except trio.Cancelled as e: user_cancelled = e + user_cancelled_cause = e.__cause__ + user_cancelled_context = e.__context__ raise + assert exc_info.value is user_cancelled + assert exc_info.value.__cause__ is user_cancelled_cause + assert exc_info.value.__context__ is user_cancelled_context + def _trio_default_non_strict_exception_groups() -> bool: - assert re.match(r'^0\.\d\d\.', trio.__version__), "unexpected trio versioning scheme" - return int(trio.__version__[2:4]) < 25 + trio_version = version("trio") + assert re.match(r"^0\.\d\d\.", trio_version), "unexpected trio versioning scheme" + return int(trio_version[2:4]) < 25 + @fail_after(1) async def test_handshake_exception_before_accept() -> None: - ''' In #107, a request handler that throws an exception before finishing the + """In #107, a request handler that throws an exception before finishing the handshake causes the task to hang. The proper behavior is to raise an - exception to the nursery as soon as possible. ''' - async def handler(request): + exception to the nursery as soon as possible.""" + + async def handler(request: WebSocketRequest) -> None: raise ValueError() # pylint fails to resolve that BaseExceptionGroup will always be available - with pytest.raises((BaseExceptionGroup, ValueError)) as exc: # pylint: disable=possibly-used-before-assignment + with pytest.raises( + (BaseExceptionGroup, ValueError) + ) as exc: # pylint: disable=possibly-used-before-assignment async with trio.open_nursery() as nursery: - server = await nursery.start(serve_websocket, handler, HOST, 0, - None) - async with open_websocket(HOST, server.port, RESOURCE, - use_ssl=False): + server = await nursery.start(serve_websocket, handler, HOST, 0, None) + assert isinstance(server, WebSocketServer) + async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False): pass if _trio_default_non_strict_exception_groups(): @@ -554,38 +713,62 @@ async def handler(request): # 2. WebSocketServer.run # 3. trio.serve_listeners # 4. WebSocketServer._handle_connection - assert RaisesGroup( - RaisesGroup( - RaisesGroup( - RaisesGroup(ValueError)))).matches(exc.value) + assert RaisesGroup(RaisesGroup(RaisesGroup(RaisesGroup(ValueError)))).matches( + exc.value + ) + + +async def test_user_exception_cause(nursery: trio.Nursery) -> None: + async def handler(request: WebSocketRequest) -> None: + await request.accept() + + server = await nursery.start(serve_websocket, handler, HOST, 0, None) + assert isinstance(server, WebSocketServer) + e_context = TypeError("foo") + e_primary = ValueError("bar") + e_cause = RuntimeError("zee") + with pytest.raises(ValueError) as exc_info: + async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False): + try: + raise e_context + except TypeError: + raise e_primary from e_cause + e = exc_info.value + assert e is e_primary + assert e.__cause__ is e_cause + assert e.__context__ is e_context @fail_after(1) -async def test_reject_handshake(nursery): - async def handler(request): - body = b'My body' +async def test_reject_handshake(nursery: trio.Nursery) -> None: + async def handler(request: WebSocketRequest) -> None: + body = b"My body" await request.reject(400, body=body) server = await nursery.start(serve_websocket, handler, HOST, 0, None) + assert isinstance(server, WebSocketServer) with pytest.raises(ConnectionRejected) as exc_info: async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False): pass exc = exc_info.value - assert exc.body == b'My body' + assert exc.body == b"My body" @fail_after(1) -async def test_reject_handshake_invalid_info_status(nursery): - ''' +async def test_reject_handshake_invalid_info_status(nursery: trio.Nursery) -> None: + """ An informational status code that is not 101 should cause the client to reject the handshake. Since it is an informational response, there will not be a response body, so this test exercises a different code path. - ''' - async def handler(stream): - await stream.send_all(b'HTTP/1.1 100 CONTINUE\r\n\r\n') + """ + + async def handler(stream: trio.SocketStream) -> None: + await stream.send_all(b"HTTP/1.1 100 CONTINUE\r\n\r\n") await stream.receive_some(max_bytes=1024) + serve_fn = partial(trio.serve_tcp, handler, 0, host=HOST) - listeners = await nursery.start(serve_fn) + raw_listeners = await nursery.start(serve_fn) + listeners = cast("list[trio.SocketListener]", raw_listeners) port = listeners[0].socket.getsockname()[1] with pytest.raises(ConnectionRejected) as exc_info: @@ -593,50 +776,52 @@ async def handler(stream): pass exc = exc_info.value assert exc.status_code == 100 - assert repr(exc) == 'ConnectionRejected' + assert repr(exc) == "ConnectionRejected" assert exc.body is None -async def test_handshake_protocol_error(echo_server): - ''' +async def test_handshake_protocol_error(echo_server: WebSocketServer) -> None: + """ If a client connects to a trio-websocket server and tries to speak HTTP instead of WebSocket, the server should reject the connection. (If the server does not catch the protocol exception, it will raise an exception up to the nursery level and fail the test.) - ''' + """ client_stream = await trio.open_tcp_stream(HOST, echo_server.port) async with client_stream: - await client_stream.send_all(b'GET / HTTP/1.1\r\n\r\n') + await client_stream.send_all(b"GET / HTTP/1.1\r\n\r\n") response = await client_stream.receive_some(1024) - assert response.startswith(b'HTTP/1.1 400') + assert response.startswith(b"HTTP/1.1 400") -async def test_client_send_and_receive(echo_conn): +async def test_client_send_and_receive(echo_conn: WebSocketConnection) -> None: async with echo_conn: - await echo_conn.send_message('This is a test message.') + await echo_conn.send_message("This is a test message.") received_msg = await echo_conn.get_message() - assert received_msg == 'This is a test message.' + assert received_msg == "This is a test message." -async def test_client_send_invalid_type(echo_conn): +async def test_client_send_invalid_type(echo_conn: WebSocketConnection) -> None: async with echo_conn: with pytest.raises(ValueError): - await echo_conn.send_message(object()) + await echo_conn.send_message(object()) # type: ignore[arg-type] -async def test_client_ping(echo_conn): +async def test_client_ping(echo_conn: WebSocketConnection) -> None: async with echo_conn: - await echo_conn.ping(b'A') + await echo_conn.ping(b"A") with pytest.raises(ConnectionClosed): - await echo_conn.ping(b'B') + await echo_conn.ping(b"B") -async def test_client_ping_two_payloads(echo_conn): +async def test_client_ping_two_payloads(echo_conn: WebSocketConnection) -> None: pong_count = 0 - async def ping_and_count(): + + async def ping_and_count() -> None: nonlocal pong_count await echo_conn.ping() pong_count += 1 + async with echo_conn: async with trio.open_nursery() as nursery: nursery.start_soon(ping_and_count) @@ -644,17 +829,19 @@ async def ping_and_count(): assert pong_count == 2 -async def test_client_ping_same_payload(echo_conn): +async def test_client_ping_same_payload(echo_conn: WebSocketConnection) -> None: # This test verifies that two tasks can't ping with the same payload at the # same time. One of them should succeed and the other should get an # exception. exc_count = 0 - async def ping_and_catch(): + + async def ping_and_catch() -> None: nonlocal exc_count try: - await echo_conn.ping(b'A') + await echo_conn.ping(b"A") except ValueError: exc_count += 1 + async with echo_conn: async with trio.open_nursery() as nursery: nursery.start_soon(ping_and_catch) @@ -662,84 +849,106 @@ async def ping_and_catch(): assert exc_count == 1 -async def test_client_pong(echo_conn): +async def test_client_pong(echo_conn: WebSocketConnection) -> None: async with echo_conn: - await echo_conn.pong(b'A') + await echo_conn.pong(b"A") with pytest.raises(ConnectionClosed): - await echo_conn.pong(b'B') + await echo_conn.pong(b"B") -async def test_client_default_close(echo_conn): +async def test_client_default_close(echo_conn: WebSocketConnection) -> None: async with echo_conn: assert not echo_conn.closed + assert isinstance(echo_conn.closed, CloseReason) assert echo_conn.closed.code == 1000 assert echo_conn.closed.reason is None - assert repr(echo_conn.closed) == 'CloseReason' + assert ( + repr(echo_conn.closed) == "CloseReason" + ) -async def test_client_nondefault_close(echo_conn): +async def test_client_nondefault_close(echo_conn: WebSocketConnection) -> None: async with echo_conn: assert not echo_conn.closed - await echo_conn.aclose(code=1001, reason='test reason') + await echo_conn.aclose(code=1001, reason="test reason") + assert isinstance(echo_conn.closed, CloseReason) assert echo_conn.closed.code == 1001 - assert echo_conn.closed.reason == 'test reason' + assert echo_conn.closed.reason == "test reason" -async def test_wrap_client_stream(nursery): +async def test_wrap_client_stream(nursery: trio.Nursery) -> None: listener = MemoryListener() - server = WebSocketServer(echo_request_handler, [listener]) + server = WebSocketServer(echo_request_handler, [listener]) # type: ignore[list-item] await nursery.start(server.run) stream = await listener.connect() - conn = await wrap_client_stream(nursery, stream, HOST, RESOURCE) + conn = await wrap_client_stream( + nursery, + stream, # type: ignore[arg-type] + HOST, + RESOURCE, + ) async with conn: assert not conn.closed - await conn.send_message('Hello from client!') + await conn.send_message("Hello from client!") msg = await conn.get_message() - assert msg == 'Hello from client!' - assert conn.local.startswith('StapledStream(') + assert msg == "Hello from client!" + assert isinstance(conn.local, str) + assert conn.local.startswith("StapledStream(") assert conn.closed -async def test_wrap_server_stream(nursery): - async def handler(stream): +async def test_wrap_server_stream(nursery: trio.Nursery) -> None: + async def handler(stream: trio.SocketStream) -> None: request = await wrap_server_stream(nursery, stream) server_ws = await request.accept() async with server_ws: assert not server_ws.closed msg = await server_ws.get_message() - assert msg == 'Hello from client!' + assert msg == "Hello from client!" assert server_ws.closed + serve_fn = partial(trio.serve_tcp, handler, 0, host=HOST) - listeners = await nursery.start(serve_fn) + raw_listeners = await nursery.start(serve_fn) + listeners = cast("list[trio.SocketListener]", raw_listeners) port = listeners[0].socket.getsockname()[1] async with open_websocket(HOST, port, RESOURCE, use_ssl=False) as client: - await client.send_message('Hello from client!') + await client.send_message("Hello from client!") @fail_after(TIMEOUT_TEST_MAX_DURATION) -async def test_client_open_timeout(nursery, autojump_clock): - ''' +async def test_client_open_timeout( + nursery: trio.Nursery, + autojump_clock: trio.testing.MockClock, +) -> None: + """ The client times out waiting for the server to complete the opening handshake. - ''' - async def handler(request): + """ + + async def handler(request: WebSocketRequest) -> None: await trio.sleep(FORCE_TIMEOUT) await request.accept() - pytest.fail('Should not reach this line.') + pytest.fail("Should not reach this line.") server = await nursery.start( - partial(serve_websocket, handler, HOST, 0, ssl_context=None)) + partial(serve_websocket, handler, HOST, 0, ssl_context=None) + ) + assert isinstance(server, WebSocketServer) with pytest.raises(ConnectionTimeout): - async with open_websocket(HOST, server.port, '/', use_ssl=False, - connect_timeout=TIMEOUT): + async with open_websocket( + HOST, server.port, "/", use_ssl=False, connect_timeout=TIMEOUT + ): pass @fail_after(TIMEOUT_TEST_MAX_DURATION) -async def test_client_close_timeout(nursery, autojump_clock): - ''' +async def test_client_close_timeout( + nursery: trio.Nursery, + autojump_clock: trio.testing.MockClock, +) -> None: + """ This client times out waiting for the server to complete the closing handshake. @@ -747,68 +956,85 @@ async def test_client_close_timeout(nursery, autojump_clock): queue size is 0, and the client sends it exactly 1 message. This blocks the server's reader so it won't do the closing handshake for at least ``FORCE_TIMEOUT`` seconds. - ''' - async def handler(request): + """ + + async def handler(request: WebSocketRequest) -> None: server_ws = await request.accept() await trio.sleep(FORCE_TIMEOUT) # The next line should raise ConnectionClosed. await server_ws.get_message() - pytest.fail('Should not reach this line.') + pytest.fail("Should not reach this line.") server = await nursery.start( - partial(serve_websocket, handler, HOST, 0, ssl_context=None, - message_queue_size=0)) + partial( + serve_websocket, handler, HOST, 0, ssl_context=None, message_queue_size=0 + ) + ) + assert isinstance(server, WebSocketServer) with pytest.raises(DisconnectionTimeout): - async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False, - disconnect_timeout=TIMEOUT) as client_ws: - await client_ws.send_message('test') + async with open_websocket( + HOST, server.port, RESOURCE, use_ssl=False, disconnect_timeout=TIMEOUT + ) as client_ws: + await client_ws.send_message("test") -async def test_client_connect_networking_error(): - with patch('trio_websocket._impl.connect_websocket') as \ - connect_websocket_mock: +async def test_client_connect_networking_error() -> None: + with patch("trio_websocket._impl.connect_websocket") as connect_websocket_mock: connect_websocket_mock.side_effect = OSError() with pytest.raises(HandshakeError): - async with open_websocket(HOST, 0, '/', use_ssl=False): + async with open_websocket(HOST, 0, "/", use_ssl=False): pass @fail_after(TIMEOUT_TEST_MAX_DURATION) -async def test_server_open_timeout(autojump_clock): - ''' +async def test_server_open_timeout(autojump_clock: trio.testing.MockClock) -> None: + """ The server times out waiting for the client to complete the opening handshake. Server timeouts don't raise exceptions, because handler tasks are launched in an internal nursery and sending exceptions wouldn't be helpful. Instead, timed out tasks silently end. - ''' - async def handler(request): - pytest.fail('This handler should not be called.') + """ + + async def handler(request: WebSocketRequest) -> None: + pytest.fail("This handler should not be called.") async with trio.open_nursery() as nursery: - server = await nursery.start(partial(serve_websocket, handler, HOST, 0, - ssl_context=None, handler_nursery=nursery, connect_timeout=TIMEOUT)) + server = await nursery.start( + partial( + serve_websocket, + handler, + HOST, + 0, + ssl_context=None, + handler_nursery=nursery, + connect_timeout=TIMEOUT, + ) + ) + assert isinstance(server, WebSocketServer) old_task_count = len(nursery.child_tasks) # This stream is not a WebSocket, so it won't send a handshake: await trio.open_tcp_stream(HOST, server.port) # Checkpoint so the server's handler task can spawn: await trio.sleep(0) - assert len(nursery.child_tasks) == old_task_count + 1, \ - "Server's reader task did not spawn" + assert ( + len(nursery.child_tasks) == old_task_count + 1 + ), "Server's reader task did not spawn" # Sleep long enough to trigger server's connect_timeout: await trio.sleep(FORCE_TIMEOUT) - assert len(nursery.child_tasks) == old_task_count, \ - "Server's reader task is still running" + assert ( + len(nursery.child_tasks) == old_task_count + ), "Server's reader task is still running" # Cancel the server task: nursery.cancel_scope.cancel() @fail_after(TIMEOUT_TEST_MAX_DURATION) -async def test_server_close_timeout(autojump_clock): - ''' +async def test_server_close_timeout(autojump_clock: trio.testing.MockClock) -> None: + """ The server times out waiting for the client to complete the closing handshake. @@ -819,33 +1045,44 @@ async def test_server_close_timeout(autojump_clock): To prevent the client from doing the closing handshake, we make sure that its message queue size is 0 and the server sends it exactly 1 message. This blocks the client's reader and prevents it from doing the client handshake. - ''' - async def handler(request): + """ + + async def handler(request: WebSocketRequest) -> None: ws = await request.accept() # Send one message to block the client's reader task: - await ws.send_message('test') + await ws.send_message("test") async with trio.open_nursery() as outer: - server = await outer.start(partial(serve_websocket, handler, HOST, 0, - ssl_context=None, handler_nursery=outer, - disconnect_timeout=TIMEOUT)) + server = await outer.start( + partial( + serve_websocket, + handler, + HOST, + 0, + ssl_context=None, + handler_nursery=outer, + disconnect_timeout=TIMEOUT, + ) + ) + assert isinstance(server, WebSocketServer) old_task_count = len(outer.child_tasks) # Spawn client inside an inner nursery so that we can cancel it's reader # so that it won't do a closing handshake. async with trio.open_nursery() as inner: - await connect_websocket(inner, HOST, server.port, RESOURCE, - use_ssl=False) + await connect_websocket(inner, HOST, server.port, RESOURCE, use_ssl=False) # Checkpoint so the server can spawn a handler task: await trio.sleep(0) - assert len(outer.child_tasks) == old_task_count + 1, \ - "Server's reader task did not spawn" + assert ( + len(outer.child_tasks) == old_task_count + 1 + ), "Server's reader task did not spawn" # The client waits long enough to trigger the server's disconnect # timeout: await trio.sleep(FORCE_TIMEOUT) # The server should have cancelled the handler: - assert len(outer.child_tasks) == old_task_count, \ - "Server's reader task is still running" + assert ( + len(outer.child_tasks) == old_task_count + ), "Server's reader task is still running" # Cancel the client's reader task: inner.cancel_scope.cancel() @@ -853,155 +1090,175 @@ async def handler(request): outer.cancel_scope.cancel() -async def test_client_does_not_close_handshake(nursery): - async def handler(request): +async def test_client_does_not_close_handshake(nursery: trio.Nursery) -> None: + async def handler(request: WebSocketRequest) -> None: server_ws = await request.accept() with pytest.raises(ConnectionClosed): await server_ws.get_message() + server = await nursery.start(serve_websocket, handler, HOST, 0, None) + assert isinstance(server, WebSocketServer) stream = await trio.open_tcp_stream(HOST, server.port) client_ws = await wrap_client_stream(nursery, stream, HOST, RESOURCE) async with client_ws: await stream.aclose() with pytest.raises(ConnectionClosed): - await client_ws.send_message('Hello from client!') + await client_ws.send_message("Hello from client!") -async def test_server_sends_after_close(nursery): +async def test_server_sends_after_close(nursery: trio.Nursery) -> None: done = trio.Event() - async def handler(request): + async def handler(request: WebSocketRequest) -> None: server_ws = await request.accept() with pytest.raises(ConnectionClosed): while True: - await server_ws.send_message('Hello from server') + await server_ws.send_message("Hello from server") done.set() server = await nursery.start(serve_websocket, handler, HOST, 0, None) + assert isinstance(server, WebSocketServer) stream = await trio.open_tcp_stream(HOST, server.port) client_ws = await wrap_client_stream(nursery, stream, HOST, RESOURCE) async with client_ws: # pump a few messages for x in range(2): - await client_ws.send_message('Hello from client') + await client_ws.send_message("Hello from client") await stream.aclose() await done.wait() -async def test_server_does_not_close_handshake(nursery): - async def handler(stream): +async def test_server_does_not_close_handshake(nursery: trio.Nursery) -> None: + async def handler(stream: trio.SocketStream) -> None: request = await wrap_server_stream(nursery, stream) server_ws = await request.accept() async with server_ws: await stream.aclose() with pytest.raises(ConnectionClosed): - await server_ws.send_message('Hello from client!') + await server_ws.send_message("Hello from client!") + serve_fn = partial(trio.serve_tcp, handler, 0, host=HOST) - listeners = await nursery.start(serve_fn) + raw_listeners = await nursery.start(serve_fn) + listeners = cast("list[trio.SocketListener]", raw_listeners) port = listeners[0].socket.getsockname()[1] async with open_websocket(HOST, port, RESOURCE, use_ssl=False) as client: with pytest.raises(ConnectionClosed): await client.get_message() -async def test_server_handler_exit(nursery, autojump_clock): - async def handler(request): +async def test_server_handler_exit( + nursery: trio.Nursery, + autojump_clock: trio.testing.MockClock, +) -> None: + async def handler(request: WebSocketRequest) -> None: await request.accept() await trio.sleep(1) server = await nursery.start( - partial(serve_websocket, handler, HOST, 0, ssl_context=None)) + partial(serve_websocket, handler, HOST, 0, ssl_context=None) + ) + assert isinstance(server, WebSocketServer) # connection should close when server handler exits with trio.fail_after(2): - async with open_websocket( - HOST, server.port, '/', use_ssl=False) as connection: + async with open_websocket(HOST, server.port, "/", use_ssl=False) as connection: with pytest.raises(ConnectionClosed) as exc_info: await connection.get_message() exc = exc_info.value - assert exc.reason.name == 'NORMAL_CLOSURE' + assert isinstance(exc.reason, CloseReason) + assert exc.reason.name == "NORMAL_CLOSURE" @fail_after(DEFAULT_TEST_MAX_DURATION) -async def test_read_messages_after_remote_close(nursery): - ''' +async def test_read_messages_after_remote_close(nursery: trio.Nursery) -> None: + """ When the remote endpoint closes, the local endpoint can still read all of the messages sent prior to closing. Any attempt to read beyond that will raise ConnectionClosed. This test also exercises the configuration of the queue size. - ''' + """ server_closed = trio.Event() - async def handler(request): + async def handler(request: WebSocketRequest) -> None: server = await request.accept() async with server: - await server.send_message('1') - await server.send_message('2') + await server.send_message("1") + await server.send_message("2") server_closed.set() server = await nursery.start( - partial(serve_websocket, handler, HOST, 0, ssl_context=None)) + partial(serve_websocket, handler, HOST, 0, ssl_context=None) + ) + assert isinstance(server, WebSocketServer) # The client needs a message queue of size 2 so that it can buffer both # incoming messages without blocking the reader task. - async with open_websocket(HOST, server.port, '/', use_ssl=False, - message_queue_size=2) as client: + async with open_websocket( + HOST, server.port, "/", use_ssl=False, message_queue_size=2 + ) as client: await server_closed.wait() - assert await client.get_message() == '1' - assert await client.get_message() == '2' + assert await client.get_message() == "1" + assert await client.get_message() == "2" with pytest.raises(ConnectionClosed): await client.get_message() -async def test_no_messages_after_local_close(nursery): - ''' +async def test_no_messages_after_local_close(nursery: trio.Nursery) -> None: + """ If the local endpoint initiates closing, then pending messages are discarded and any attempt to read a message will raise ConnectionClosed. - ''' + """ client_closed = trio.Event() - async def handler(request): + async def handler(request: WebSocketRequest) -> None: # The server sends some messages and then closes. server = await request.accept() async with server: - await server.send_message('1') - await server.send_message('2') + await server.send_message("1") + await server.send_message("2") await client_closed.wait() server = await nursery.start( - partial(serve_websocket, handler, HOST, 0, ssl_context=None)) + partial(serve_websocket, handler, HOST, 0, ssl_context=None) + ) + assert isinstance(server, WebSocketServer) - async with open_websocket(HOST, server.port, '/', use_ssl=False) as client: + async with open_websocket(HOST, server.port, "/", use_ssl=False) as client: pass with pytest.raises(ConnectionClosed): await client.get_message() client_closed.set() -async def test_cm_exit_with_pending_messages(echo_server, autojump_clock): - ''' +async def test_cm_exit_with_pending_messages( + echo_server: WebSocketServer, + autojump_clock: trio.testing.MockClock, +) -> None: + """ Regression test for #74, where a context manager was not able to exit when there were pending messages in the receive queue. - ''' + """ with trio.fail_after(1): - async with open_websocket(HOST, echo_server.port, RESOURCE, - use_ssl=False) as ws: - await ws.send_message('hello') + async with open_websocket( + HOST, echo_server.port, RESOURCE, use_ssl=False + ) as ws: + await ws.send_message("hello") # allow time for the server to respond - await trio.sleep(.1) + await trio.sleep(0.1) @fail_after(DEFAULT_TEST_MAX_DURATION) -async def test_max_message_size(nursery): - ''' +async def test_max_message_size(nursery: trio.Nursery) -> None: + """ Set the client's max message size to 100 bytes. The client can send a message larger than 100 bytes, but when it receives a message larger than 100 bytes, it closes the connection with code 1009. - ''' - async def handler(request): - ''' Similar to the echo_request_handler fixture except it runs in a - loop. ''' + """ + + async def handler(request: WebSocketRequest) -> None: + """Similar to the echo_request_handler fixture except it runs in a + loop.""" conn = await request.accept() while True: try: @@ -1011,44 +1268,56 @@ async def handler(request): break server = await nursery.start( - partial(serve_websocket, handler, HOST, 0, ssl_context=None)) + partial(serve_websocket, handler, HOST, 0, ssl_context=None) + ) + assert isinstance(server, WebSocketServer) - async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False, - max_message_size=100) as client: + async with open_websocket( + HOST, server.port, RESOURCE, use_ssl=False, max_message_size=100 + ) as client: # We can send and receive 100 bytes: - await client.send_message(b'A' * 100) + await client.send_message(b"A" * 100) msg = await client.get_message() assert len(msg) == 100 # We can send 101 bytes but cannot receive 101 bytes: - await client.send_message(b'B' * 101) + await client.send_message(b"B" * 101) with pytest.raises(ConnectionClosed): await client.get_message() assert client.closed assert client.closed.code == 1009 -async def test_server_close_client_disconnect_race(nursery, autojump_clock): +async def test_server_close_client_disconnect_race( + nursery: trio.Nursery, + autojump_clock: trio.testing.MockClock, +) -> None: """server attempts close just as client disconnects (issue #96)""" - async def handler(request: WebSocketRequest): + async def handler(request: WebSocketRequest) -> None: ws = await request.accept() ws._for_testing_peer_closed_connection = trio.Event() - await ws.send_message('foo') + await ws.send_message("foo") await ws._for_testing_peer_closed_connection.wait() # with bug, this would raise ConnectionClosed from websocket internal task await trio.aclose_forcefully(ws._stream) server = await nursery.start( - partial(serve_websocket, handler, HOST, 0, ssl_context=None)) + partial(serve_websocket, handler, HOST, 0, ssl_context=None) + ) + assert isinstance(server, WebSocketServer) - connection = await connect_websocket(nursery, HOST, server.port, - RESOURCE, use_ssl=False) + connection = await connect_websocket( + nursery, HOST, server.port, RESOURCE, use_ssl=False + ) await connection.get_message() await connection.aclose() - await trio.sleep(.1) + await trio.sleep(0.1) -async def test_remote_close_local_message_race(nursery, autojump_clock): +async def test_remote_close_local_message_race( + nursery: trio.Nursery, + autojump_clock: trio.testing.MockClock, +) -> None: """as remote initiates close, local attempts message (issue #175) This exposed multiple problems in the trio-websocket API and implementation: @@ -1059,76 +1328,90 @@ async def test_remote_close_local_message_race(nursery, autojump_clock): * with wsproto >= 1.2.0, LocalProtocolError will be leaked """ - async def handler(request: WebSocketRequest): + async def handler(request: WebSocketRequest) -> None: ws = await request.accept() await ws.get_message() await ws.aclose() server = await nursery.start( - partial(serve_websocket, handler, HOST, 0, ssl_context=None)) + partial(serve_websocket, handler, HOST, 0, ssl_context=None) + ) + assert isinstance(server, WebSocketServer) - client = await connect_websocket(nursery, HOST, server.port, - RESOURCE, use_ssl=False) + client = await connect_websocket( + nursery, HOST, server.port, RESOURCE, use_ssl=False + ) client._for_testing_peer_closed_connection = trio.Event() - await client.send_message('foo') + await client.send_message("foo") await client._for_testing_peer_closed_connection.wait() with pytest.raises(ConnectionClosed): - await client.send_message('bar') + await client.send_message("bar") -async def test_message_after_local_close_race(nursery): +async def test_message_after_local_close_race(nursery: trio.Nursery) -> None: """test message send during local-initiated close handshake (issue #158)""" - async def handler(request: WebSocketRequest): + async def handler(request: WebSocketRequest) -> None: await request.accept() await trio.sleep_forever() server = await nursery.start( - partial(serve_websocket, handler, HOST, 0, ssl_context=None)) + partial(serve_websocket, handler, HOST, 0, ssl_context=None) + ) + assert isinstance(server, WebSocketServer) - client = await connect_websocket(nursery, HOST, server.port, - RESOURCE, use_ssl=False) + client = await connect_websocket( + nursery, HOST, server.port, RESOURCE, use_ssl=False + ) orig_send = client._send close_sent = trio.Event() - async def _send_wrapper(event): + async def _send_wrapper(event: Event) -> None: if isinstance(event, CloseConnection): close_sent.set() return await orig_send(event) - client._send = _send_wrapper + client._send = _send_wrapper # type: ignore[method-assign] assert not client.closed nursery.start_soon(client.aclose) await close_sent.wait() assert client.closed with pytest.raises(ConnectionClosed): - await client.send_message('hello') + await client.send_message("hello") @fail_after(DEFAULT_TEST_MAX_DURATION) -async def test_server_tcp_closed_on_close_connection_event(nursery): +async def test_server_tcp_closed_on_close_connection_event( + nursery: trio.Nursery, +) -> None: """ensure server closes TCP immediately after receiving CloseConnection""" server_stream_closed = trio.Event() - async def _close_stream_stub(): + async def _close_stream_stub() -> None: assert not server_stream_closed.is_set() server_stream_closed.set() - async def handle_connection(request): + async def handle_connection(request: WebSocketRequest) -> None: ws = await request.accept() - ws._close_stream = _close_stream_stub + ws._close_stream = _close_stream_stub # type: ignore[method-assign] await trio.sleep_forever() server = await nursery.start( - partial(serve_websocket, handle_connection, HOST, 0, ssl_context=None)) - client = await connect_websocket(nursery, HOST, server.port, - RESOURCE, use_ssl=False) + partial(serve_websocket, handle_connection, HOST, 0, ssl_context=None) + ) + assert isinstance(server, WebSocketServer) + client = await connect_websocket( + nursery, HOST, server.port, RESOURCE, use_ssl=False + ) # send a CloseConnection event to server but leave client connected await client._send(CloseConnection(code=1000)) await server_stream_closed.wait() -async def test_finalization_dropped_exception(echo_server, autojump_clock): +async def test_finalization_dropped_exception( + echo_server: WebSocketServer, + autojump_clock: trio.testing.MockClock, +) -> None: # Confirm that open_websocket finalization does not contribute to dropped # exceptions as described in https://github.com/python-trio/trio/issues/1559. with pytest.raises(ValueError): @@ -1140,7 +1423,7 @@ async def test_finalization_dropped_exception(echo_server, autojump_clock): raise ValueError -async def test_remote_close_rude(): +async def test_remote_close_rude() -> None: """ Bad ordering: 1. Remote close @@ -1150,14 +1433,19 @@ async def test_remote_close_rude(): """ client_stream, server_stream = memory_stream_pair() - async def client(): - client_conn = await wrap_client_stream(nursery, client_stream, HOST, RESOURCE) + async def client() -> None: + client_conn = await wrap_client_stream( + nursery, + client_stream, # type: ignore[arg-type] + HOST, + RESOURCE, + ) assert not client_conn.closed - await client_conn.send_message('Hello from client!') + await client_conn.send_message("Hello from client!") with pytest.raises(ConnectionClosed): await client_conn.get_message() - async def server(): + async def server() -> None: server_request = await wrap_server_stream(nursery, server_stream) server_ws = await server_request.accept() assert not server_ws.closed @@ -1172,7 +1460,22 @@ async def server(): # pump the messages over memory_stream_pump(server_stream.send_stream, client_stream.receive_stream) - async with trio.open_nursery() as nursery: nursery.start_soon(server) nursery.start_soon(client) + + +def test_copy_exceptions() -> None: + # test that exceptions are copy- and pickleable + copy.copy(HandshakeError()) + copy.copy(ConnectionTimeout()) + copy.copy(DisconnectionTimeout()) + assert ( + copy.copy(ConnectionClosed("foo")).reason # type: ignore[arg-type] + == "foo" # type: ignore[comparison-overlap] + ) + + rej_copy = copy.copy(ConnectionRejected(404, ((b"a", b"b"),), b"c")) + assert rej_copy.status_code == 404 + assert rej_copy.headers == ((b"a", b"b"),) + assert rej_copy.body == b"c" diff --git a/trio_websocket/__init__.py b/trio_websocket/__init__.py index 82ca0ae..afa944c 100644 --- a/trio_websocket/__init__.py +++ b/trio_websocket/__init__.py @@ -1,20 +1,21 @@ +# pylint: disable=useless-import-alias from ._impl import ( - CloseReason, - ConnectionClosed, - ConnectionRejected, - ConnectionTimeout, - connect_websocket, - connect_websocket_url, - DisconnectionTimeout, - Endpoint, - HandshakeError, - open_websocket, - open_websocket_url, - WebSocketConnection, - WebSocketRequest, - WebSocketServer, - wrap_client_stream, - wrap_server_stream, - serve_websocket, + CloseReason as CloseReason, + ConnectionClosed as ConnectionClosed, + ConnectionRejected as ConnectionRejected, + ConnectionTimeout as ConnectionTimeout, + connect_websocket as connect_websocket, + connect_websocket_url as connect_websocket_url, + DisconnectionTimeout as DisconnectionTimeout, + Endpoint as Endpoint, + HandshakeError as HandshakeError, + open_websocket as open_websocket, + open_websocket_url as open_websocket_url, + WebSocketConnection as WebSocketConnection, + WebSocketRequest as WebSocketRequest, + WebSocketServer as WebSocketServer, + wrap_client_stream as wrap_client_stream, + wrap_server_stream as wrap_server_stream, + serve_websocket as serve_websocket, ) -from ._version import __version__ +from ._version import __version__ as __version__ diff --git a/trio_websocket/_impl.py b/trio_websocket/_impl.py index a71e0be..ba51bc0 100644 --- a/trio_websocket/_impl.py +++ b/trio_websocket/_impl.py @@ -2,7 +2,7 @@ import sys from collections import OrderedDict -from contextlib import asynccontextmanager +from contextlib import asynccontextmanager, AbstractAsyncContextManager from functools import partial from ipaddress import ip_address import itertools @@ -11,7 +11,18 @@ import ssl import struct import urllib.parse -from typing import Iterable, List, Optional, Union +from typing import ( + Any, + List, + NoReturn, + Optional, + Union, + TypeVar, + TYPE_CHECKING, + Generic, + cast, +) +from importlib.metadata import version import outcome import trio @@ -36,18 +47,35 @@ # pylint doesn't care about the version_info check, so need to ignore the warning from exceptiongroup import BaseExceptionGroup # pylint: disable=redefined-builtin -_IS_TRIO_MULTI_ERROR = tuple(map(int, trio.__version__.split('.')[:2])) < (0, 22) +if TYPE_CHECKING: + from types import TracebackType + from typing_extensions import Final + from collections.abc import ( + AsyncGenerator, + Awaitable, + Callable, + Iterable, + Coroutine, + Sequence, + ) + +_IS_TRIO_MULTI_ERROR: Final = tuple(map(int, version("trio").split(".")[:2])) < (0, 22) if _IS_TRIO_MULTI_ERROR: _TRIO_EXC_GROUP_TYPE = trio.MultiError # type: ignore[attr-defined] # pylint: disable=no-member else: - _TRIO_EXC_GROUP_TYPE = BaseExceptionGroup # pylint: disable=possibly-used-before-assignment + _TRIO_EXC_GROUP_TYPE = ( + BaseExceptionGroup # pylint: disable=possibly-used-before-assignment + ) + +CONN_TIMEOUT: Final = 60 # default connect & disconnect timeout, in seconds +MESSAGE_QUEUE_SIZE: Final = 1 +MAX_MESSAGE_SIZE: Final = 2**20 # 1 MiB +RECEIVE_BYTES: Final = 4 * 2**10 # 4 KiB +logger: Final = logging.getLogger("trio-websocket") -CONN_TIMEOUT = 60 # default connect & disconnect timeout, in seconds -MESSAGE_QUEUE_SIZE = 1 -MAX_MESSAGE_SIZE = 2 ** 20 # 1 MiB -RECEIVE_BYTES = 4 * 2 ** 10 # 4 KiB -logger = logging.getLogger('trio-websocket') +T = TypeVar("T") +E = TypeVar("E", bound=BaseException) class TrioWebsocketInternalError(Exception): @@ -57,7 +85,7 @@ class TrioWebsocketInternalError(Exception): """ -def _ignore_cancel(exc): +def _ignore_cancel(exc: E) -> E | None: return None if isinstance(exc, trio.Cancelled) else exc @@ -71,22 +99,32 @@ class _preserve_current_exception: https://github.com/python-trio/trio/issues/1559 https://gitter.im/python-trio/general?at=5faf2293d37a1a13d6a582cf """ + __slots__ = ("_armed",) - def __init__(self): + def __init__(self) -> None: self._armed = False - def __enter__(self): + def __enter__(self) -> None: self._armed = sys.exc_info()[1] is not None - def __exit__(self, ty, value, tb): + def __exit__( + self, + ty: type[BaseException] | None, + value: BaseException | None, + tb: TracebackType | None, + ) -> bool: if value is None or not self._armed: return False if _IS_TRIO_MULTI_ERROR: # pragma: no cover - filtered_exception = trio.MultiError.filter(_ignore_cancel, value) # pylint: disable=no-member - elif isinstance(value, BaseExceptionGroup): # pylint: disable=possibly-used-before-assignment - filtered_exception = value.subgroup(lambda exc: not isinstance(exc, trio.Cancelled)) + filtered_exception = trio.MultiError.filter(_ignore_cancel, value) # type: ignore[attr-defined] # pylint: disable=no-member + elif isinstance( + value, BaseExceptionGroup + ): # pylint: disable=possibly-used-before-assignment + filtered_exception = value.subgroup( + lambda exc: not isinstance(exc, trio.Cancelled) + ) else: filtered_exception = _ignore_cancel(value) return filtered_exception is None @@ -94,19 +132,19 @@ def __exit__(self, ty, value, tb): @asynccontextmanager async def open_websocket( - host: str, - port: int, - resource: str, - *, - use_ssl: Union[bool, ssl.SSLContext], - subprotocols: Optional[Iterable[str]] = None, - extra_headers: Optional[list[tuple[bytes,bytes]]] = None, - message_queue_size: int = MESSAGE_QUEUE_SIZE, - max_message_size: int = MAX_MESSAGE_SIZE, - connect_timeout: float = CONN_TIMEOUT, - disconnect_timeout: float = CONN_TIMEOUT - ): - ''' + host: str, + port: int, + resource: str, + *, + use_ssl: Union[bool, ssl.SSLContext], + subprotocols: Optional[Iterable[str]] = None, + extra_headers: Optional[list[tuple[bytes, bytes]]] = None, + message_queue_size: int = MESSAGE_QUEUE_SIZE, + max_message_size: int = MAX_MESSAGE_SIZE, + connect_timeout: float = CONN_TIMEOUT, + disconnect_timeout: float = CONN_TIMEOUT, +) -> AsyncGenerator[WebSocketConnection, None]: + """ Open a WebSocket client connection to a host. This async context manager connects when entering the context manager and @@ -137,7 +175,7 @@ async def open_websocket( :raises HandshakeError: for any networking error, client-side timeout (:exc:`ConnectionTimeout`, :exc:`DisconnectionTimeout`), or server rejection (:exc:`ConnectionRejected`) during handshakes. - ''' + """ # This context manager tries very very hard not to raise an exceptiongroup # in order to be as transparent as possible for the end user. @@ -151,24 +189,29 @@ async def open_websocket( # yield to user code. If only one of those raise a non-cancelled exception # we will raise that non-cancelled exception. # If we get multiple cancelled, we raise the user's cancelled. - # If both raise exceptions, we raise the user code's exception with the entire - # exception group as the __cause__. + # If both raise exceptions, we raise the user code's exception with __context__ + # set to a group containing internal exception(s) + any user exception __context__ # If we somehow get multiple exceptions, but no user exception, then we raise # TrioWebsocketInternalError. # If closing the connection fails, then that will be raised as the top # exception in the last `finally`. If we encountered exceptions in user code - # or in reader task then they will be set as the `__cause__`. - + # or in reader task then they will be set as the `__context__`. async def _open_connection(nursery: trio.Nursery) -> WebSocketConnection: try: with trio.fail_after(connect_timeout): - return await connect_websocket(nursery, host, port, - resource, use_ssl=use_ssl, subprotocols=subprotocols, + return await connect_websocket( + nursery, + host, + port, + resource, + use_ssl=use_ssl, + subprotocols=subprotocols, extra_headers=extra_headers, message_queue_size=message_queue_size, - max_message_size=max_message_size) + max_message_size=max_message_size, + ) except trio.TooSlowError: raise ConnectionTimeout from None except OSError as e: @@ -181,10 +224,27 @@ async def _close_connection(connection: WebSocketConnection) -> None: except trio.TooSlowError: raise DisconnectionTimeout from None - connection: WebSocketConnection|None=None + def _raise(exc: BaseException) -> NoReturn: + """This helper allows re-raising an exception without __context__ being set.""" + # cause does not need special handlng, we simply avoid using `raise .. from ..` + __tracebackhide__ = True + context = exc.__context__ + try: + raise exc + finally: + exc.__context__ = context + del exc, context + + connection: WebSocketConnection | None = None close_result: outcome.Maybe[None] | None = None user_error = None + # Unwrapping exception groups has a lot of pitfalls, one of them stemming from + # the exception we raise also being inside the group that's set as the context. + # This leads to loss of info unless properly handled. + # See https://github.com/python-trio/flake8-async/issues/298 + # We therefore avoid having the exceptiongroup included as either cause or context + try: async with trio.open_nursery() as new_nursery: result = await outcome.acapture(_open_connection, new_nursery) @@ -205,10 +265,10 @@ async def _close_connection(connection: WebSocketConnection) -> None: except _TRIO_EXC_GROUP_TYPE as e: # user_error, or exception bubbling up from _reader_task if len(e.exceptions) == 1: - raise e.exceptions[0] + _raise(e.exceptions[0]) # contains at most 1 non-cancelled exceptions - exception_to_raise: BaseException|None = None + exception_to_raise: BaseException | None = None for sub_exc in e.exceptions: if not isinstance(sub_exc, trio.Cancelled): if exception_to_raise is not None: @@ -218,25 +278,43 @@ async def _close_connection(connection: WebSocketConnection) -> None: else: if exception_to_raise is None: # all exceptions are cancelled - # prefer raising the one from the user, for traceback reasons + # we reraise the user exception and throw out internal if user_error is not None: - # no reason to raise from e, just to include a bunch of extra - # cancelleds. - raise user_error # pylint: disable=raise-missing-from + _raise(user_error) # multiple internal Cancelled is not possible afaik - raise e.exceptions[0] # pragma: no cover # pylint: disable=raise-missing-from - raise exception_to_raise + # but if so we just raise one of them + _raise(e.exceptions[0]) # pragma: no cover + # raise the non-cancelled exception + _raise(exception_to_raise) - # if we have any KeyboardInterrupt in the group, make sure to raise it. + # if we have any KeyboardInterrupt in the group, raise a new KeyboardInterrupt + # with the group as cause & context for sub_exc in e.exceptions: if isinstance(sub_exc, KeyboardInterrupt): - raise sub_exc from e + raise KeyboardInterrupt from e # Both user code and internal code raised non-cancelled exceptions. - # We "hide" the internal exception(s) in the __cause__ and surface - # the user_error. + # We set the context to be an exception group containing internal exceptions + # and, if not None, `user_error.__context__` if user_error is not None: - raise user_error from e + exceptions = [subexc for subexc in e.exceptions if subexc is not user_error] + eg_substr = "" + # there's technically loss of info here, with __suppress_context__=True you + # still have original __context__ available, just not printed. But we delete + # it completely because we can't partially suppress the group + if ( + user_error.__context__ is not None + and not user_error.__suppress_context__ + ): + exceptions.append(user_error.__context__) + eg_substr = " and the context for the user exception" + eg_str = ( + "Both internal and user exceptions encountered. This group contains " + "the internal exception(s)" + eg_substr + "." + ) + user_error.__context__ = BaseExceptionGroup(eg_str, exceptions) + user_error.__suppress_context__ = False + _raise(user_error) raise TrioWebsocketInternalError( "The trio-websocket API is not expected to raise multiple exceptions. " @@ -248,17 +326,24 @@ async def _close_connection(connection: WebSocketConnection) -> None: if close_result is not None: close_result.unwrap() - # error setting up, unwrap that exception if connection is None: result.unwrap() -async def connect_websocket(nursery, host, port, resource, *, use_ssl, - subprotocols=None, extra_headers=None, - message_queue_size=MESSAGE_QUEUE_SIZE, max_message_size=MAX_MESSAGE_SIZE - ) -> WebSocketConnection: - ''' +async def connect_websocket( + nursery: trio.Nursery, + host: str, + port: int, + resource: str, + *, + use_ssl: bool | ssl.SSLContext, + subprotocols: Iterable[str] | None = None, + extra_headers: list[tuple[bytes, bytes]] | None = None, + message_queue_size: int = MESSAGE_QUEUE_SIZE, + max_message_size: int = MAX_MESSAGE_SIZE, +) -> WebSocketConnection: + """ Return an open WebSocket client connection to a host. This function is used to specify a custom nursery to run connection @@ -286,7 +371,7 @@ async def connect_websocket(nursery, host, port, resource, *, use_ssl, ``len()``. If a message is received that is larger than this size, then the connection is closed with code 1009 (Message Too Big). :rtype: WebSocketConnection - ''' + """ if use_ssl is True: ssl_context = ssl.create_default_context() elif use_ssl is False: @@ -294,37 +379,53 @@ async def connect_websocket(nursery, host, port, resource, *, use_ssl, elif isinstance(use_ssl, ssl.SSLContext): ssl_context = use_ssl else: - raise TypeError('`use_ssl` argument must be bool or ssl.SSLContext') - - logger.debug('Connecting to ws%s://%s:%d%s', - '' if ssl_context is None else 's', host, port, resource) + raise TypeError("`use_ssl` argument must be bool or ssl.SSLContext") + + logger.debug( + "Connecting to ws%s://%s:%d%s", + "" if ssl_context is None else "s", + host, + port, + resource, + ) stream: trio.SSLStream[trio.SocketStream] | trio.SocketStream if ssl_context is None: stream = await trio.open_tcp_stream(host, port) else: - stream = await trio.open_ssl_over_tcp_stream(host, port, - ssl_context=ssl_context, https_compatible=True) + stream = await trio.open_ssl_over_tcp_stream( + host, port, ssl_context=ssl_context, https_compatible=True + ) if port in (80, 443): host_header = host else: - host_header = f'{host}:{port}' - connection = WebSocketConnection(stream, + host_header = f"{host}:{port}" + connection = WebSocketConnection( + stream, WSConnection(ConnectionType.CLIENT), host=host_header, path=resource, - client_subprotocols=subprotocols, client_extra_headers=extra_headers, + client_subprotocols=subprotocols, + client_extra_headers=extra_headers, message_queue_size=message_queue_size, - max_message_size=max_message_size) + max_message_size=max_message_size, + ) nursery.start_soon(connection._reader_task) await connection._open_handshake.wait() return connection -def open_websocket_url(url, ssl_context=None, *, subprotocols=None, - extra_headers=None, - message_queue_size=MESSAGE_QUEUE_SIZE, max_message_size=MAX_MESSAGE_SIZE, - connect_timeout=CONN_TIMEOUT, disconnect_timeout=CONN_TIMEOUT): - ''' +def open_websocket_url( + url: str, + ssl_context: ssl.SSLContext | None = None, + *, + subprotocols: Iterable[str] | None = None, + extra_headers: list[tuple[bytes, bytes]] | None = None, + message_queue_size: int = MESSAGE_QUEUE_SIZE, + max_message_size: int = MAX_MESSAGE_SIZE, + connect_timeout: float = CONN_TIMEOUT, + disconnect_timeout: float = CONN_TIMEOUT, +) -> AbstractAsyncContextManager[WebSocketConnection]: + """ Open a WebSocket client connection to a URL. This async context manager connects when entering the context manager and @@ -353,19 +454,33 @@ def open_websocket_url(url, ssl_context=None, *, subprotocols=None, :raises HandshakeError: for any networking error, client-side timeout (:exc:`ConnectionTimeout`, :exc:`DisconnectionTimeout`), or server rejection (:exc:`ConnectionRejected`) during handshakes. - ''' - host, port, resource, ssl_context = _url_to_host(url, ssl_context) - return open_websocket(host, port, resource, use_ssl=ssl_context, - subprotocols=subprotocols, extra_headers=extra_headers, + """ + host, port, resource, return_ssl_context = _url_to_host(url, ssl_context) + return open_websocket( + host, + port, + resource, + use_ssl=return_ssl_context, + subprotocols=subprotocols, + extra_headers=extra_headers, message_queue_size=message_queue_size, max_message_size=max_message_size, - connect_timeout=connect_timeout, disconnect_timeout=disconnect_timeout) - - -async def connect_websocket_url(nursery, url, ssl_context=None, *, - subprotocols=None, extra_headers=None, - message_queue_size=MESSAGE_QUEUE_SIZE, max_message_size=MAX_MESSAGE_SIZE): - ''' + connect_timeout=connect_timeout, + disconnect_timeout=disconnect_timeout, + ) + + +async def connect_websocket_url( + nursery: trio.Nursery, + url: str, + ssl_context: ssl.SSLContext | None = None, + *, + subprotocols: Iterable[str] | None = None, + extra_headers: list[tuple[bytes, bytes]] | None = None, + message_queue_size: int = MESSAGE_QUEUE_SIZE, + max_message_size: int = MAX_MESSAGE_SIZE, +) -> WebSocketConnection: + """ Return an open WebSocket client connection to a URL. This function is used to specify a custom nursery to run connection @@ -390,16 +505,26 @@ async def connect_websocket_url(nursery, url, ssl_context=None, *, ``len()``. If a message is received that is larger than this size, then the connection is closed with code 1009 (Message Too Big). :rtype: WebSocketConnection - ''' - host, port, resource, ssl_context = _url_to_host(url, ssl_context) - return await connect_websocket(nursery, host, port, resource, - use_ssl=ssl_context, subprotocols=subprotocols, - extra_headers=extra_headers, message_queue_size=message_queue_size, - max_message_size=max_message_size) + """ + host, port, resource, return_ssl_context = _url_to_host(url, ssl_context) + return await connect_websocket( + nursery, + host, + port, + resource, + use_ssl=return_ssl_context, + subprotocols=subprotocols, + extra_headers=extra_headers, + message_queue_size=message_queue_size, + max_message_size=max_message_size, + ) -def _url_to_host(url, ssl_context): - ''' +def _url_to_host( + url: str, + ssl_context: ssl.SSLContext | None, +) -> tuple[str, int, str, ssl.SSLContext | bool]: + """ Convert a WebSocket URL to a (host,port,resource) tuple. The returned ``ssl_context`` is either the same object that was passed in, @@ -409,16 +534,21 @@ def _url_to_host(url, ssl_context): :param str url: A WebSocket URL. :type ssl_context: ssl.SSLContext or None :returns: A tuple of ``(host, port, resource, ssl_context)``. - ''' + """ url = str(url) # For backward compat with isinstance(url, yarl.URL). parts = urllib.parse.urlsplit(url) - if parts.scheme not in ('ws', 'wss'): + if parts.scheme not in ("ws", "wss"): raise ValueError('WebSocket URL scheme must be "ws:" or "wss:"') + return_ssl_context: ssl.SSLContext | bool if ssl_context is None: - ssl_context = parts.scheme == 'wss' - elif parts.scheme == 'ws': - raise ValueError('SSL context must be None for ws: URL scheme') + return_ssl_context = parts.scheme == "wss" + elif parts.scheme == "ws": + raise ValueError("SSL context must be None for ws: URL scheme") + else: + return_ssl_context = ssl_context host = parts.hostname + if host is None: + raise ValueError("URL host must not be None") if parts.port is not None: port = parts.port else: @@ -428,16 +558,24 @@ def _url_to_host(url, ssl_context): # If the target URI's path component is empty, the client MUST # send "/" as the path within the origin-form of request-target. if not path_qs: - path_qs = '/' - if '?' in url: - path_qs += '?' + parts.query - return host, port, path_qs, ssl_context - - -async def wrap_client_stream(nursery, stream, host, resource, *, - subprotocols=None, extra_headers=None, - message_queue_size=MESSAGE_QUEUE_SIZE, max_message_size=MAX_MESSAGE_SIZE): - ''' + path_qs = "/" + if "?" in url: + path_qs += "?" + parts.query + return host, port, path_qs, return_ssl_context + + +async def wrap_client_stream( + nursery: trio.Nursery, + stream: trio.SocketStream | trio.SSLStream[trio.SocketStream], + host: str, + resource: str, + *, + subprotocols: Iterable[str] | None = None, + extra_headers: list[tuple[bytes, bytes]] | None = None, + message_queue_size: int = MESSAGE_QUEUE_SIZE, + max_message_size: int = MAX_MESSAGE_SIZE, +) -> WebSocketConnection: + """ Wrap an arbitrary stream in a WebSocket connection. This is a low-level function only needed in rare cases. In most cases, you @@ -461,21 +599,29 @@ async def wrap_client_stream(nursery, stream, host, resource, *, ``len()``. If a message is received that is larger than this size, then the connection is closed with code 1009 (Message Too Big). :rtype: WebSocketConnection - ''' - connection = WebSocketConnection(stream, + """ + connection = WebSocketConnection( + stream, WSConnection(ConnectionType.CLIENT), - host=host, path=resource, - client_subprotocols=subprotocols, client_extra_headers=extra_headers, + host=host, + path=resource, + client_subprotocols=subprotocols, + client_extra_headers=extra_headers, message_queue_size=message_queue_size, - max_message_size=max_message_size) + max_message_size=max_message_size, + ) nursery.start_soon(connection._reader_task) await connection._open_handshake.wait() return connection -async def wrap_server_stream(nursery, stream, - message_queue_size=MESSAGE_QUEUE_SIZE, max_message_size=MAX_MESSAGE_SIZE): - ''' +async def wrap_server_stream( + nursery: trio.Nursery, + stream: trio.abc.Stream, + message_queue_size: int = MESSAGE_QUEUE_SIZE, + max_message_size: int = MAX_MESSAGE_SIZE, +) -> WebSocketRequest: + """ Wrap an arbitrary stream in a server-side WebSocket. This is a low-level function only needed in rare cases. In most cases, you @@ -490,21 +636,32 @@ async def wrap_server_stream(nursery, stream, then the connection is closed with code 1009 (Message Too Big). :type stream: trio.abc.Stream :rtype: WebSocketRequest - ''' - connection = WebSocketConnection(stream, + """ + connection = WebSocketConnection( + stream, WSConnection(ConnectionType.SERVER), message_queue_size=message_queue_size, - max_message_size=max_message_size) + max_message_size=max_message_size, + ) nursery.start_soon(connection._reader_task) request = await connection._get_request() return request -async def serve_websocket(handler, host, port, ssl_context, *, - handler_nursery=None, message_queue_size=MESSAGE_QUEUE_SIZE, - max_message_size=MAX_MESSAGE_SIZE, connect_timeout=CONN_TIMEOUT, - disconnect_timeout=CONN_TIMEOUT, task_status=trio.TASK_STATUS_IGNORED): - ''' +async def serve_websocket( + handler: Callable[[WebSocketRequest], Awaitable[None]], + host: str | bytes | None, + port: int, + ssl_context: ssl.SSLContext | None, + *, + handler_nursery: trio.Nursery | None = None, + message_queue_size: int = MESSAGE_QUEUE_SIZE, + max_message_size: int = MAX_MESSAGE_SIZE, + connect_timeout: float = CONN_TIMEOUT, + disconnect_timeout: float = CONN_TIMEOUT, + task_status: trio.TaskStatus[WebSocketServer] = trio.TASK_STATUS_IGNORED, +) -> NoReturn: + """ Serve a WebSocket over TCP. This function supports the Trio nursery start protocol: ``server = await @@ -538,65 +695,89 @@ async def serve_websocket(handler, host, port, ssl_context, *, to finish the closing handshake before timing out. :param task_status: Part of Trio nursery start protocol. :returns: This function runs until cancelled. - ''' + """ + open_tcp_listeners: ( + partial[Coroutine[Any, Any, list[trio.SocketListener]]] + | partial[Coroutine[Any, Any, list[trio.SSLListener[trio.SocketStream]]]] + ) if ssl_context is None: open_tcp_listeners = partial(trio.open_tcp_listeners, port, host=host) else: - open_tcp_listeners = partial(trio.open_ssl_over_tcp_listeners, port, - ssl_context, host=host, https_compatible=True) + open_tcp_listeners = partial( + trio.open_ssl_over_tcp_listeners, + port, + ssl_context, + host=host, + https_compatible=True, + ) listeners = await open_tcp_listeners() - server = WebSocketServer(handler, listeners, - handler_nursery=handler_nursery, message_queue_size=message_queue_size, - max_message_size=max_message_size, connect_timeout=connect_timeout, - disconnect_timeout=disconnect_timeout) + server = WebSocketServer( + handler, + listeners, + handler_nursery=handler_nursery, + message_queue_size=message_queue_size, + max_message_size=max_message_size, + connect_timeout=connect_timeout, + disconnect_timeout=disconnect_timeout, + ) await server.run(task_status=task_status) class HandshakeError(Exception): - ''' + """ There was an error during connection or disconnection with the websocket server. - ''' + """ + class ConnectionTimeout(HandshakeError): - '''There was a timeout when connecting to the websocket server.''' + """There was a timeout when connecting to the websocket server.""" + class DisconnectionTimeout(HandshakeError): - '''There was a timeout when disconnecting from the websocket server.''' + """There was a timeout when disconnecting from the websocket server.""" + class ConnectionClosed(Exception): - ''' + """ A WebSocket operation cannot be completed because the connection is closed or in the process of closing. - ''' - def __init__(self, reason): - ''' + """ + + def __init__(self, reason: CloseReason | None) -> None: + """ Constructor. :param reason: :type reason: CloseReason - ''' - super().__init__() + """ + super().__init__(reason) self.reason = reason - def __repr__(self): - ''' Return representation. ''' - return f'{self.__class__.__name__}<{self.reason}>' + def __repr__(self) -> str: + """Return representation.""" + return f"{self.__class__.__name__}<{self.reason}>" class ConnectionRejected(HandshakeError): - ''' + """ A WebSocket connection could not be established because the server rejected the connection attempt. - ''' - def __init__(self, status_code, headers, body): - ''' + """ + + def __init__( + self, + status_code: int, + headers: tuple[tuple[bytes, bytes], ...], + body: bytes | None, + ) -> None: + """ Constructor. :param reason: :type reason: CloseReason - ''' - super().__init__() + """ + super().__init__(status_code, headers, body) #: a 3 digit HTTP status code self.status_code = status_code #: a tuple of 2-tuples containing header key/value pairs @@ -604,145 +785,165 @@ def __init__(self, status_code, headers, body): #: an optional ``bytes`` response body self.body = body - def __repr__(self): - ''' Return representation. ''' - return f'{self.__class__.__name__}' + def __repr__(self) -> str: + """Return representation.""" + return f"{self.__class__.__name__}" class CloseReason: - ''' Contains information about why a WebSocket was closed. ''' - def __init__(self, code, reason): - ''' + """Contains information about why a WebSocket was closed.""" + + def __init__(self, code: int, reason: str | None) -> None: + """ Constructor. :param int code: :param Optional[str] reason: - ''' + """ self._code = code try: self._name = wsframeproto.CloseReason(code).name except ValueError: if 1000 <= code <= 2999: - self._name = 'RFC_RESERVED' + self._name = "RFC_RESERVED" elif 3000 <= code <= 3999: - self._name = 'IANA_RESERVED' + self._name = "IANA_RESERVED" elif 4000 <= code <= 4999: - self._name = 'PRIVATE_RESERVED' + self._name = "PRIVATE_RESERVED" else: - self._name = 'INVALID_CODE' + self._name = "INVALID_CODE" self._reason = reason @property - def code(self): - ''' (Read-only) The numeric close code. ''' + def code(self) -> int: + """(Read-only) The numeric close code.""" return self._code @property - def name(self): - ''' (Read-only) The human-readable close code. ''' + def name(self) -> str: + """(Read-only) The human-readable close code.""" return self._name @property - def reason(self): - ''' (Read-only) An arbitrary reason string. ''' + def reason(self) -> str | None: + """(Read-only) An arbitrary reason string.""" return self._reason - def __repr__(self): - ''' Show close code, name, and reason. ''' - return f'{self.__class__.__name__}' \ - f'' + def __repr__(self) -> str: + """Show close code, name, and reason.""" + return ( + f"{self.__class__.__name__}" + f"" + ) + +NULL: Final = object() -class Future: - ''' Represents a value that will be available in the future. ''' - def __init__(self): - ''' Constructor. ''' - self._value = None + +class Future(Generic[T]): + """Represents a value that will be available in the future.""" + + def __init__(self) -> None: + """Constructor.""" + # We do some type shenanigins + # Would do `T | Literal[NULL]` but that's not right apparently. + self._value: T = cast(T, NULL) self._value_event = trio.Event() - def set_value(self, value): - ''' + def set_value(self, value: T) -> None: + """ Set a value, which will notify any waiters. :param value: - ''' + """ self._value = value self._value_event.set() - async def wait_value(self): - ''' + async def wait_value(self) -> T: + """ Wait for this future to have a value, then return it. :returns: The value set by ``set_value()``. - ''' + """ await self._value_event.wait() + assert self._value is not NULL return self._value class WebSocketRequest: - ''' + """ Represents a handshake presented by a client to a server. The server may modify the handshake or leave it as is. The server should call ``accept()`` to finish the handshake and obtain a connection object. - ''' - def __init__(self, connection, event): - ''' + """ + + def __init__( + self, + connection: WebSocketConnection, + event: wsproto.events.Request, + ) -> None: + """ Constructor. :param WebSocketConnection connection: :type event: wsproto.events.Request - ''' + """ self._connection = connection self._event = event @property - def headers(self): - ''' + def headers(self) -> list[tuple[bytes, bytes]]: + """ HTTP headers represented as a list of (name, value) pairs. :rtype: list[tuple] - ''' + """ return self._event.extra_headers @property - def path(self): - ''' + def path(self) -> str: + """ The requested URL path. :rtype: str - ''' + """ return self._event.target @property - def proposed_subprotocols(self): - ''' + def proposed_subprotocols(self) -> tuple[str, ...]: + """ A tuple of protocols proposed by the client. :rtype: tuple[str] - ''' + """ return tuple(self._event.subprotocols) @property - def local(self): - ''' + def local(self) -> Endpoint | str: + """ The connection's local endpoint. :rtype: Endpoint or str - ''' + """ return self._connection.local @property - def remote(self): - ''' + def remote(self) -> Endpoint | str: + """ The connection's remote endpoint. :rtype: Endpoint or str - ''' + """ return self._connection.remote - async def accept(self, *, subprotocol=None, extra_headers=None): - ''' + async def accept( + self, + *, + subprotocol: str | None = None, + extra_headers: list[tuple[bytes, bytes]] | None = None, + ) -> WebSocketConnection: + """ Accept the request and return a connection object. :param subprotocol: The selected subprotocol for this connection. @@ -751,14 +952,20 @@ async def accept(self, *, subprotocol=None, extra_headers=None): send as HTTP headers. :type extra_headers: list[tuple[bytes,bytes]] or None :rtype: WebSocketConnection - ''' + """ if extra_headers is None: extra_headers = [] await self._connection._accept(self._event, subprotocol, extra_headers) return self._connection - async def reject(self, status_code, *, extra_headers=None, body=None): - ''' + async def reject( + self, + status_code: int, + *, + extra_headers: list[tuple[bytes, bytes]] | None = None, + body: bytes | None = None, + ) -> None: + """ Reject the handshake. :param int status_code: The 3 digit HTTP status code. In order to be @@ -769,14 +976,18 @@ async def reject(self, status_code, *, extra_headers=None, body=None): :param body: If provided, this data will be sent in the response body, otherwise no response body will be sent. :type body: bytes or None - ''' + """ extra_headers = extra_headers or [] - body = body or b'' + body = body or b"" await self._connection._reject(status_code, extra_headers, body) -def _get_stream_endpoint(stream, *, local): - ''' +def _get_stream_endpoint( + stream: trio.abc.Stream, + *, + local: bool, +) -> Endpoint | str: + """ Construct an endpoint from a stream. :param trio.Stream stream: @@ -784,13 +995,14 @@ def _get_stream_endpoint(stream, *, local): :returns: An endpoint instance or ``repr()`` for streams that cannot be represented as an endpoint. :rtype: Endpoint or str - ''' + """ socket, is_ssl = None, False if isinstance(stream, trio.SocketStream): socket = stream.socket elif isinstance(stream, trio.SSLStream): socket = stream.transport_stream.socket is_ssl = True + endpoint: Endpoint | str if socket: addr, port, *_ = socket.getsockname() if local else socket.getpeername() endpoint = Endpoint(addr, port, is_ssl) @@ -800,22 +1012,23 @@ def _get_stream_endpoint(stream, *, local): class WebSocketConnection(trio.abc.AsyncResource): - ''' A WebSocket connection. ''' + """A WebSocket connection.""" CONNECTION_ID = itertools.count() def __init__( - self, - stream: trio.SocketStream | trio.SSLStream[trio.SocketStream], - ws_connection: wsproto.WSConnection, - *, - host=None, - path=None, - client_subprotocols=None, client_extra_headers=None, - message_queue_size=MESSAGE_QUEUE_SIZE, - max_message_size=MAX_MESSAGE_SIZE - ): - ''' + self, + stream: trio.abc.Stream, + ws_connection: wsproto.WSConnection, + *, + host: str | None = None, + path: str | None = None, + client_subprotocols: Iterable[str] | None = None, + client_extra_headers: list[tuple[bytes, bytes]] | None = None, + message_queue_size: int = MESSAGE_QUEUE_SIZE, + max_message_size: int = MAX_MESSAGE_SIZE, + ) -> None: + """ Constructor. Generally speaking, users are discouraged from directly instantiating a @@ -840,7 +1053,7 @@ def __init__( :param int max_message_size: The maximum message size as measured by ``len()``. If a message is received that is larger than this size, then the connection is closed with code 1009 (Message Too Big). - ''' + """ # NOTE: The implementation uses _close_reason for more than an advisory # purpose. It's critical internal state, indicating when the # connection is closed or closing. @@ -854,24 +1067,31 @@ def __init__( self._max_message_size = max_message_size self._reader_running = True if ws_connection.client: - self._initial_request: Optional[Request] = Request(host=host, target=path, - subprotocols=client_subprotocols, - extra_headers=client_extra_headers or []) + assert host is not None + assert path is not None + self._initial_request: Optional[Request] = Request( + host=host, + target=path, + subprotocols=list(client_subprotocols or ()), + extra_headers=client_extra_headers or [], + ) else: self._initial_request = None self._path = path self._subprotocol: Optional[str] = None - self._handshake_headers: tuple[tuple[str,str], ...] = tuple() + self._handshake_headers: tuple[tuple[bytes, bytes], ...] = () self._reject_status = 0 - self._reject_headers: tuple[tuple[str,str], ...] = tuple() - self._reject_body = b'' + self._reject_headers: tuple[tuple[bytes, bytes], ...] = () + self._reject_body = b"" self._send_channel, self._recv_channel = trio.open_memory_channel[ Union[bytes, str] ](message_queue_size) self._pings: OrderedDict[bytes, trio.Event] = OrderedDict() # Set when the server has received a connection request event. This # future is never set on client connections. - self._connection_proposal = Future() + self._connection_proposal: Future[WebSocketRequest] | None = Future[ + WebSocketRequest + ]() # Set once the WebSocket open handshake takes place, i.e. # ConnectionRequested for server or ConnectedEstablished for client. self._open_handshake = trio.Event() @@ -883,78 +1103,80 @@ def __init__( self._for_testing_peer_closed_connection = trio.Event() @property - def closed(self): - ''' + def closed(self) -> CloseReason | None: + """ (Read-only) The reason why the connection was or is being closed, else ``None``. :rtype: Optional[CloseReason] - ''' + """ return self._close_reason @property - def is_client(self): - ''' (Read-only) Is this a client instance? ''' + def is_client(self) -> bool: + """(Read-only) Is this a client instance?""" return self._wsproto.client @property - def is_server(self): - ''' (Read-only) Is this a server instance? ''' + def is_server(self) -> bool: + """(Read-only) Is this a server instance?""" return not self._wsproto.client @property - def local(self): - ''' + def local(self) -> Endpoint | str: + """ The local endpoint of the connection. :rtype: Endpoint or str - ''' + """ return _get_stream_endpoint(self._stream, local=True) @property - def remote(self): - ''' + def remote(self) -> Endpoint | str: + """ The remote endpoint of the connection. :rtype: Endpoint or str - ''' + """ return _get_stream_endpoint(self._stream, local=False) @property - def path(self): - ''' + def path(self) -> str | None: + """ The requested URL path. For clients, this is set when the connection is instantiated. For servers, it is set after the handshake completes. - :rtype: str - ''' + :rtype: str or None + """ return self._path @property - def subprotocol(self): - ''' + def subprotocol(self) -> str | None: + """ (Read-only) The negotiated subprotocol, or ``None`` if there is no subprotocol. This is only valid after the opening handshake is complete. :rtype: str or None - ''' + """ return self._subprotocol @property - def handshake_headers(self): - ''' + def handshake_headers(self) -> tuple[tuple[bytes, bytes], ...]: + """ The HTTP headers that were sent by the remote during the handshake, stored as 2-tuples containing key/value pairs. Header keys are always lower case. :rtype: tuple[tuple[str,str]] - ''' + """ return self._handshake_headers - async def aclose(self, code=1000, reason=None): # pylint: disable=arguments-differ - ''' + async def aclose( + self, code: int = 1000, reason: str | None = None + ) -> None: # pylint: disable=arguments-differ + """ Close the WebSocket connection. This sends a closing frame and suspends until the connection is closed. @@ -967,11 +1189,11 @@ async def aclose(self, code=1000, reason=None): # pylint: disable=arguments-dif :param int code: A 4-digit code number indicating the type of closure. :param str reason: An optional string describing the closure. - ''' + """ with _preserve_current_exception(): await self._aclose(code, reason) - async def _aclose(self, code, reason): + async def _aclose(self, code: int, reason: str | None) -> None: if self._close_reason: # Per AsyncResource interface, calling aclose() on a closed resource # should succeed. @@ -982,8 +1204,10 @@ async def _aclose(self, code, reason): # event to peer, while setting the local close reason to normal. self._close_reason = CloseReason(1000, None) await self._send(CloseConnection(code=code, reason=reason)) - elif self._wsproto.state in (ConnectionState.CONNECTING, - ConnectionState.REJECTING): + elif self._wsproto.state in ( + ConnectionState.CONNECTING, + ConnectionState.REJECTING, + ): self._close_handshake.set() # TODO: shouldn't the receive channel be closed earlier, so that # get_message() during send of the CloseConneciton event fails? @@ -997,8 +1221,8 @@ async def _aclose(self, code, reason): # stream is closed. await self._close_stream() - async def get_message(self): - ''' + async def get_message(self) -> str | bytes: + """ Receive the next WebSocket message. If no message is available immediately, then this function blocks until @@ -1013,15 +1237,15 @@ async def get_message(self): :rtype: str or bytes :raises ConnectionClosed: if the connection is closed. - ''' + """ try: message = await self._recv_channel.receive() except (trio.ClosedResourceError, trio.EndOfChannel): raise ConnectionClosed(self._close_reason) from None return message - async def ping(self, payload: bytes|None=None): - ''' + async def ping(self, payload: bytes | None = None) -> None: + """ Send WebSocket ping to remote endpoint and wait for a correspoding pong. Each in-flight ping must include a unique payload. This function sends @@ -1039,56 +1263,62 @@ async def ping(self, payload: bytes|None=None): :raises ConnectionClosed: if connection is closed. :raises ValueError: if ``payload`` is identical to another in-flight ping. - ''' + """ if self._close_reason: raise ConnectionClosed(self._close_reason) if payload in self._pings: - raise ValueError(f'Payload value {payload!r} is already in flight.') + raise ValueError(f"Payload value {payload!r} is already in flight.") if payload is None: - payload = struct.pack('!I', random.getrandbits(32)) + payload = struct.pack("!I", random.getrandbits(32)) event = trio.Event() self._pings[payload] = event await self._send(Ping(payload=payload)) await event.wait() - async def pong(self, payload=None): - ''' + async def pong(self, payload: bytes | None = None) -> None: + """ Send an unsolicted pong. :param payload: The pong's payload. If ``None``, then no payload is sent. :type payload: bytes or None :raises ConnectionClosed: if connection is closed - ''' + """ if self._close_reason: raise ConnectionClosed(self._close_reason) - await self._send(Pong(payload=payload)) + await self._send(Pong(payload=payload or b"")) - async def send_message(self, message): - ''' + async def send_message(self, message: str | bytes) -> None: + """ Send a WebSocket message. :param message: The message to send. :type message: str or bytes :raises ConnectionClosed: if connection is closed, or being closed - ''' + """ if self._close_reason: raise ConnectionClosed(self._close_reason) + event: TextMessage | BytesMessage if isinstance(message, str): event = TextMessage(data=message) elif isinstance(message, bytes): event = BytesMessage(data=message) else: - raise ValueError('message must be str or bytes') + raise ValueError("message must be str or bytes") await self._send(event) - def __str__(self): - ''' Connection ID and type. ''' - type_ = 'client' if self.is_client else 'server' - return f'{type_}-{self._id}' - - async def _accept(self, request, subprotocol, extra_headers): - ''' + def __str__(self) -> str: + """Connection ID and type.""" + type_ = "client" if self.is_client else "server" + return f"{type_}-{self._id}" + + async def _accept( + self, + request: Request, + subprotocol: str | None, + extra_headers: list[tuple[bytes, bytes]], + ) -> None: + """ Accept the handshake. This method is only applicable to server-side connections. @@ -1098,15 +1328,21 @@ async def _accept(self, request, subprotocol, extra_headers): :type subprotocol: str or None :param list[tuple[bytes,bytes]] extra_headers: A list of 2-tuples containing key/value pairs to send as HTTP headers. - ''' + """ self._subprotocol = subprotocol self._path = request.target - await self._send(AcceptConnection(subprotocol=self._subprotocol, - extra_headers=extra_headers)) + await self._send( + AcceptConnection(subprotocol=self._subprotocol, extra_headers=extra_headers) + ) self._open_handshake.set() - async def _reject(self, status_code, headers, body): - ''' + async def _reject( + self, + status_code: int, + headers: list[tuple[bytes, bytes]], + body: bytes, + ) -> None: + """ Reject the handshake. :param int status_code: The 3 digit HTTP status code. In order to be @@ -1115,25 +1351,26 @@ async def _reject(self, status_code, headers, body): :param list[tuple[bytes,bytes]] headers: A list of 2-tuples containing key/value pairs to send as HTTP headers. :param bytes body: An optional response body. - ''' + """ if body: - headers.append(('Content-length', str(len(body)).encode('ascii'))) - reject_conn = RejectConnection(status_code=status_code, headers=headers, - has_body=bool(body)) + headers.append((b"Content-length", str(len(body)).encode("ascii"))) + reject_conn = RejectConnection( + status_code=status_code, headers=headers, has_body=bool(body) + ) await self._send(reject_conn) if body: reject_body = RejectData(data=body) await self._send(reject_body) - self._close_reason = CloseReason(1006, 'Rejected WebSocket handshake') + self._close_reason = CloseReason(1006, "Rejected WebSocket handshake") self._close_handshake.set() - async def _abort_web_socket(self): - ''' + async def _abort_web_socket(self) -> None: + """ If a stream is closed outside of this class, e.g. due to network conditions or because some other code closed our stream object, then we cannot perform the close handshake. We just need to clean up internal state. - ''' + """ close_reason = wsframeproto.CloseReason.ABNORMAL_CLOSURE if self._wsproto.state == ConnectionState.OPEN: self._wsproto.send(CloseConnection(code=close_reason.value)) @@ -1144,8 +1381,8 @@ async def _abort_web_socket(self): # (e.g. self.aclose()) to resume. self._close_handshake.set() - async def _close_stream(self): - ''' Close the TCP connection. ''' + async def _close_stream(self) -> None: + """Close the TCP connection.""" self._reader_running = False try: with _preserve_current_exception(): @@ -1154,86 +1391,97 @@ async def _close_stream(self): # This means the TCP connection is already dead. pass - async def _close_web_socket(self, code, reason=None): - ''' + async def _close_web_socket(self, code: int, reason: str | None = None) -> None: + """ Mark the WebSocket as closed. Close the message channel so that if any tasks are suspended in get_message(), they will wake up with a ConnectionClosed exception. - ''' + """ self._close_reason = CloseReason(code, reason) exc = ConnectionClosed(self._close_reason) - logger.debug('%s websocket closed %r', self, exc) + logger.debug("%s websocket closed %r", self, exc) await self._send_channel.aclose() - async def _get_request(self): - ''' + async def _get_request(self) -> WebSocketRequest: + """ Return a proposal for a WebSocket handshake. This method can only be called on server connections and it may only be called one time. :rtype: WebSocketRequest - ''' + """ if not self.is_server: - raise RuntimeError('This method is only valid for server connections.') + raise RuntimeError("This method is only valid for server connections.") if self._connection_proposal is None: - raise RuntimeError('No proposal available. Did you call this method' - ' multiple times or at the wrong time?') + raise RuntimeError( + "No proposal available. Did you call this method" + " multiple times or at the wrong time?" + ) proposal = await self._connection_proposal.wait_value() self._connection_proposal = None return proposal - async def _handle_request_event(self, event): - ''' + async def _handle_request_event(self, event: wsproto.events.Request) -> None: + """ Handle a connection request. This method is async even though it never awaits, because the event dispatch requires an async function. :param event: - ''' + """ proposal = WebSocketRequest(self, event) + assert self._connection_proposal is not None self._connection_proposal.set_value(proposal) - async def _handle_accept_connection_event(self, event): - ''' + async def _handle_accept_connection_event( + self, event: wsproto.events.AcceptConnection + ) -> None: + """ Handle an AcceptConnection event. :param wsproto.eventsAcceptConnection event: - ''' + """ self._subprotocol = event.subprotocol self._handshake_headers = tuple(event.extra_headers) self._open_handshake.set() - async def _handle_reject_connection_event(self, event): - ''' + async def _handle_reject_connection_event( + self, event: wsproto.events.RejectConnection + ) -> None: + """ Handle a RejectConnection event. :param event: - ''' + """ self._reject_status = event.status_code self._reject_headers = tuple(event.headers) if not event.has_body: - raise ConnectionRejected(self._reject_status, self._reject_headers, - body=None) + raise ConnectionRejected( + self._reject_status, self._reject_headers, body=None + ) - async def _handle_reject_data_event(self, event): - ''' + async def _handle_reject_data_event(self, event: wsproto.events.RejectData) -> None: + """ Handle a RejectData event. :param event: - ''' + """ self._reject_body += event.data if event.body_finished: - raise ConnectionRejected(self._reject_status, self._reject_headers, - body=self._reject_body) + raise ConnectionRejected( + self._reject_status, self._reject_headers, body=self._reject_body + ) - async def _handle_close_connection_event(self, event): - ''' + async def _handle_close_connection_event( + self, event: wsproto.events.CloseConnection + ) -> None: + """ Handle a close event. :param wsproto.events.CloseConnection event: - ''' + """ if self._wsproto.state == ConnectionState.REMOTE_CLOSING: # Set _close_reason in advance, so that send_message() will raise # ConnectionClosed during the close handshake. @@ -1249,17 +1497,20 @@ async def _handle_close_connection_event(self, event): if self.is_server: await self._close_stream() - async def _handle_message_event(self, event): - ''' + async def _handle_message_event( + self, + event: wsproto.events.BytesMessage | wsproto.events.TextMessage, + ) -> None: + """ Handle a message event. :param event: :type event: wsproto.events.BytesMessage or wsproto.events.TextMessage - ''' + """ self._message_size += len(event.data) self._message_parts.append(event.data) if self._message_size > self._max_message_size: - err = f'Exceeded maximum message size: {self._max_message_size} bytes' + err = f"Exceeded maximum message size: {self._max_message_size} bytes" self._message_size = 0 self._message_parts = [] self._close_reason = CloseReason(1009, err) @@ -1267,8 +1518,12 @@ async def _handle_message_event(self, event): await self._recv_channel.aclose() self._reader_running = False elif event.message_finished: - msg = (b'' if isinstance(event, BytesMessage) else '') \ - .join(self._message_parts) + msg: str | bytes + # Type checker does not understand `_message_parts` + if isinstance(event, BytesMessage): + msg = b"".join(cast("list[bytes]", self._message_parts)) + else: + msg = "".join(cast("list[str]", self._message_parts)) self._message_size = 0 self._message_parts = [] try: @@ -1279,20 +1534,20 @@ async def _handle_message_event(self, event): # and there's no useful cleanup that we can do here. pass - async def _handle_ping_event(self, event): - ''' + async def _handle_ping_event(self, event: wsproto.events.Ping) -> None: + """ Handle a PingReceived event. Wsproto queues a pong frame automatically, so this handler just needs to send it. :param wsproto.events.Ping event: - ''' - logger.debug('%s ping %r', self, event.payload) + """ + logger.debug("%s ping %r", self, event.payload) await self._send(event.response()) - async def _handle_pong_event(self, event): - ''' + async def _handle_pong_event(self, event: wsproto.events.Pong) -> None: + """ Handle a PongReceived event. When a pong is received, check if we have any ping requests waiting for @@ -1304,24 +1559,24 @@ async def _handle_pong_event(self, event): complicated if some handlers were sync. :param event: - ''' + """ payload = bytes(event.payload) try: - event = self._pings[payload] + ping_event = self._pings[payload] except KeyError: # We received a pong that doesn't match any in-flight pongs. Nothing # we can do with it, so ignore it. return while self._pings: - key, event = self._pings.popitem(0) - skipped = ' [skipped] ' if payload != key else ' ' - logger.debug('%s pong%s%r', self, skipped, key) - event.set() + key, ping_event = self._pings.popitem(False) + skipped = " [skipped] " if payload != key else " " + logger.debug("%s pong%s%r", self, skipped, key) + ping_event.set() if payload == key: break - async def _reader_task(self): - ''' A background task that reads network data and generates events. ''' + async def _reader_task(self) -> None: + """A background task that reads network data and generates events.""" handlers = { AcceptConnection: self._handle_accept_connection_event, BytesMessage: self._handle_message_event, @@ -1348,12 +1603,15 @@ async def _reader_task(self): event_type = type(event) try: handler = handlers[event_type] - logger.debug('%s received event: %s', self, - event_type) - await handler(event) + logger.debug("%s received event: %s", self, event_type) + # Type checkers don't understand looking up type in handlers. + # If we wanted to fix, best I can figure is we'd need a huge + # if-else or case block for every type individually. + await handler(event) # type: ignore[operator] except KeyError: - logger.warning('%s received unknown event type: "%s"', self, - event_type) + logger.warning( + '%s received unknown event type: "%s"', self, event_type + ) except ConnectionClosed: self._reader_running = False break @@ -1365,27 +1623,26 @@ async def _reader_task(self): await self._abort_web_socket() break if len(data) == 0: - logger.debug('%s received zero bytes (connection closed)', - self) + logger.debug("%s received zero bytes (connection closed)", self) # If TCP closed before WebSocket, then record it as an abnormal # closure. if self._wsproto.state != ConnectionState.CLOSED: await self._abort_web_socket() break - logger.debug('%s received %d bytes', self, len(data)) + logger.debug("%s received %d bytes", self, len(data)) if self._wsproto.state != ConnectionState.CLOSED: try: self._wsproto.receive_data(data) except wsproto.utilities.RemoteProtocolError as err: - logger.debug('%s remote protocol error: %s', self, err) + logger.debug("%s remote protocol error: %s", self, err) if err.event_hint: await self._send(err.event_hint) await self._close_stream() - logger.debug('%s reader task finished', self) + logger.debug("%s reader task finished", self) - async def _send(self, event): - ''' + async def _send(self, event: wsproto.events.Event) -> None: + """ Send an event to the remote WebSocket. The reader task and one or more writers might try to send messages at @@ -1393,20 +1650,22 @@ async def _send(self, event): requests to send data. :param wsproto.events.Event event: - ''' + """ data = self._wsproto.send(event) async with self._stream_lock: - logger.debug('%s sending %d bytes', self, len(data)) + logger.debug("%s sending %d bytes", self, len(data)) try: await self._stream.send_all(data) except (trio.BrokenResourceError, trio.ClosedResourceError): await self._abort_web_socket() + assert self._close_reason is not None raise ConnectionClosed(self._close_reason) from None class Endpoint: - ''' Represents a connection endpoint. ''' - def __init__(self, address, port, is_ssl): + """Represents a connection endpoint.""" + + def __init__(self, address: str | int, port: int, is_ssl: bool) -> None: #: IP address :class:`ipaddress.ip_address` self.address = ip_address(address) #: TCP port @@ -1415,38 +1674,44 @@ def __init__(self, address, port, is_ssl): self.is_ssl = is_ssl @property - def url(self): - ''' Return a URL representation of a TCP endpoint, e.g. - ``ws://127.0.0.1:80``. ''' - scheme = 'wss' if self.is_ssl else 'ws' - if (self.port == 80 and not self.is_ssl) or \ - (self.port == 443 and self.is_ssl): - port_str = '' + def url(self) -> str: + """Return a URL representation of a TCP endpoint, e.g. + ``ws://127.0.0.1:80``.""" + scheme = "wss" if self.is_ssl else "ws" + if (self.port == 80 and not self.is_ssl) or (self.port == 443 and self.is_ssl): + port_str = "" else: - port_str = ':' + str(self.port) + port_str = ":" + str(self.port) if self.address.version == 4: - return f'{scheme}://{self.address}{port_str}' - return f'{scheme}://[{self.address}]{port_str}' + return f"{scheme}://{self.address}{port_str}" + return f"{scheme}://[{self.address}]{port_str}" - def __repr__(self): - ''' Return endpoint info as string. ''' + def __repr__(self) -> str: + """Return endpoint info as string.""" return f'Endpoint(address="{self.address}", port={self.port}, is_ssl={self.is_ssl})' class WebSocketServer: - ''' + """ WebSocket server. The server class handles incoming connections on one or more ``Listener`` objects. For each incoming connection, it creates a ``WebSocketConnection`` instance and starts some background tasks, - ''' + """ - def __init__(self, handler, listeners, *, handler_nursery=None, - message_queue_size=MESSAGE_QUEUE_SIZE, - max_message_size=MAX_MESSAGE_SIZE, connect_timeout=CONN_TIMEOUT, - disconnect_timeout=CONN_TIMEOUT): - ''' + def __init__( + self, + handler: Callable[[WebSocketRequest], Awaitable[None]], + listeners: Sequence[trio.SSLListener[trio.SocketStream] | trio.SocketListener], + *, + handler_nursery: trio.Nursery | None = None, + message_queue_size: int = MESSAGE_QUEUE_SIZE, + max_message_size: int = MAX_MESSAGE_SIZE, + connect_timeout: float = CONN_TIMEOUT, + disconnect_timeout: float = CONN_TIMEOUT, + ) -> None: + """ Constructor. Note that if ``host`` is ``None`` and ``port`` is zero, then you may get @@ -1465,9 +1730,9 @@ def __init__(self, handler, listeners, *, handler_nursery=None, to finish connection handshake before timing out. :param float disconnect_timeout: The number of seconds to wait for a client to finish the closing handshake before timing out. - ''' + """ if len(listeners) == 0: - raise ValueError('Listeners must contain at least one item.') + raise ValueError("Listeners must contain at least one item.") self._handler = handler self._handler_nursery = handler_nursery self._listeners = listeners @@ -1477,7 +1742,7 @@ def __init__(self, handler, listeners, *, handler_nursery=None, self._disconnect_timeout = disconnect_timeout @property - def port(self): + def port(self) -> int: """Returns the requested or kernel-assigned port number. In the case of kernel-assigned port (requested with port=0 in the @@ -1489,31 +1754,36 @@ def port(self): listener must be socket-based. """ if len(self._listeners) > 1: - raise RuntimeError('Cannot get port because this server has' - ' more than 1 listeners.') + raise RuntimeError( + "Cannot get port because this server has more than 1 listener." + ) listener = self.listeners[0] try: - return listener.port + return listener.port # type: ignore[union-attr] except AttributeError: - raise RuntimeError(f'This socket does not have a port: {repr(listener)}') from None + raise RuntimeError( + f"This socket does not have a port: {repr(listener)}" + ) from None @property - def listeners(self): - ''' + def listeners(self) -> list[Endpoint | str]: + """ Return a list of listener metadata. Each TCP listener is represented as an ``Endpoint`` instance. Other listener types are represented by their ``repr()``. :returns: Listeners :rtype list[Endpoint or str]: - ''' - listeners = [] + """ + listeners: list[Endpoint | str] = [] for listener in self._listeners: socket, is_ssl = None, False if isinstance(listener, trio.SocketListener): socket = listener.socket elif isinstance(listener, trio.SSLListener): - socket = listener.transport_listener.socket + internal_listener = listener.transport_listener + assert isinstance(internal_listener, trio.SocketListener) + socket = internal_listener.socket is_ssl = True if socket: sockname = socket.getsockname() @@ -1522,8 +1792,16 @@ def listeners(self): listeners.append(repr(listener)) return listeners - async def run(self, *, task_status=trio.TASK_STATUS_IGNORED): - ''' + # Type ignore is because type checker does not think NoReturn is + # real for Trio 0.25.1 (current version used in requirements file as + # of writing). Not a problem for newer versions however, which is + # why we have unused-ignore as well. + async def run( # type: ignore[misc,unused-ignore] + self, + *, + task_status: trio.TaskStatus[WebSocketServer] = trio.TASK_STATUS_IGNORED, + ) -> NoReturn: + """ Start serving incoming connections requests. This method supports the Trio nursery start protocol: ``server = await @@ -1532,30 +1810,34 @@ async def run(self, *, task_status=trio.TASK_STATUS_IGNORED): :param task_status: Part of the Trio nursery start protocol. :returns: This method never returns unless cancelled. - ''' + """ async with trio.open_nursery() as nursery: - serve_listeners = partial(trio.serve_listeners, - self._handle_connection, self._listeners, - handler_nursery=self._handler_nursery) + serve_listeners = partial( + trio.serve_listeners, + self._handle_connection, + list(self._listeners), + handler_nursery=self._handler_nursery, + ) await nursery.start(serve_listeners) - logger.debug('Listening on %s', - ','.join([str(l) for l in self.listeners])) + logger.debug("Listening on %s", ",".join([str(l) for l in self.listeners])) task_status.started(self) await trio.sleep_forever() - async def _handle_connection(self, stream): - ''' + async def _handle_connection(self, stream: trio.abc.Stream) -> None: + """ Handle an incoming connection by spawning a connection background task and a handler task inside a new nursery. :param stream: :type stream: trio.abc.Stream - ''' + """ async with trio.open_nursery() as nursery: - connection = WebSocketConnection(stream, + connection = WebSocketConnection( + stream, WSConnection(ConnectionType.SERVER), message_queue_size=self._message_queue_size, - max_message_size=self._max_message_size) + max_message_size=self._max_message_size, + ) nursery.start_soon(connection._reader_task) with trio.move_on_after(self._connect_timeout) as connect_scope: request = await connection._get_request() diff --git a/trio_websocket/_version.py b/trio_websocket/_version.py index 2320701..5c47800 100644 --- a/trio_websocket/_version.py +++ b/trio_websocket/_version.py @@ -1 +1 @@ -__version__ = '0.12.0-dev' +__version__ = "0.12.0-dev" diff --git a/trio_websocket/py.typed b/trio_websocket/py.typed new file mode 100644 index 0000000..e69de29