|
45 | 45 | from .reasoncodes import ReasonCode, ReasonCodes
|
46 | 46 | from .subscribeoptions import SubscribeOptions
|
47 | 47 |
|
| 48 | +try: |
| 49 | + from OpenSSL import SSL |
| 50 | + from OpenSSL.crypto import X509 |
| 51 | + |
| 52 | + def _subject_alt_name_string(cert: X509) -> list: |
| 53 | + """Extracts the subject alternative name (SAN) entries from the certificate.""" |
| 54 | + san = [] |
| 55 | + for i in range(cert.get_extension_count()): |
| 56 | + ext = cert.get_extension(i) |
| 57 | + if ext.get_short_name() == b'subjectAltName': |
| 58 | + san_entries = ext.__str__().split(', ') |
| 59 | + for entry in san_entries: |
| 60 | + key, value = entry.split(':', 1) |
| 61 | + print(f"key {key}: value {value}") |
| 62 | + san.append((key.strip(), value.strip())) |
| 63 | + return san |
| 64 | + |
| 65 | + def _openssl_match_hostname(cert: X509, hostname: str): |
| 66 | + """Verify that *cert* matches the *hostname* according to RFC 2818 and RFC 6125 rules. |
| 67 | + CertificateError is raised on failure. On success, the function returns nothing. |
| 68 | + """ |
| 69 | + if not cert: |
| 70 | + raise ValueError("Empty or no certificate. match_hostname needs a certificate.") |
| 71 | + |
| 72 | + dnsnames = [] |
| 73 | + # Extract subject alternative name (SAN) entries |
| 74 | + san = _subject_alt_name_string(cert) |
| 75 | + for key, value in san: |
| 76 | + if key == 'DNS': |
| 77 | + if ssl._dnsname_match(value, hostname): |
| 78 | + return |
| 79 | + dnsnames.append(value) |
| 80 | + |
| 81 | + if not dnsnames: |
| 82 | + # TODO: check if no dns entry to use subject |
| 83 | + raise ValueError("pyOpenssl match_hostname: using subject is not supported.") |
| 84 | + |
| 85 | + if len(dnsnames) > 1: |
| 86 | + raise ssl.CertificateError(f"Hostname {hostname} doesn't match any of {', '.join(map(repr, dnsnames))}") |
| 87 | + elif len(dnsnames) == 1: |
| 88 | + raise ssl.CertificateError(f"Hostname {hostname} doesn't match {dnsnames[0]}") |
| 89 | + else: |
| 90 | + raise ssl.CertificateError("No appropriate commonName or subjectAltName fields were found") |
| 91 | + |
| 92 | + HAS_OPENSSL = True |
| 93 | +except ImportError: |
| 94 | + HAS_OPENSSL = False |
| 95 | + |
48 | 96 | try:
|
49 | 97 | from typing import Literal
|
50 | 98 | except ImportError:
|
@@ -851,7 +899,7 @@ def __init__(
|
851 | 899 | self._thread: threading.Thread | None = None
|
852 | 900 | self._thread_terminate = False
|
853 | 901 | self._ssl = False
|
854 |
| - self._ssl_context: ssl.SSLContext | None = None |
| 902 | + self._ssl_context: ssl.SSLContext | SSL.Context | None = None |
855 | 903 | # Only used when SSL context does not have check_hostname attribute
|
856 | 904 | self._tls_insecure = False
|
857 | 905 | self._logger: logging.Logger | None = None
|
@@ -1181,26 +1229,37 @@ def ws_set_options(
|
1181 | 1229 |
|
1182 | 1230 | def tls_set_context(
|
1183 | 1231 | self,
|
1184 |
| - context: ssl.SSLContext | None = None, |
| 1232 | + context: ssl.SSLContext | SSL.Context | None = None, |
1185 | 1233 | ) -> None:
|
1186 | 1234 | """Configure network encryption and authentication context. Enables SSL/TLS support.
|
1187 | 1235 |
|
1188 |
| - :param context: an ssl.SSLContext object. By default this is given by |
1189 |
| - ``ssl.create_default_context()``, if available. |
| 1236 | + :param context: an ssl.SSLContext or OpenSSL.SSL.Context object. By default, this is given by |
| 1237 | + ``ssl.create_default_context()`` if available. |
1190 | 1238 |
|
1191 |
| - Must be called before `connect()`, `connect_async()` or `connect_srv()`.""" |
| 1239 | + Must be called before `connect()`, `connect_async()` or `connect_srv()`. |
| 1240 | + """ |
1192 | 1241 | if self._ssl_context is not None:
|
1193 | 1242 | raise ValueError('SSL/TLS has already been configured.')
|
1194 | 1243 |
|
1195 | 1244 | if context is None:
|
1196 |
| - context = ssl.create_default_context() |
| 1245 | + if HAS_OPENSSL: |
| 1246 | + raise ValueError("OpenSSL custom context is not provided.") |
| 1247 | + else: |
| 1248 | + context = ssl.create_default_context() |
1197 | 1249 |
|
1198 | 1250 | self._ssl = True
|
1199 | 1251 | self._ssl_context = context
|
1200 | 1252 |
|
1201 |
| - # Ensure _tls_insecure is consistent with check_hostname attribute |
1202 |
| - if hasattr(context, 'check_hostname'): |
| 1253 | + # Ensure _tls_insecure is consistent with check_hostname attribute for ssl.SSLContext |
| 1254 | + if isinstance(context, ssl.SSLContext) and hasattr(context, 'check_hostname'): |
1203 | 1255 | self._tls_insecure = not context.check_hostname
|
| 1256 | + elif HAS_OPENSSL and isinstance(context, SSL.Context): |
| 1257 | + # PyOpenSSL Context does not have check_hostname attribute |
| 1258 | + # Set _tls_insecure based on custom logic if necessary |
| 1259 | + self._tls_insecure = False # Assuming default to False for PyOpenSSL |
| 1260 | + else: |
| 1261 | + # If OpenSSL is not available and context is an SSL.Context, raise an error |
| 1262 | + raise ValueError("OpenSSL is not available, cannot use SSL.Context.") |
1204 | 1263 |
|
1205 | 1264 | def tls_set(
|
1206 | 1265 | self,
|
@@ -4638,43 +4697,67 @@ def _create_socket_connection(self) -> _socket.socket:
|
4638 | 4697 | return socks.create_connection(addr, timeout=self._connect_timeout, source_address=source, **proxy)
|
4639 | 4698 | else:
|
4640 | 4699 | return socket.create_connection(addr, timeout=self._connect_timeout, source_address=source)
|
4641 |
| - |
4642 |
| - def _ssl_wrap_socket(self, tcp_sock: _socket.socket) -> ssl.SSLSocket: |
| 4700 | + |
| 4701 | + def _ssl_wrap_socket(self, tcp_sock: _socket) -> _socket.socket: |
4643 | 4702 | if self._ssl_context is None:
|
4644 | 4703 | raise ValueError(
|
4645 | 4704 | "Impossible condition. _ssl_context should never be None if _ssl is True"
|
4646 | 4705 | )
|
4647 |
| - |
| 4706 | + |
4648 | 4707 | verify_host = not self._tls_insecure
|
4649 | 4708 | try:
|
4650 |
| - # Try with server_hostname, even it's not supported in certain scenarios |
4651 |
| - ssl_sock = self._ssl_context.wrap_socket( |
4652 |
| - tcp_sock, |
4653 |
| - server_hostname=self._host, |
4654 |
| - do_handshake_on_connect=False, |
4655 |
| - ) |
| 4709 | + if isinstance(self._ssl_context, ssl.SSLContext): |
| 4710 | + # Use the built-in ssl.SSLContext |
| 4711 | + ssl_sock = self._ssl_context.wrap_socket( |
| 4712 | + tcp_sock, |
| 4713 | + server_hostname=self._host, |
| 4714 | + do_handshake_on_connect=False, |
| 4715 | + ) |
| 4716 | + elif HAS_OPENSSL and isinstance(self._ssl_context, SSL.Context): |
| 4717 | + # Use PyOpenSSL's SSL.Context |
| 4718 | + conn = SSL.Connection(self._ssl_context, tcp_sock) |
| 4719 | + conn.set_connect_state() |
| 4720 | + if self._host: |
| 4721 | + conn.set_tlsext_host_name(self._host.encode('utf-8')) |
| 4722 | + ssl_sock = conn |
| 4723 | + else: |
| 4724 | + raise ValueError("Unsupported SSL context type") |
4656 | 4725 | except ssl.CertificateError:
|
4657 |
| - # CertificateError is derived from ValueError |
4658 | 4726 | raise
|
4659 | 4727 | except ValueError:
|
4660 |
| - # Python version requires SNI in order to handle server_hostname, but SNI is not available |
4661 |
| - ssl_sock = self._ssl_context.wrap_socket( |
4662 |
| - tcp_sock, |
4663 |
| - do_handshake_on_connect=False, |
4664 |
| - ) |
4665 |
| - else: |
4666 |
| - # If SSL context has already checked hostname, then don't need to do it again |
4667 |
| - if getattr(self._ssl_context, 'check_hostname', False): # type: ignore |
4668 |
| - verify_host = False |
| 4728 | + if isinstance(self._ssl_context, ssl.SSLContext): |
| 4729 | + ssl_sock = self._ssl_context.wrap_socket( |
| 4730 | + tcp_sock, |
| 4731 | + do_handshake_on_connect=False, |
| 4732 | + ) |
| 4733 | + else: |
| 4734 | + raise |
4669 | 4735 |
|
4670 | 4736 | ssl_sock.settimeout(self._keepalive)
|
4671 |
| - ssl_sock.do_handshake() |
4672 | 4737 |
|
4673 |
| - if verify_host: |
4674 |
| - # TODO: this type error is a true error: |
4675 |
| - # error: Module has no attribute "match_hostname" [attr-defined] |
4676 |
| - # Python 3.12 no longer have this method. |
4677 |
| - ssl.match_hostname(ssl_sock.getpeercert(), self._host) # type: ignore |
| 4738 | + # Function to handle retries for non-blocking SSL handshake |
| 4739 | + def do_handshake_with_retries(ssl_sock, retries=35, delay=0.1): |
| 4740 | + for attempt in range(retries): |
| 4741 | + try: |
| 4742 | + ssl_sock.do_handshake() |
| 4743 | + return |
| 4744 | + except SSL.WantReadError: |
| 4745 | + if attempt == retries - 1: |
| 4746 | + raise RuntimeError("Handshake failed after maximum retries") |
| 4747 | + time.sleep(delay) |
| 4748 | + |
| 4749 | + if HAS_OPENSSL and isinstance(ssl_sock, SSL.Connection): |
| 4750 | + do_handshake_with_retries(ssl_sock) |
| 4751 | + if verify_host: |
| 4752 | + if getattr(self._ssl_context, 'check_hostname', False): |
| 4753 | + verify_host = False |
| 4754 | + _openssl_match_hostname(ssl_sock.get_peer_certificate(), self._host) |
| 4755 | + else: |
| 4756 | + ssl_sock.do_handshake() |
| 4757 | + if verify_host: |
| 4758 | + if getattr(self._ssl_context, 'check_hostname', False): |
| 4759 | + verify_host = False |
| 4760 | + ssl.match_hostname(ssl_sock.getpeercert(), self._host) |
4678 | 4761 |
|
4679 | 4762 | return ssl_sock
|
4680 | 4763 |
|
|
0 commit comments