Skip to content

Commit 8e86e23

Browse files
authored
Change transaction API (#42)
1 parent 6197f32 commit 8e86e23

File tree

6 files changed

+82
-48
lines changed

6 files changed

+82
-48
lines changed

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,9 @@ await client.send(body=b"hi there!", destination="DLQ", headers={"persistent": "
4646
Or, to send messages in a transaction:
4747

4848
```python
49-
async with client.enter_transaction() as transaction:
49+
async with client.begin() as transaction:
5050
for _ in range(10):
51-
await client.send(body=b"hi there!", destination="DLQ", transaction=transaction, headers={"persistent": "true"})
51+
await transaction.send(body=b"hi there!", destination="DLQ", headers={"persistent": "true"})
5252
await asyncio.sleep(0.1)
5353
```
5454

@@ -108,7 +108,7 @@ async def handle_message(event: stompman.MessageEvent) -> None:
108108

109109
### Cleaning Up
110110

111-
stompman takes care of cleaning up resources automatically. When you leave the context of async context managers `stompman.Client()`, `client.subscribe()`, or `client.enter_transaction()`, the necessary frames will be sent to the server.
111+
stompman takes care of cleaning up resources automatically. When you leave the context of async context managers `stompman.Client()`, `client.subscribe()`, or `client.begin()`, the necessary frames will be sent to the server.
112112

113113
### Handling Connectivity Issues
114114

stompman/client.py

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
NackFrame,
2929
ReceiptFrame,
3030
SendFrame,
31-
SendHeaders,
3231
SubscribeFrame,
3332
UnsubscribeFrame,
3433
)
@@ -241,12 +240,12 @@ async def send_heartbeats_forever() -> None:
241240
break
242241

243242
@asynccontextmanager
244-
async def enter_transaction(self) -> AsyncGenerator[str, None]:
243+
async def begin(self) -> AsyncGenerator["Transaction", None]:
245244
transaction_id = str(uuid4())
246245
await self._connection.write_frame(BeginFrame(headers={"transaction": transaction_id}))
247246

248247
try:
249-
yield transaction_id
248+
yield Transaction(id=transaction_id, _connection=self._connection)
250249
except Exception:
251250
if self._connection.active:
252251
await self._connection.write_frame(AbortFrame(headers={"transaction": transaction_id}))
@@ -255,22 +254,18 @@ async def enter_transaction(self) -> AsyncGenerator[str, None]:
255254
if self._connection.active:
256255
await self._connection.write_frame(CommitFrame(headers={"transaction": transaction_id}))
257256

258-
async def send( # noqa: PLR0913
257+
async def send(
259258
self,
260259
body: bytes,
261260
destination: str,
262-
transaction: str | None = None,
263261
content_type: str | None = None,
264262
headers: dict[str, str] | None = None,
265263
) -> None:
266-
full_headers: SendHeaders = headers or {} # type: ignore[assignment]
267-
full_headers["destination"] = destination
268-
full_headers["content-length"] = str(len(body))
269-
if content_type is not None:
270-
full_headers["content-type"] = content_type
271-
if transaction is not None:
272-
full_headers["transaction"] = transaction
273-
await self._connection.write_frame(SendFrame(headers=full_headers, body=body))
264+
await self._connection.write_frame(
265+
SendFrame.build(
266+
body=body, destination=destination, transaction=None, content_type=content_type, headers=headers
267+
)
268+
)
274269

275270
@asynccontextmanager
276271
async def subscribe(self, destination: str) -> AsyncGenerator[None, None]:
@@ -300,6 +295,25 @@ async def listen(self) -> AsyncIterator["AnyListeningEvent"]:
300295
raise AssertionError(msg, frame)
301296

302297

298+
@dataclass(kw_only=True, slots=True)
299+
class Transaction:
300+
id: str
301+
_connection: AbstractConnection
302+
303+
async def send(
304+
self,
305+
body: bytes,
306+
destination: str,
307+
content_type: str | None = None,
308+
headers: dict[str, str] | None = None,
309+
) -> None:
310+
await self._connection.write_frame(
311+
SendFrame.build(
312+
body=body, destination=destination, transaction=self.id, content_type=content_type, headers=headers
313+
)
314+
)
315+
316+
303317
@dataclass(kw_only=True, slots=True)
304318
class MessageEvent:
305319
body: bytes = field(init=False)

stompman/frames.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from dataclasses import dataclass
2-
from typing import Literal, NotRequired, TypedDict
2+
from typing import Literal, NotRequired, Self, TypedDict
33

44
ConnectHeaders = TypedDict(
55
"ConnectHeaders",
@@ -140,6 +140,25 @@ class SendFrame:
140140
headers: SendHeaders
141141
body: bytes = b""
142142

143+
@classmethod
144+
def build( # noqa: PLR0913
145+
cls,
146+
*,
147+
body: bytes,
148+
destination: str,
149+
transaction: str | None,
150+
content_type: str | None,
151+
headers: dict[str, str] | None,
152+
) -> Self:
153+
full_headers: SendHeaders = headers or {} # type: ignore[assignment]
154+
full_headers["destination"] = destination
155+
full_headers["content-length"] = str(len(body))
156+
if content_type is not None:
157+
full_headers["content-type"] = content_type
158+
if transaction is not None:
159+
full_headers["transaction"] = transaction
160+
return cls(headers=full_headers, body=body)
161+
143162

144163
@dataclass(frozen=True, kw_only=True, slots=True)
145164
class SubscribeFrame:

testing/producer.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,11 @@
55

66

77
async def main() -> None:
8-
async with (
9-
stompman.Client(servers=[CONNECTION_PARAMETERS]) as client,
10-
client.enter_transaction() as transaction,
11-
):
8+
async with stompman.Client(servers=[CONNECTION_PARAMETERS]) as client, client.begin() as transaction:
129
for _ in range(10):
13-
await client.send(body=b"hi there!", destination="DLQ", transaction=transaction)
10+
await transaction.send(body=b"hi there!", destination="DLQ")
1411
await asyncio.sleep(3)
15-
await client.send(body=b"hi there!", destination="DLQ", transaction=transaction)
12+
await transaction.send(body=b"hi there!", destination="DLQ")
1613

1714

1815
if __name__ == "__main__":

tests/integration.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,9 @@ def destination() -> str:
5050

5151
async def test_ok(destination: str) -> None:
5252
async def produce() -> None:
53-
async with producer.enter_transaction() as transaction:
53+
async with producer.begin() as transaction:
5454
for message in messages:
55-
await producer.send(
56-
body=message, destination=destination, transaction=transaction, headers={"hello": "world"}
57-
)
55+
await transaction.send(body=message, destination=destination, headers={"hello": "world"})
5856

5957
async def consume() -> None:
6058
received_messages = []
@@ -102,19 +100,19 @@ async def test_not_raises_connection_lost_error_in_subscription(client: stompman
102100

103101

104102
async def test_not_raises_connection_lost_error_in_transaction_without_send(client: stompman.Client) -> None:
105-
async with client.enter_transaction():
103+
async with client.begin():
106104
await client._connection.close()
107105

108106

109107
async def test_not_raises_connection_lost_error_in_transaction_with_send(
110108
client: stompman.Client, destination: str
111109
) -> None:
112-
async with client.enter_transaction() as transaction:
113-
await client.send(b"first", destination=destination, transaction=transaction)
110+
async with client.begin() as transaction:
111+
await transaction.send(b"first", destination=destination)
114112
await client._connection.close()
115113

116114
with pytest.raises(ConnectionLostError):
117-
await client.send(b"second", destination=destination, transaction=transaction)
115+
await transaction.send(b"second", destination=destination)
118116

119117

120118
async def test_raises_connection_lost_error_in_send(client: stompman.Client, destination: str) -> None:

tests/test_client.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
ConnectedFrame,
2121
ConnectFrame,
2222
ConnectionConfirmationTimeoutError,
23+
ConnectionLostError,
2324
ConnectionParameters,
2425
DisconnectFrame,
2526
ErrorEvent,
@@ -36,7 +37,6 @@
3637
UnsubscribeFrame,
3738
UnsupportedProtocolVersionError,
3839
)
39-
from stompman.errors import ConnectionLostError
4040

4141
pytestmark = pytest.mark.anyio
4242

@@ -470,53 +470,59 @@ async def test_send_message_and_enter_transaction_ok(monkeypatch: pytest.MonkeyP
470470
body = b"hello"
471471
destination = "/queue/test"
472472
expires = "whatever"
473-
transaction = "myid"
473+
transaction_id = "myid"
474474
content_type = "my-content-type"
475-
monkeypatch.setattr(stompman.client, "uuid4", lambda: transaction)
475+
monkeypatch.setattr(stompman.client, "uuid4", lambda: transaction_id)
476476

477477
connection_class, collected_frames = create_spying_connection(get_read_frames_with_lifespan([]))
478478
async with (
479479
EnrichedClient(connection_class=connection_class) as client,
480-
client.enter_transaction() as transaction,
480+
client.begin() as transaction,
481481
):
482-
await client.send(
483-
body=body,
484-
destination=destination,
485-
transaction=transaction,
486-
content_type=content_type,
487-
headers={"expires": expires},
482+
await transaction.send(
483+
body=body, destination=destination, content_type=content_type, headers={"expires": expires}
488484
)
485+
await client.send(body=body, destination=destination, content_type=content_type, headers={"expires": expires})
489486

490487
assert_frames_between_lifespan_match(
491488
collected_frames,
492489
[
493-
BeginFrame(headers={"transaction": transaction}),
490+
BeginFrame(headers={"transaction": transaction_id}),
494491
SendFrame(
495492
headers={ # type: ignore[typeddict-unknown-key]
496493
"content-length": str(len(body)),
497494
"content-type": content_type,
498495
"destination": destination,
499-
"transaction": transaction,
496+
"transaction": transaction_id,
500497
"expires": expires,
501498
},
502499
body=b"hello",
503500
),
504-
CommitFrame(headers={"transaction": transaction}),
501+
SendFrame(
502+
headers={ # type: ignore[typeddict-unknown-key]
503+
"content-length": str(len(body)),
504+
"content-type": content_type,
505+
"destination": destination,
506+
"expires": expires,
507+
},
508+
body=b"hello",
509+
),
510+
CommitFrame(headers={"transaction": transaction_id}),
505511
],
506512
)
507513

508514

509515
async def test_send_message_and_enter_transaction_abort(monkeypatch: pytest.MonkeyPatch) -> None:
510-
transaction = "myid"
511-
monkeypatch.setattr(stompman.client, "uuid4", lambda: transaction)
516+
transaction_id = "myid"
517+
monkeypatch.setattr(stompman.client, "uuid4", lambda: transaction_id)
512518

513519
connection_class, collected_frames = create_spying_connection(get_read_frames_with_lifespan([]))
514520
async with EnrichedClient(connection_class=connection_class) as client:
515521
with suppress(AssertionError):
516-
async with client.enter_transaction() as transaction:
522+
async with client.begin():
517523
raise AssertionError
518524

519525
assert_frames_between_lifespan_match(
520526
collected_frames,
521-
[BeginFrame(headers={"transaction": transaction}), AbortFrame(headers={"transaction": transaction})],
527+
[BeginFrame(headers={"transaction": transaction_id}), AbortFrame(headers={"transaction": transaction_id})],
522528
)

0 commit comments

Comments
 (0)