Skip to content

Commit 12fe022

Browse files
authored
Merge pull request #511 from nats-io/tls-handshake-first
Add tls_handshake_first option.
2 parents cf86a09 + 6127b18 commit 12fe022

File tree

5 files changed

+173
-20
lines changed

5 files changed

+173
-20
lines changed

nats/aio/client.py

Lines changed: 31 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@
6464
)
6565
from .transport import TcpTransport, Transport, WebSocketTransport
6666

67-
__version__ = '2.4.0'
67+
__version__ = '2.5.0'
6868
__lang__ = 'python3'
6969
_logger = logging.getLogger(__name__)
7070
PROTOCOL = 1
@@ -305,6 +305,7 @@ async def connect(
305305
no_echo: bool = False,
306306
tls: Optional[ssl.SSLContext] = None,
307307
tls_hostname: Optional[str] = None,
308+
tls_handshake_first: bool = False,
308309
user: Optional[str] = None,
309310
password: Optional[str] = None,
310311
token: Optional[str] = None,
@@ -448,6 +449,7 @@ async def subscribe_handler(msg):
448449
self.options["token"] = token
449450
self.options["connect_timeout"] = connect_timeout
450451
self.options["drain_timeout"] = drain_timeout
452+
self.options['tls_handshake_first'] = tls_handshake_first
451453

452454
if tls:
453455
self.options['tls'] = tls
@@ -1886,6 +1888,24 @@ async def _process_connect_init(self) -> None:
18861888
assert self._current_server, "must be called only from Client.connect"
18871889
self._status = Client.CONNECTING
18881890

1891+
# Check whether to reuse the original hostname for an implicit route.
1892+
hostname = None
1893+
if "tls_hostname" in self.options:
1894+
hostname = self.options["tls_hostname"]
1895+
elif self._current_server.tls_name is not None:
1896+
hostname = self._current_server.tls_name
1897+
else:
1898+
hostname = self._current_server.uri.hostname
1899+
1900+
handshake_first = self.options['tls_handshake_first']
1901+
if handshake_first:
1902+
await self._transport.connect_tls(
1903+
hostname,
1904+
self.ssl_context,
1905+
DEFAULT_BUFFER_SIZE,
1906+
self.options['connect_timeout'],
1907+
)
1908+
18891909
connection_completed = self._transport.readline()
18901910
info_line = await asyncio.wait_for(
18911911
connection_completed, self.options["connect_timeout"]
@@ -1921,24 +1941,16 @@ async def _process_connect_init(self) -> None:
19211941

19221942
if 'tls_required' in self._server_info and self._server_info[
19231943
'tls_required'] and self._current_server.uri.scheme != "ws":
1924-
# Check whether to reuse the original hostname for an implicit route.
1925-
hostname = None
1926-
if "tls_hostname" in self.options:
1927-
hostname = self.options["tls_hostname"]
1928-
elif self._current_server.tls_name is not None:
1929-
hostname = self._current_server.tls_name
1930-
else:
1931-
hostname = self._current_server.uri.hostname
1932-
1933-
await self._transport.drain() # just in case something is left
1934-
1935-
# connect to transport via tls
1936-
await self._transport.connect_tls(
1937-
hostname,
1938-
self.ssl_context,
1939-
DEFAULT_BUFFER_SIZE,
1940-
self.options['connect_timeout'],
1941-
)
1944+
if not handshake_first:
1945+
await self._transport.drain() # just in case something is left
1946+
1947+
# connect to transport via tls
1948+
await self._transport.connect_tls(
1949+
hostname,
1950+
self.ssl_context,
1951+
DEFAULT_BUFFER_SIZE,
1952+
self.options['connect_timeout'],
1953+
)
19421954

19431955
# Refresh state of parser upon reconnect.
19441956
if self.is_reconnecting:

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# These are here for GitHub's dependency graph and help with setuptools support in some environments.
55
setup(
66
name="nats-py",
7-
version='2.4.0',
7+
version='2.5.0',
88
license='Apache 2 License',
99
extras_require={
1010
'nkeys': ['nkeys'],

tests/conf/tls_handshake_first.conf

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
port: 4224
2+
tls {
3+
cert_file: "./tests/certs/server-cert.pem"
4+
key_file: "./tests/certs/server-key.pem"
5+
ca_file: "./tests/certs/ca.pem"
6+
handshake_first: true
7+
verify: true
8+
}

tests/test_client.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import http.client
33
import json
44
import ssl
5+
import os
56
import time
67
import unittest
78
import urllib
@@ -21,6 +22,7 @@
2122
MultiTLSServerAuthTestCase,
2223
SingleServerTestCase,
2324
TLSServerTestCase,
25+
TLSServerHandshakeFirstTestCase,
2426
NoAuthUserServerTestCase,
2527
async_test,
2628
)
@@ -1797,6 +1799,110 @@ async def worker_handler(msg):
17971799
self.assertEqual(1, err_count)
17981800

17991801

1802+
class ClientTLSHandshakeFirstTest(TLSServerHandshakeFirstTestCase):
1803+
1804+
@async_test
1805+
async def test_connect(self):
1806+
if os.environ.get('NATS_SERVER_VERSION') != 'main':
1807+
pytest.skip("test requires nats-server@main")
1808+
1809+
nc = await nats.connect(
1810+
'nats://127.0.0.1:4224',
1811+
tls=self.ssl_ctx,
1812+
tls_handshake_first=True
1813+
)
1814+
self.assertEqual(nc._server_info['max_payload'], nc.max_payload)
1815+
self.assertTrue(nc._server_info['tls_required'])
1816+
self.assertTrue(nc._server_info['tls_verify'])
1817+
self.assertTrue(nc.max_payload > 0)
1818+
self.assertTrue(nc.is_connected)
1819+
await nc.close()
1820+
self.assertTrue(nc.is_closed)
1821+
self.assertFalse(nc.is_connected)
1822+
1823+
@async_test
1824+
async def test_default_connect_using_tls_scheme(self):
1825+
if os.environ.get('NATS_SERVER_VERSION') != 'main':
1826+
pytest.skip("test requires nats-server@main")
1827+
1828+
nc = NATS()
1829+
1830+
# Will attempt to connect using TLS with default certs.
1831+
with self.assertRaises(ssl.SSLError):
1832+
await nc.connect(
1833+
servers=['tls://127.0.0.1:4224'],
1834+
allow_reconnect=False,
1835+
tls_handshake_first=True,
1836+
)
1837+
1838+
@async_test
1839+
async def test_default_connect_using_tls_scheme_in_url(self):
1840+
if os.environ.get('NATS_SERVER_VERSION') != 'main':
1841+
pytest.skip("test requires nats-server@main")
1842+
1843+
nc = NATS()
1844+
1845+
# Will attempt to connect using TLS with default certs.
1846+
with self.assertRaises(ssl.SSLError):
1847+
await nc.connect(
1848+
'tls://127.0.0.1:4224',
1849+
allow_reconnect=False,
1850+
tls_handshake_first=True
1851+
)
1852+
1853+
@async_test
1854+
async def test_connect_tls_with_custom_hostname(self):
1855+
if os.environ.get('NATS_SERVER_VERSION') != 'main':
1856+
pytest.skip("test requires nats-server@main")
1857+
1858+
nc = NATS()
1859+
1860+
# Will attempt to connect using TLS with an invalid hostname.
1861+
with self.assertRaises(ssl.SSLError):
1862+
await nc.connect(
1863+
servers=['nats://127.0.0.1:4224'],
1864+
tls=self.ssl_ctx,
1865+
tls_hostname="nats.example",
1866+
tls_handshake_first=True,
1867+
allow_reconnect=False,
1868+
)
1869+
1870+
@async_test
1871+
async def test_subscribe(self):
1872+
if os.environ.get('NATS_SERVER_VERSION') != 'main':
1873+
pytest.skip("test requires nats-server@main")
1874+
1875+
nc = NATS()
1876+
msgs = []
1877+
1878+
async def subscription_handler(msg):
1879+
msgs.append(msg)
1880+
1881+
payload = b'hello world'
1882+
await nc.connect(
1883+
servers=['nats://127.0.0.1:4224'],
1884+
tls=self.ssl_ctx,
1885+
tls_handshake_first=True
1886+
)
1887+
sub = await nc.subscribe("foo", cb=subscription_handler)
1888+
await nc.publish("foo", payload)
1889+
await nc.publish("bar", payload)
1890+
1891+
with self.assertRaises(nats.errors.BadSubjectError):
1892+
await nc.publish("", b'')
1893+
1894+
# Wait a bit for message to be received.
1895+
await asyncio.sleep(0.2)
1896+
1897+
self.assertEqual(1, len(msgs))
1898+
msg = msgs[0]
1899+
self.assertEqual('foo', msg.subject)
1900+
self.assertEqual('', msg.reply)
1901+
self.assertEqual(payload, msg.data)
1902+
self.assertEqual(1, sub._received)
1903+
await nc.close()
1904+
1905+
18001906
class ClusterDiscoveryTest(ClusteringTestCase):
18011907

18021908
@async_test

tests/utils.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,33 @@ def tearDown(self):
253253
self.loop.close()
254254

255255

256+
class TLSServerHandshakeFirstTestCase(unittest.TestCase):
257+
258+
def setUp(self):
259+
super().setUp()
260+
self.loop = asyncio.new_event_loop()
261+
262+
self.natsd = NATSD(
263+
port=4224,
264+
config_file=get_config_file('conf/tls_handshake_first.conf')
265+
)
266+
start_natsd(self.natsd)
267+
268+
self.ssl_ctx = ssl.create_default_context(
269+
purpose=ssl.Purpose.SERVER_AUTH
270+
)
271+
# self.ssl_ctx.protocol = ssl.PROTOCOL_TLSv1_2
272+
self.ssl_ctx.load_verify_locations(get_config_file('certs/ca.pem'))
273+
self.ssl_ctx.load_cert_chain(
274+
certfile=get_config_file('certs/client-cert.pem'),
275+
keyfile=get_config_file('certs/client-key.pem')
276+
)
277+
278+
def tearDown(self):
279+
self.natsd.stop()
280+
self.loop.close()
281+
282+
256283
class MultiTLSServerAuthTestCase(unittest.TestCase):
257284

258285
def setUp(self):

0 commit comments

Comments
 (0)