Skip to content

Commit 3bce385

Browse files
authored
Add SSL support (#66)
1 parent 51d410e commit 3bce385

7 files changed

+67
-13
lines changed

docker-compose.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ services:
1111
ARTEMIS_HOST: artemis
1212

1313
artemis:
14-
image: apache/activemq-artemis:2.34.0-alpine
14+
image: apache/activemq-artemis:2.37.0-alpine
1515
environment:
1616
ARTEMIS_USER: admin
1717
ARTEMIS_PASSWORD: ":=123"

stompman/client.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
from contextlib import AsyncExitStack, asynccontextmanager
44
from dataclasses import dataclass, field
55
from functools import partial
6+
from ssl import SSLContext
67
from types import TracebackType
7-
from typing import ClassVar, Self
8+
from typing import ClassVar, Literal, Self
89

910
from stompman.config import ConnectionParameters, Heartbeat
1011
from stompman.connection import AbstractConnection, Connection
@@ -32,6 +33,7 @@ class Client:
3233
on_heartbeat: Callable[[], None] | None = None
3334

3435
heartbeat: Heartbeat = field(default=Heartbeat(1000, 1000))
36+
ssl: Literal[True] | SSLContext | None = None
3537
connect_retry_attempts: int = 3
3638
connect_retry_interval: int = 1
3739
connect_timeout: int = 2
@@ -71,6 +73,7 @@ def __post_init__(self) -> None:
7173
read_timeout=self.read_timeout,
7274
read_max_chunk_size=self.read_max_chunk_size,
7375
write_retry_attempts=self.write_retry_attempts,
76+
ssl=self.ssl,
7477
)
7578

7679
async def __aenter__(self) -> Self:

stompman/connection.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
from collections.abc import AsyncGenerator, Generator, Iterator
44
from contextlib import contextmanager, suppress
55
from dataclasses import dataclass
6-
from typing import Protocol, Self, cast
6+
from ssl import SSLContext
7+
from typing import Literal, Protocol, Self, cast
78

89
from stompman.errors import ConnectionLostError
910
from stompman.frames import AnyClientFrame, AnyServerFrame
@@ -14,7 +15,14 @@
1415
class AbstractConnection(Protocol):
1516
@classmethod
1617
async def connect(
17-
cls, *, host: str, port: int, timeout: int, read_max_chunk_size: int, read_timeout: int
18+
cls,
19+
*,
20+
host: str,
21+
port: int,
22+
timeout: int,
23+
read_max_chunk_size: int,
24+
read_timeout: int,
25+
ssl: Literal[True] | SSLContext | None,
1826
) -> Self | None: ...
1927
async def close(self) -> None: ...
2028
def write_heartbeat(self) -> None: ...
@@ -36,17 +44,31 @@ class Connection(AbstractConnection):
3644
writer: asyncio.StreamWriter
3745
read_max_chunk_size: int
3846
read_timeout: int
47+
ssl: Literal[True] | SSLContext | None
3948

4049
@classmethod
4150
async def connect(
42-
cls, *, host: str, port: int, timeout: int, read_max_chunk_size: int, read_timeout: int
51+
cls,
52+
*,
53+
host: str,
54+
port: int,
55+
timeout: int,
56+
read_max_chunk_size: int,
57+
read_timeout: int,
58+
ssl: Literal[True] | SSLContext | None,
4359
) -> Self | None:
4460
try:
4561
reader, writer = await asyncio.wait_for(asyncio.open_connection(host, port), timeout=timeout)
4662
except (TimeoutError, ConnectionError, socket.gaierror):
4763
return None
4864
else:
49-
return cls(reader=reader, writer=writer, read_max_chunk_size=read_max_chunk_size, read_timeout=read_timeout)
65+
return cls(
66+
reader=reader,
67+
writer=writer,
68+
read_max_chunk_size=read_max_chunk_size,
69+
read_timeout=read_timeout,
70+
ssl=ssl,
71+
)
5072

5173
async def close(self) -> None:
5274
self.writer.close()

stompman/connection_manager.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import asyncio
22
from collections.abc import AsyncGenerator
33
from dataclasses import dataclass, field
4+
from ssl import SSLContext
45
from types import TracebackType
5-
from typing import TYPE_CHECKING, Self
6+
from typing import TYPE_CHECKING, Literal, Self
67

78
from stompman.config import ConnectionParameters
89
from stompman.connection import AbstractConnection
@@ -34,6 +35,7 @@ class ConnectionManager:
3435
connect_retry_attempts: int
3536
connect_retry_interval: int
3637
connect_timeout: int
38+
ssl: Literal[True] | SSLContext | None
3739
read_timeout: int
3840
read_max_chunk_size: int
3941
write_retry_attempts: int
@@ -63,6 +65,7 @@ async def _create_connection_to_one_server(self, server: ConnectionParameters) -
6365
timeout=self.connect_timeout,
6466
read_max_chunk_size=self.read_max_chunk_size,
6567
read_timeout=self.read_timeout,
68+
ssl=self.ssl,
6669
):
6770
return ActiveConnectionState(
6871
connection=connection,

tests/conftest.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import asyncio
22
from collections.abc import AsyncGenerator
33
from dataclasses import dataclass, field
4-
from typing import Any, Self, TypeVar
4+
from ssl import SSLContext
5+
from typing import Any, Literal, Self, TypeVar
56

67
import pytest
78
from polyfactory.factories.dataclass_factory import DataclassFactory
@@ -37,7 +38,14 @@ def noop_error_handler(exception: Exception, frame: stompman.MessageFrame) -> No
3738
class BaseMockConnection(AbstractConnection):
3839
@classmethod
3940
async def connect(
40-
cls, *, host: str, port: int, timeout: int, read_max_chunk_size: int, read_timeout: int
41+
cls,
42+
*,
43+
host: str,
44+
port: int,
45+
timeout: int,
46+
read_max_chunk_size: int,
47+
read_timeout: int,
48+
ssl: Literal[True] | SSLContext | None,
4149
) -> Self | None:
4250
return cls()
4351

@@ -78,6 +86,7 @@ class EnrichedConnectionManager(ConnectionManager):
7886
read_timeout: int = 4
7987
read_max_chunk_size: int = 5
8088
write_retry_attempts: int = 3
89+
ssl: Literal[True] | SSLContext | None = None
8190

8291

8392
DataclassType = TypeVar("DataclassType")

tests/test_connection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
async def make_connection() -> Connection | None:
2525
return await Connection.connect(
26-
host="localhost", port=12345, timeout=2, read_max_chunk_size=1024 * 1024, read_timeout=2
26+
host="localhost", port=12345, timeout=2, read_max_chunk_size=1024 * 1024, read_timeout=2, ssl=None
2727
)
2828

2929

tests/test_connection_manager.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import asyncio
22
from collections.abc import AsyncGenerator, AsyncIterable
3-
from typing import Self
3+
from ssl import SSLContext
4+
from typing import Literal, Self
45
from unittest import mock
56

67
import pytest
@@ -29,7 +30,14 @@ async def test_connect_attempts_ok(ok_on_attempt: int, monkeypatch: pytest.Monke
2930
class MockConnection(BaseMockConnection):
3031
@classmethod
3132
async def connect(
32-
cls, *, host: str, port: int, timeout: int, read_max_chunk_size: int, read_timeout: int
33+
cls,
34+
*,
35+
host: str,
36+
port: int,
37+
timeout: int,
38+
read_max_chunk_size: int,
39+
read_timeout: int,
40+
ssl: Literal[True] | SSLContext | None,
3341
) -> Self | None:
3442
assert (host, port) == (manager.servers[0].host, manager.servers[0].port)
3543
nonlocal attempts
@@ -42,6 +50,7 @@ async def connect(
4250
timeout=timeout,
4351
read_max_chunk_size=read_max_chunk_size,
4452
read_timeout=read_timeout,
53+
ssl=ssl,
4554
)
4655
if attempts == ok_on_attempt
4756
else None
@@ -67,7 +76,14 @@ async def test_connect_to_any_server_ok() -> None:
6776
class MockConnection(BaseMockConnection):
6877
@classmethod
6978
async def connect(
70-
cls, *, host: str, port: int, timeout: int, read_max_chunk_size: int, read_timeout: int
79+
cls,
80+
*,
81+
host: str,
82+
port: int,
83+
timeout: int,
84+
read_max_chunk_size: int,
85+
read_timeout: int,
86+
ssl: Literal[True] | SSLContext | None,
7187
) -> Self | None:
7288
return (
7389
await super().connect(
@@ -76,6 +92,7 @@ async def connect(
7692
timeout=timeout,
7793
read_max_chunk_size=read_max_chunk_size,
7894
read_timeout=read_timeout,
95+
ssl=ssl,
7996
)
8097
if port == successful_server.port
8198
else None

0 commit comments

Comments
 (0)