Skip to content

Commit 9e5be0f

Browse files
committed
async transactions
1 parent 496d4c1 commit 9e5be0f

File tree

10 files changed

+306
-67
lines changed

10 files changed

+306
-67
lines changed

tests/aio/query/conftest.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,17 @@ async def session(driver):
1010
try:
1111
await session.delete()
1212
except BaseException:
13-
pass
13+
pass
14+
15+
16+
@pytest.fixture
17+
async def tx(session):
18+
await session.create()
19+
transaction = session.transaction()
20+
21+
yield transaction
22+
23+
try:
24+
await transaction.rollback()
25+
except BaseException:
26+
pass

tests/aio/query/test_query_session.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -58,21 +58,23 @@ async def test_create_after_delete_not_possible(self, session: QuerySessionAsync
5858
with pytest.raises(RuntimeError):
5959
await session.create()
6060

61-
# def test_transaction_before_create_raises(self, session: QuerySessionAsync):
62-
# with pytest.raises(RuntimeError):
63-
# session.transaction()
61+
def test_transaction_before_create_raises(self, session: QuerySessionAsync):
62+
with pytest.raises(RuntimeError):
63+
session.transaction()
6464

65-
# def test_transaction_after_delete_raises(self, session: QuerySessionAsync):
66-
# session.create()
65+
@pytest.mark.asyncio
66+
async def test_transaction_after_delete_raises(self, session: QuerySessionAsync):
67+
await session.create()
6768

68-
# session.delete()
69+
await session.delete()
6970

70-
# with pytest.raises(RuntimeError):
71-
# session.transaction()
71+
with pytest.raises(RuntimeError):
72+
session.transaction()
7273

73-
# def test_transaction_after_create_not_raises(self, session: QuerySessionAsync):
74-
# session.create()
75-
# session.transaction()
74+
@pytest.mark.asyncio
75+
async def test_transaction_after_create_not_raises(self, session: QuerySessionAsync):
76+
await session.create()
77+
session.transaction()
7678

7779
@pytest.mark.asyncio
7880
async def test_execute_before_create_raises(self, session: QuerySessionAsync):
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
import pytest
2+
3+
from ydb.aio.query.transaction import QueryTxContextAsync
4+
from ydb.query.transaction import QueryTxStateEnum
5+
6+
7+
class TestAsyncQueryTransaction:
8+
@pytest.mark.asyncio
9+
async def test_tx_begin(self, tx: QueryTxContextAsync):
10+
assert tx.tx_id is None
11+
12+
await tx.begin()
13+
assert tx.tx_id is not None
14+
15+
@pytest.mark.asyncio
16+
async def test_tx_allow_double_commit(self, tx: QueryTxContextAsync):
17+
await tx.begin()
18+
await tx.commit()
19+
await tx.commit()
20+
21+
@pytest.mark.asyncio
22+
async def test_tx_allow_double_rollback(self, tx: QueryTxContextAsync):
23+
await tx.begin()
24+
await tx.rollback()
25+
await tx.rollback()
26+
27+
@pytest.mark.asyncio
28+
async def test_tx_commit_before_begin(self, tx: QueryTxContextAsync):
29+
await tx.commit()
30+
assert tx._tx_state._state == QueryTxStateEnum.COMMITTED
31+
32+
@pytest.mark.asyncio
33+
async def test_tx_rollback_before_begin(self, tx: QueryTxContextAsync):
34+
await tx.rollback()
35+
assert tx._tx_state._state == QueryTxStateEnum.ROLLBACKED
36+
37+
@pytest.mark.asyncio
38+
async def test_tx_first_execute_begins_tx(self, tx: QueryTxContextAsync):
39+
await tx.execute("select 1;")
40+
await tx.commit()
41+
42+
@pytest.mark.asyncio
43+
async def test_interactive_tx_commit(self, tx: QueryTxContextAsync):
44+
await tx.execute("select 1;", commit_tx=True)
45+
with pytest.raises(RuntimeError):
46+
await tx.execute("select 1;")
47+
48+
@pytest.mark.asyncio
49+
async def test_tx_execute_raises_after_commit(self, tx: QueryTxContextAsync):
50+
await tx.begin()
51+
await tx.commit()
52+
with pytest.raises(RuntimeError):
53+
await tx.execute("select 1;")
54+
55+
@pytest.mark.asyncio
56+
async def test_tx_execute_raises_after_rollback(self, tx: QueryTxContextAsync):
57+
await tx.begin()
58+
await tx.rollback()
59+
with pytest.raises(RuntimeError):
60+
await tx.execute("select 1;")
61+
62+
@pytest.mark.asyncio
63+
async def test_context_manager_rollbacks_tx(self, tx: QueryTxContextAsync):
64+
async with tx:
65+
await tx.begin()
66+
67+
assert tx._tx_state._state == QueryTxStateEnum.ROLLBACKED
68+
69+
@pytest.mark.asyncio
70+
async def test_context_manager_normal_flow(self, tx: QueryTxContextAsync):
71+
async with tx:
72+
await tx.begin()
73+
await tx.execute("select 1;")
74+
await tx.commit()
75+
76+
assert tx._tx_state._state == QueryTxStateEnum.COMMITTED
77+
78+
@pytest.mark.asyncio
79+
async def test_context_manager_does_not_hide_exceptions(self, tx: QueryTxContextAsync):
80+
class CustomException(Exception):
81+
pass
82+
83+
with pytest.raises(CustomException):
84+
async with tx:
85+
raise CustomException()
86+
87+
@pytest.mark.asyncio
88+
async def test_execute_as_context_manager(self, tx: QueryTxContextAsync):
89+
await tx.begin()
90+
91+
async with await tx.execute("select 1;") as results:
92+
res = [result_set for result_set in results]
93+
94+
assert len(res) == 1

tests/query/test_query_transaction.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,78 +1,78 @@
11
import pytest
22

3-
from ydb.query.transaction import BaseQueryTxContext
3+
from ydb.query.transaction import QueryTxContextSync
44
from ydb.query.transaction import QueryTxStateEnum
55

66

77
class TestQueryTransaction:
8-
def test_tx_begin(self, tx: BaseQueryTxContext):
8+
def test_tx_begin(self, tx: QueryTxContextSync):
99
assert tx.tx_id is None
1010

1111
tx.begin()
1212
assert tx.tx_id is not None
1313

14-
def test_tx_allow_double_commit(self, tx: BaseQueryTxContext):
14+
def test_tx_allow_double_commit(self, tx: QueryTxContextSync):
1515
tx.begin()
1616
tx.commit()
1717
tx.commit()
1818

19-
def test_tx_allow_double_rollback(self, tx: BaseQueryTxContext):
19+
def test_tx_allow_double_rollback(self, tx: QueryTxContextSync):
2020
tx.begin()
2121
tx.rollback()
2222
tx.rollback()
2323

24-
def test_tx_commit_before_begin(self, tx: BaseQueryTxContext):
24+
def test_tx_commit_before_begin(self, tx: QueryTxContextSync):
2525
tx.commit()
2626
assert tx._tx_state._state == QueryTxStateEnum.COMMITTED
2727

28-
def test_tx_rollback_before_begin(self, tx: BaseQueryTxContext):
28+
def test_tx_rollback_before_begin(self, tx: QueryTxContextSync):
2929
tx.rollback()
3030
assert tx._tx_state._state == QueryTxStateEnum.ROLLBACKED
3131

32-
def test_tx_first_execute_begins_tx(self, tx: BaseQueryTxContext):
32+
def test_tx_first_execute_begins_tx(self, tx: QueryTxContextSync):
3333
tx.execute("select 1;")
3434
tx.commit()
3535

36-
def test_interactive_tx_commit(self, tx: BaseQueryTxContext):
36+
def test_interactive_tx_commit(self, tx: QueryTxContextSync):
3737
tx.execute("select 1;", commit_tx=True)
3838
with pytest.raises(RuntimeError):
3939
tx.execute("select 1;")
4040

41-
def test_tx_execute_raises_after_commit(self, tx: BaseQueryTxContext):
41+
def test_tx_execute_raises_after_commit(self, tx: QueryTxContextSync):
4242
tx.begin()
4343
tx.commit()
4444
with pytest.raises(RuntimeError):
4545
tx.execute("select 1;")
4646

47-
def test_tx_execute_raises_after_rollback(self, tx: BaseQueryTxContext):
47+
def test_tx_execute_raises_after_rollback(self, tx: QueryTxContextSync):
4848
tx.begin()
4949
tx.rollback()
5050
with pytest.raises(RuntimeError):
5151
tx.execute("select 1;")
5252

53-
def test_context_manager_rollbacks_tx(self, tx: BaseQueryTxContext):
53+
def test_context_manager_rollbacks_tx(self, tx: QueryTxContextSync):
5454
with tx:
5555
tx.begin()
5656

5757
assert tx._tx_state._state == QueryTxStateEnum.ROLLBACKED
5858

59-
def test_context_manager_normal_flow(self, tx: BaseQueryTxContext):
59+
def test_context_manager_normal_flow(self, tx: QueryTxContextSync):
6060
with tx:
6161
tx.begin()
6262
tx.execute("select 1;")
6363
tx.commit()
6464

6565
assert tx._tx_state._state == QueryTxStateEnum.COMMITTED
6666

67-
def test_context_manager_does_not_hide_exceptions(self, tx: BaseQueryTxContext):
67+
def test_context_manager_does_not_hide_exceptions(self, tx: QueryTxContextSync):
6868
class CustomException(Exception):
6969
pass
7070

7171
with pytest.raises(CustomException):
7272
with tx:
7373
raise CustomException()
7474

75-
def test_execute_as_context_manager(self, tx: BaseQueryTxContext):
75+
def test_execute_as_context_manager(self, tx: QueryTxContextSync):
7676
tx.begin()
7777

7878
with tx.execute("select 1;") as results:

ydb/aio/_utilities.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ def cancel(self):
77
self.it.cancel()
88
return self
99

10+
def __iter__(self):
11+
return self
12+
1013
def __aiter__(self):
1114
return self
1215

ydb/aio/query/base.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from .. import _utilities
2+
3+
class AsyncResponseContextIterator(_utilities.AsyncResponseIterator):
4+
async def __aenter__(self) -> "AsyncResponseContextIterator":
5+
return self
6+
7+
async def __aexit__(self, exc_type, exc_val, exc_tb):
8+
async for _ in self:
9+
pass

ydb/aio/query/session.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,25 +4,19 @@
44
Optional,
55
)
66

7+
from .base import AsyncResponseContextIterator
8+
from .transaction import QueryTxContextAsync
79
from .. import _utilities
810
from ... import issues
911
from ..._grpc.grpcwrapper import common_utils
12+
from ..._grpc.grpcwrapper import ydb_query_public_types as _ydb_query_public
13+
1014
from ...query import base
1115
from ...query.session import (
1216
BaseQuerySession,
1317
QuerySessionStateEnum,
1418
)
1519

16-
17-
class AsyncResponseContextIterator(_utilities.AsyncResponseIterator):
18-
async def __aenter__(self) -> "AsyncResponseContextIterator":
19-
return self
20-
21-
async def __aexit__(self, exc_type, exc_val, exc_tb):
22-
async for _ in self:
23-
pass
24-
25-
2620
class QuerySessionAsync(BaseQuerySession):
2721
"""Session object for Query Service. It is not recommended to control
2822
session's lifecycle manually - use a QuerySessionPool is always a better choise.
@@ -91,14 +85,21 @@ async def create(self) -> "QuerySessionAsync":
9185

9286
return self
9387

94-
async def transaction(self, tx_mode) -> base.IQueryTxContext:
95-
return super().transaction(tx_mode)
88+
def transaction(self, tx_mode = None) -> base.IQueryTxContext:
89+
self._state._check_session_ready_to_use()
90+
tx_mode = tx_mode if tx_mode else _ydb_query_public.QuerySerializableReadWrite()
91+
92+
return QueryTxContextAsync(
93+
self._driver,
94+
self._state,
95+
self,
96+
tx_mode,
97+
)
9698

9799
async def execute(
98100
self,
99101
query: str,
100102
parameters: dict = None,
101-
commit_tx: bool = False,
102103
syntax: base.QuerySyntax = None,
103104
exec_mode: base.QueryExecMode = None,
104105
concurrent_result_sets: bool = False,

0 commit comments

Comments
 (0)