Skip to content

Commit f8e3edd

Browse files
committed
fix async transactions
1 parent 74ecb33 commit f8e3edd

File tree

4 files changed

+86
-21
lines changed

4 files changed

+86
-21
lines changed

tests/aio/query/test_query_transaction.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,6 @@ async def test_execute_as_context_manager(self, tx: QueryTxContextAsync):
8989
await tx.begin()
9090

9191
async with await tx.execute("select 1;") as results:
92-
res = [result_set for result_set in results]
92+
res = [result_set async for result_set in results]
9393

9494
assert len(res) == 1

ydb/aio/query/session.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,14 @@ class QuerySessionAsync(BaseQuerySession):
2626
_loop: asyncio.AbstractEventLoop
2727
_status_stream: _utilities.AsyncResponseIterator = None
2828

29-
def __init__(self, driver: base.SupportedDriverType, settings: Optional[base.QueryClientSettings] = None):
29+
def __init__(
30+
self,
31+
driver: base.SupportedDriverType,
32+
settings: Optional[base.QueryClientSettings] = None,
33+
loop: asyncio.AbstractEventLoop = None
34+
):
3035
super(QuerySessionAsync, self).__init__(driver, settings)
31-
self._loop = asyncio.get_running_loop()
36+
self._loop = loop if loop is not None else asyncio.get_running_loop()
3237

3338
async def _attach(self) -> None:
3439
self._stream = await self._attach_call()
@@ -94,6 +99,7 @@ def transaction(self, tx_mode=None) -> base.IQueryTxContext:
9499
self._state,
95100
self,
96101
tx_mode,
102+
self._loop,
97103
)
98104

99105
async def execute(

ydb/aio/query/transaction.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
import logging
23
from typing import (
34
Optional,
@@ -15,6 +16,36 @@
1516

1617

1718
class QueryTxContextAsync(BaseQueryTxContext):
19+
_loop: asyncio.AbstractEventLoop
20+
21+
def __init__(self, driver, session_state, session, tx_mode, loop):
22+
"""
23+
An object that provides a simple transaction context manager that allows statements execution
24+
in a transaction. You don't have to open transaction explicitly, because context manager encapsulates
25+
transaction control logic, and opens new transaction if:
26+
27+
1) By explicit .begin() method;
28+
2) On execution of a first statement, which is strictly recommended method, because that avoids useless round trip
29+
30+
This context manager is not thread-safe, so you should not manipulate on it concurrently.
31+
32+
:param driver: A driver instance
33+
:param session_state: A state of session
34+
:param tx_mode: Transaction mode, which is a one from the following choises:
35+
1) QuerySerializableReadWrite() which is default mode;
36+
2) QueryOnlineReadOnly(allow_inconsistent_reads=False);
37+
3) QuerySnapshotReadOnly();
38+
4) QueryStaleReadOnly().
39+
"""
40+
41+
super(QueryTxContextAsync, self).__init__(
42+
driver,
43+
session_state,
44+
session,
45+
tx_mode,
46+
)
47+
self._loop = loop
48+
1849
async def __aenter__(self) -> "QueryTxContextAsync":
1950
"""
2051
Enters a context manager and returns a transaction
@@ -40,6 +71,12 @@ async def __aexit__(self, *args, **kwargs):
4071
except issues.Error:
4172
logger.warning("Failed to rollback leaked tx: %s", self._tx_state.tx_id)
4273

74+
async def _ensure_prev_stream_finished(self) -> None:
75+
if self._prev_stream is not None:
76+
async for _ in self._prev_stream:
77+
pass
78+
self._prev_stream = None
79+
4380
async def begin(self, settings: Optional[base.QueryClientSettings] = None) -> None:
4481
"""WARNING: This API is experimental and could be changed.
4582
@@ -61,9 +98,27 @@ async def commit(self, settings: Optional[base.QueryClientSettings] = None) -> N
6198
6299
:return: A committed transaction or exception if commit is failed
63100
"""
101+
if self._tx_state._already_in(QueryTxStateEnum.COMMITTED):
102+
return
103+
104+
if self._tx_state._state == QueryTxStateEnum.NOT_INITIALIZED:
105+
self._tx_state._change_state(QueryTxStateEnum.COMMITTED)
106+
return
107+
108+
await self._ensure_prev_stream_finished()
109+
64110
await self._commit_call(settings)
65111

66112
async def rollback(self, settings: Optional[base.QueryClientSettings] = None) -> None:
113+
if self._tx_state._already_in(QueryTxStateEnum.ROLLBACKED):
114+
return
115+
116+
if self._tx_state._state == QueryTxStateEnum.NOT_INITIALIZED:
117+
self._tx_state._change_state(QueryTxStateEnum.ROLLBACKED)
118+
return
119+
120+
await self._ensure_prev_stream_finished()
121+
67122
await self._rollback_call(settings)
68123

69124
async def execute(
@@ -93,6 +148,8 @@ async def execute(
93148
94149
:return: Iterator with result sets
95150
"""
151+
await self._ensure_prev_stream_finished()
152+
96153
stream_it = await self._execute_call(
97154
query=query,
98155
commit_tx=commit_tx,

ydb/query/transaction.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -252,14 +252,6 @@ def _begin_call(self, settings: Optional[base.QueryClientSettings]) -> "BaseQuer
252252
)
253253

254254
def _commit_call(self, settings: Optional[base.QueryClientSettings]) -> "BaseQueryTxContext":
255-
if self._tx_state._already_in(QueryTxStateEnum.COMMITTED):
256-
return
257-
self._ensure_prev_stream_finished()
258-
259-
if self._tx_state._state == QueryTxStateEnum.NOT_INITIALIZED:
260-
self._tx_state._change_state(QueryTxStateEnum.COMMITTED)
261-
return
262-
263255
self._tx_state._check_invalid_transition(QueryTxStateEnum.COMMITTED)
264256

265257
return self._driver(
@@ -272,15 +264,6 @@ def _commit_call(self, settings: Optional[base.QueryClientSettings]) -> "BaseQue
272264
)
273265

274266
def _rollback_call(self, settings: Optional[base.QueryClientSettings]) -> "BaseQueryTxContext":
275-
if self._tx_state._already_in(QueryTxStateEnum.ROLLBACKED):
276-
return
277-
278-
self._ensure_prev_stream_finished()
279-
280-
if self._tx_state._state == QueryTxStateEnum.NOT_INITIALIZED:
281-
self._tx_state._change_state(QueryTxStateEnum.ROLLBACKED)
282-
return
283-
284267
self._tx_state._check_invalid_transition(QueryTxStateEnum.ROLLBACKED)
285268

286269
return self._driver(
@@ -301,7 +284,6 @@ def _execute_call(
301284
parameters: dict = None,
302285
concurrent_result_sets: bool = False,
303286
) -> Iterable[_apis.ydb_query.ExecuteQueryResponsePart]:
304-
self._ensure_prev_stream_finished()
305287
self._tx_state._check_tx_ready_to_use()
306288

307289
request = base.create_execute_query_request(
@@ -362,9 +344,27 @@ def commit(self, settings: Optional[base.QueryClientSettings] = None) -> None:
362344
363345
:return: A committed transaction or exception if commit is failed
364346
"""
347+
if self._tx_state._already_in(QueryTxStateEnum.COMMITTED):
348+
return
349+
350+
if self._tx_state._state == QueryTxStateEnum.NOT_INITIALIZED:
351+
self._tx_state._change_state(QueryTxStateEnum.COMMITTED)
352+
return
353+
354+
self._ensure_prev_stream_finished()
355+
365356
self._commit_call(settings)
366357

367358
def rollback(self, settings: Optional[base.QueryClientSettings] = None) -> None:
359+
if self._tx_state._already_in(QueryTxStateEnum.ROLLBACKED):
360+
return
361+
362+
if self._tx_state._state == QueryTxStateEnum.NOT_INITIALIZED:
363+
self._tx_state._change_state(QueryTxStateEnum.ROLLBACKED)
364+
return
365+
366+
self._ensure_prev_stream_finished()
367+
368368
self._rollback_call(settings)
369369

370370
def execute(
@@ -394,6 +394,8 @@ def execute(
394394
395395
:return: Iterator with result sets
396396
"""
397+
self._ensure_prev_stream_finished()
398+
397399
stream_it = self._execute_call(
398400
query=query,
399401
commit_tx=commit_tx,

0 commit comments

Comments
 (0)