From 1b2d5e46e50c49ebc7486abb920a4936e4b3c05a Mon Sep 17 00:00:00 2001 From: Walter BONETTI Date: Thu, 6 Jun 2024 10:25:36 -0400 Subject: [PATCH 1/5] Add pyOpenSSL context support Signed-off-by: Walter BONETTI --- pyproject.toml | 3 + src/paho/mqtt/client.py | 149 +++++++++++++++++++++++++++++++--------- 2 files changed, 119 insertions(+), 33 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e990812d..f2c3a944 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,9 @@ dependencies = [] proxy = [ "PySocks", ] +openssl = [ + "pyOpenSSL" +] [project.urls] Homepage = "http://eclipse.org/paho" diff --git a/src/paho/mqtt/client.py b/src/paho/mqtt/client.py index 4ccc8696..7456b44f 100644 --- a/src/paho/mqtt/client.py +++ b/src/paho/mqtt/client.py @@ -45,6 +45,54 @@ from .reasoncodes import ReasonCode, ReasonCodes from .subscribeoptions import SubscribeOptions +try: + from OpenSSL import SSL + from OpenSSL.crypto import X509 + + def _subject_alt_name_string(cert: X509) -> list: + """Extracts the subject alternative name (SAN) entries from the certificate.""" + san = [] + for i in range(cert.get_extension_count()): + ext = cert.get_extension(i) + if ext.get_short_name() == b'subjectAltName': + san_entries = ext.__str__().split(', ') + for entry in san_entries: + key, value = entry.split(':', 1) + print(f"key {key}: value {value}") + san.append((key.strip(), value.strip())) + return san + + def _openssl_match_hostname(cert: X509, hostname: str): + """Verify that *cert* matches the *hostname* according to RFC 2818 and RFC 6125 rules. + CertificateError is raised on failure. On success, the function returns nothing. + """ + if not cert: + raise ValueError("Empty or no certificate. match_hostname needs a certificate.") + + dnsnames = [] + # Extract subject alternative name (SAN) entries + san = _subject_alt_name_string(cert) + for key, value in san: + if key == 'DNS': + if ssl._dnsname_match(value, hostname): + return + dnsnames.append(value) + + if not dnsnames: + # TODO: check if no dns entry to use subject + raise ValueError("pyOpenssl match_hostname: using subject is not supported.") + + if len(dnsnames) > 1: + raise ssl.CertificateError(f"Hostname {hostname} doesn't match any of {', '.join(map(repr, dnsnames))}") + elif len(dnsnames) == 1: + raise ssl.CertificateError(f"Hostname {hostname} doesn't match {dnsnames[0]}") + else: + raise ssl.CertificateError("No appropriate commonName or subjectAltName fields were found") + + HAS_OPENSSL = True +except ImportError: + HAS_OPENSSL = False + try: from typing import Literal except ImportError: @@ -851,7 +899,7 @@ def __init__( self._thread: threading.Thread | None = None self._thread_terminate = False self._ssl = False - self._ssl_context: ssl.SSLContext | None = None + self._ssl_context: ssl.SSLContext | SSL.Context | None = None # Only used when SSL context does not have check_hostname attribute self._tls_insecure = False self._logger: logging.Logger | None = None @@ -1181,26 +1229,37 @@ def ws_set_options( def tls_set_context( self, - context: ssl.SSLContext | None = None, + context: ssl.SSLContext | SSL.Context | None = None, ) -> None: """Configure network encryption and authentication context. Enables SSL/TLS support. - :param context: an ssl.SSLContext object. By default this is given by - ``ssl.create_default_context()``, if available. + :param context: an ssl.SSLContext or OpenSSL.SSL.Context object. By default, this is given by + ``ssl.create_default_context()`` if available. - Must be called before `connect()`, `connect_async()` or `connect_srv()`.""" + Must be called before `connect()`, `connect_async()` or `connect_srv()`. + """ if self._ssl_context is not None: raise ValueError('SSL/TLS has already been configured.') if context is None: - context = ssl.create_default_context() + if HAS_OPENSSL: + raise ValueError("OpenSSL custom context is not provided.") + else: + context = ssl.create_default_context() self._ssl = True self._ssl_context = context - # Ensure _tls_insecure is consistent with check_hostname attribute - if hasattr(context, 'check_hostname'): + # Ensure _tls_insecure is consistent with check_hostname attribute for ssl.SSLContext + if isinstance(context, ssl.SSLContext) and hasattr(context, 'check_hostname'): self._tls_insecure = not context.check_hostname + elif HAS_OPENSSL and isinstance(context, SSL.Context): + # PyOpenSSL Context does not have check_hostname attribute + # Set _tls_insecure based on custom logic if necessary + self._tls_insecure = False # Assuming default to False for PyOpenSSL + else: + # If OpenSSL is not available and context is an SSL.Context, raise an error + raise ValueError("OpenSSL is not available, cannot use SSL.Context.") def tls_set( self, @@ -4638,43 +4697,67 @@ def _create_socket_connection(self) -> _socket.socket: return socks.create_connection(addr, timeout=self._connect_timeout, source_address=source, **proxy) else: return socket.create_connection(addr, timeout=self._connect_timeout, source_address=source) - - def _ssl_wrap_socket(self, tcp_sock: _socket.socket) -> ssl.SSLSocket: + + def _ssl_wrap_socket(self, tcp_sock: _socket) -> _socket.socket: if self._ssl_context is None: raise ValueError( "Impossible condition. _ssl_context should never be None if _ssl is True" ) - + verify_host = not self._tls_insecure try: - # Try with server_hostname, even it's not supported in certain scenarios - ssl_sock = self._ssl_context.wrap_socket( - tcp_sock, - server_hostname=self._host, - do_handshake_on_connect=False, - ) + if isinstance(self._ssl_context, ssl.SSLContext): + # Use the built-in ssl.SSLContext + ssl_sock = self._ssl_context.wrap_socket( + tcp_sock, + server_hostname=self._host, + do_handshake_on_connect=False, + ) + elif HAS_OPENSSL and isinstance(self._ssl_context, SSL.Context): + # Use PyOpenSSL's SSL.Context + conn = SSL.Connection(self._ssl_context, tcp_sock) + conn.set_connect_state() + if self._host: + conn.set_tlsext_host_name(self._host.encode('utf-8')) + ssl_sock = conn + else: + raise ValueError("Unsupported SSL context type") except ssl.CertificateError: - # CertificateError is derived from ValueError raise except ValueError: - # Python version requires SNI in order to handle server_hostname, but SNI is not available - ssl_sock = self._ssl_context.wrap_socket( - tcp_sock, - do_handshake_on_connect=False, - ) - else: - # If SSL context has already checked hostname, then don't need to do it again - if getattr(self._ssl_context, 'check_hostname', False): # type: ignore - verify_host = False + if isinstance(self._ssl_context, ssl.SSLContext): + ssl_sock = self._ssl_context.wrap_socket( + tcp_sock, + do_handshake_on_connect=False, + ) + else: + raise ssl_sock.settimeout(self._keepalive) - ssl_sock.do_handshake() - if verify_host: - # TODO: this type error is a true error: - # error: Module has no attribute "match_hostname" [attr-defined] - # Python 3.12 no longer have this method. - ssl.match_hostname(ssl_sock.getpeercert(), self._host) # type: ignore + # Function to handle retries for non-blocking SSL handshake + def do_handshake_with_retries(ssl_sock, retries=35, delay=0.1): + for attempt in range(retries): + try: + ssl_sock.do_handshake() + return + except SSL.WantReadError: + if attempt == retries - 1: + raise RuntimeError("Handshake failed after maximum retries") + time.sleep(delay) + + if HAS_OPENSSL and isinstance(ssl_sock, SSL.Connection): + do_handshake_with_retries(ssl_sock) + if verify_host: + if getattr(self._ssl_context, 'check_hostname', False): + verify_host = False + _openssl_match_hostname(ssl_sock.get_peer_certificate(), self._host) + else: + ssl_sock.do_handshake() + if verify_host: + if getattr(self._ssl_context, 'check_hostname', False): + verify_host = False + ssl.match_hostname(ssl_sock.getpeercert(), self._host) return ssl_sock From 02b583c98ad25666200a8573157df4bf719d3168 Mon Sep 17 00:00:00 2001 From: Walter BONETTI Date: Fri, 5 Jul 2024 10:12:09 -0400 Subject: [PATCH 2/5] Remove debug print Signed-off-by: Walter BONETTI --- src/paho/mqtt/client.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/paho/mqtt/client.py b/src/paho/mqtt/client.py index 7456b44f..b4c69807 100644 --- a/src/paho/mqtt/client.py +++ b/src/paho/mqtt/client.py @@ -58,7 +58,6 @@ def _subject_alt_name_string(cert: X509) -> list: san_entries = ext.__str__().split(', ') for entry in san_entries: key, value = entry.split(':', 1) - print(f"key {key}: value {value}") san.append((key.strip(), value.strip())) return san From 6df0433cb42f76f1f1c05f773eb3ecca3c8892c4 Mon Sep 17 00:00:00 2001 From: Walter BONETTI Date: Mon, 29 Jul 2024 17:43:33 -0400 Subject: [PATCH 3/5] Fix: missing error-handling on PyOpenSSL errors Signed-off-by: Walter BONETTI --- src/paho/mqtt/client.py | 121 ++++++++++++++++++++++++++++++---------- 1 file changed, 91 insertions(+), 30 deletions(-) diff --git a/src/paho/mqtt/client.py b/src/paho/mqtt/client.py index b4c69807..e6c24b02 100644 --- a/src/paho/mqtt/client.py +++ b/src/paho/mqtt/client.py @@ -1140,32 +1140,88 @@ def logger(self, value: logging.Logger | None) -> None: def _sock_recv(self, bufsize: int) -> bytes: if self._sock is None: raise ConnectionError("self._sock is None") - try: - return self._sock.recv(bufsize) - except ssl.SSLWantReadError as err: - raise BlockingIOError() from err - except ssl.SSLWantWriteError as err: - self._call_socket_register_write() - raise BlockingIOError() from err - except AttributeError as err: - self._easy_log( - MQTT_LOG_DEBUG, "socket was None: %s", err) - raise ConnectionError() from err + + if HAS_OPENSSL: + from OpenSSL import SSL + + if isinstance(self._ssl_context, SSL.Context): + try: + return self._sock.recv(bufsize) + except SSL.WantReadError as err: + raise BlockingIOError() from err + except SSL.WantWriteError as err: + self._call_socket_register_write() + raise BlockingIOError() from err + except SSL.ZeroReturnError as err: + raise BlockingIOError() from err + except AttributeError as err: + self._easy_log(MQTT_LOG_DEBUG, "socket was None: %s", err) + raise ConnectionError() from err + else: + try: + return self._sock.recv(bufsize) + except ssl.SSLWantReadError as err: + raise BlockingIOError() from err + except ssl.SSLWantWriteError as err: + self._call_socket_register_write() + raise BlockingIOError() from err + except AttributeError as err: + self._easy_log(MQTT_LOG_DEBUG, "socket was None: %s", err) + raise ConnectionError() from err + else: + try: + return self._sock.recv(bufsize) + except ssl.SSLWantReadError as err: + raise BlockingIOError() from err + except ssl.SSLWantWriteError as err: + self._call_socket_register_write() + raise BlockingIOError() from err + except AttributeError as err: + self._easy_log(MQTT_LOG_DEBUG, "socket was None: %s", err) + raise ConnectionError() from err def _sock_send(self, buf: bytes) -> int: if self._sock is None: raise ConnectionError("self._sock is None") - try: - return self._sock.send(buf) - except ssl.SSLWantReadError as err: - raise BlockingIOError() from err - except ssl.SSLWantWriteError as err: - self._call_socket_register_write() - raise BlockingIOError() from err - except BlockingIOError as err: - self._call_socket_register_write() - raise BlockingIOError() from err + if HAS_OPENSSL: + from OpenSSL import SSL + + if isinstance(self._ssl_context, SSL.Context): + try: + return self._sock.send(buf) + except SSL.WantReadError as err: + raise BlockingIOError() from err + except SSL.WantWriteError as err: + self._call_socket_register_write() + raise BlockingIOError() from err + except SSL.ZeroReturnError as err: + raise BlockingIOError() from err + except BlockingIOError as err: + self._call_socket_register_write() + raise BlockingIOError() from err + else: + try: + return self._sock.send(buf) + except ssl.SSLWantReadError as err: + raise BlockingIOError() from err + except ssl.SSLWantWriteError as err: + self._call_socket_register_write() + raise BlockingIOError() from err + except BlockingIOError as err: + self._call_socket_register_write() + raise BlockingIOError() from err + else: + try: + return self._sock.send(buf) + except ssl.SSLWantReadError as err: + raise BlockingIOError() from err + except ssl.SSLWantWriteError as err: + self._call_socket_register_write() + raise BlockingIOError() from err + except BlockingIOError as err: + self._call_socket_register_write() + raise BlockingIOError() from err def _sock_close(self) -> None: """Close the connection to the server.""" @@ -4696,13 +4752,13 @@ def _create_socket_connection(self) -> _socket.socket: return socks.create_connection(addr, timeout=self._connect_timeout, source_address=source, **proxy) else: return socket.create_connection(addr, timeout=self._connect_timeout, source_address=source) - + def _ssl_wrap_socket(self, tcp_sock: _socket) -> _socket.socket: if self._ssl_context is None: raise ValueError( "Impossible condition. _ssl_context should never be None if _ssl is True" ) - + verify_host = not self._tls_insecure try: if isinstance(self._ssl_context, ssl.SSLContext): @@ -4712,13 +4768,18 @@ def _ssl_wrap_socket(self, tcp_sock: _socket) -> _socket.socket: server_hostname=self._host, do_handshake_on_connect=False, ) - elif HAS_OPENSSL and isinstance(self._ssl_context, SSL.Context): - # Use PyOpenSSL's SSL.Context - conn = SSL.Connection(self._ssl_context, tcp_sock) - conn.set_connect_state() - if self._host: - conn.set_tlsext_host_name(self._host.encode('utf-8')) - ssl_sock = conn + elif HAS_OPENSSL: + from OpenSSL import SSL + + if isinstance(self._ssl_context, SSL.Context): + # Use PyOpenSSL's SSL.Context + conn = SSL.Connection(self._ssl_context, tcp_sock) + conn.set_connect_state() + if self._host: + conn.set_tlsext_host_name(self._host.encode('utf-8')) + ssl_sock = conn + else: + raise ValueError("Unsupported SSL context type") else: raise ValueError("Unsupported SSL context type") except ssl.CertificateError: From de609df3d203b857cfda28b6ad8fa5a2602cf08f Mon Sep 17 00:00:00 2001 From: Walter BONETTI Date: Thu, 1 Aug 2024 09:51:18 -0400 Subject: [PATCH 4/5] Fix: SSL.ZeroReturnError must raise ConnectionError SSL.ZeroReturnError is raised when the connection has been closed cleanly, and no more data will be received. This behavior indicates that the connection is no longer usable and thus can be considered a type of connection error. Therefore, it is appropriate to raise a ConnectionError when encountering SSL.ZeroReturnError. --- src/paho/mqtt/client.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/paho/mqtt/client.py b/src/paho/mqtt/client.py index e6c24b02..976f42bc 100644 --- a/src/paho/mqtt/client.py +++ b/src/paho/mqtt/client.py @@ -1153,7 +1153,7 @@ def _sock_recv(self, bufsize: int) -> bytes: self._call_socket_register_write() raise BlockingIOError() from err except SSL.ZeroReturnError as err: - raise BlockingIOError() from err + raise ConnectionError() from err except AttributeError as err: self._easy_log(MQTT_LOG_DEBUG, "socket was None: %s", err) raise ConnectionError() from err @@ -1196,7 +1196,7 @@ def _sock_send(self, buf: bytes) -> int: self._call_socket_register_write() raise BlockingIOError() from err except SSL.ZeroReturnError as err: - raise BlockingIOError() from err + raise ConnectionError() from err except BlockingIOError as err: self._call_socket_register_write() raise BlockingIOError() from err From fd552303b9715c41a09156d390892125578086d1 Mon Sep 17 00:00:00 2001 From: Walter BONETTI Date: Wed, 21 Aug 2024 16:07:22 -0400 Subject: [PATCH 5/5] Fix: missing SysCallError and WantX509LookupError Signed-off-by: Walter BONETTI --- src/paho/mqtt/client.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/paho/mqtt/client.py b/src/paho/mqtt/client.py index 976f42bc..a38a37f3 100644 --- a/src/paho/mqtt/client.py +++ b/src/paho/mqtt/client.py @@ -1152,8 +1152,12 @@ def _sock_recv(self, bufsize: int) -> bytes: except SSL.WantWriteError as err: self._call_socket_register_write() raise BlockingIOError() from err + except SSL.WantX509LookupError as err: + raise ConnectionError() from err except SSL.ZeroReturnError as err: raise ConnectionError() from err + except SSL.SysCallError as err: + raise ConnectionError() from err except AttributeError as err: self._easy_log(MQTT_LOG_DEBUG, "socket was None: %s", err) raise ConnectionError() from err @@ -1195,8 +1199,12 @@ def _sock_send(self, buf: bytes) -> int: except SSL.WantWriteError as err: self._call_socket_register_write() raise BlockingIOError() from err + except SSL.WantX509LookupError as err: + raise ConnectionError() from err except SSL.ZeroReturnError as err: raise ConnectionError() from err + except SSL.SysCallError as err: + raise ConnectionError() from err except BlockingIOError as err: self._call_socket_register_write() raise BlockingIOError() from err