Skip to content

Commit d9a0424

Browse files
committed
refactor transaction
1 parent 75ea82d commit d9a0424

File tree

5 files changed

+141
-53
lines changed

5 files changed

+141
-53
lines changed

docker-compose.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
version: "3.3"
22
services:
33
ydb:
4-
image: ydbplatform/local-ydb:latest
4+
image: ydbplatform/local-ydb:trunk
55
restart: always
66
ports:
77
- 2136:2136

tests/query/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def session(driver_sync):
1414
pass
1515

1616
@pytest.fixture
17-
def transaction(session):
17+
def tx(session):
1818
session.create()
1919
transaction = session.transaction()
2020

tests/query/test_query_transaction.py

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,44 @@
11
import pytest
22

3-
import ydb.query.session
3+
class TestQueryTransaction:
4+
def test_tx_begin(self, tx):
5+
assert tx.tx_id == None
46

5-
class TestQuerySession:
6-
def test_transaction_begin(self, driver_sync):
7-
session = ydb.query.session.QuerySessionSync(driver_sync)
7+
tx.begin()
8+
assert tx.tx_id != None
89

9-
session.create()
10+
def test_tx_allow_double_commit(self, tx):
11+
tx.begin()
12+
tx.commit()
13+
tx.commit()
1014

11-
tx = session.transaction()
15+
def test_tx_allow_double_rollback(self, tx):
16+
tx.begin()
17+
tx.rollback()
18+
tx.rollback()
1219

13-
assert tx.tx_id == None
20+
def test_tx_commit_raises_before_begin(self, tx):
21+
with pytest.raises(RuntimeError):
22+
tx.commit()
1423

24+
def test_tx_rollback_raises_before_begin(self, tx):
25+
with pytest.raises(RuntimeError):
26+
tx.rollback()
27+
28+
# def test_tx_execute_raises_before_begin(self, tx):
29+
# with pytest.raises(RuntimeError):
30+
# tx.execute("select 1;")
31+
32+
def text_tx_execute_raises_after_commit(self, tx):
1533
tx.begin()
34+
tx.commit()
35+
with pytest.raises(RuntimeError):
36+
tx.execute("select 1;")
37+
38+
def text_tx_execute_raises_after_rollback(self, tx):
39+
tx.begin()
40+
tx.rollback()
41+
with pytest.raises(RuntimeError):
42+
tx.execute("select 1;")
43+
1644

17-
assert tx.tx_id != None

ydb/query/base.py

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -124,19 +124,32 @@ def session(self) -> IQuerySession:
124124
pass
125125

126126

127-
def create_execute_query_request(query: str, session_id: str, commit_tx: bool):
128-
req = ydb_query.ExecuteQueryRequest(
129-
session_id=session_id,
130-
query_content=ydb_query.QueryContent.from_public(
131-
query=query,
132-
),
133-
tx_control=ydb_query.TransactionControl(
134-
begin_tx=ydb_query.TransactionSettings(
135-
tx_mode=QuerySerializableReadWrite(),
127+
def create_execute_query_request(query: str, session_id: str, tx_id: str = None, commit_tx: bool = False, tx_mode: BaseQueryTxMode = None):
128+
if tx_id:
129+
req = ydb_query.ExecuteQueryRequest(
130+
session_id=session_id,
131+
query_content=ydb_query.QueryContent.from_public(
132+
query=query,
136133
),
137-
commit_tx=commit_tx
138-
),
139-
)
134+
tx_control=ydb_query.TransactionControl(
135+
tx_id=tx_id,
136+
commit_tx=commit_tx
137+
),
138+
)
139+
else:
140+
tx_mode = tx_mode if tx_mode is not None else QuerySerializableReadWrite()
141+
req = ydb_query.ExecuteQueryRequest(
142+
session_id=session_id,
143+
query_content=ydb_query.QueryContent.from_public(
144+
query=query,
145+
),
146+
tx_control=ydb_query.TransactionControl(
147+
begin_tx=ydb_query.TransactionSettings(
148+
tx_mode=tx_mode,
149+
),
150+
commit_tx=commit_tx
151+
),
152+
)
140153

141154
return req.to_proto()
142155

@@ -148,17 +161,17 @@ def wrap_execute_query_response(rpc_state, response_pb):
148161
X_YDB_SESSION_CLOSE = "session-close"
149162

150163

151-
def _check_session_is_closing(rpc_state, session_state):
152-
metadata = rpc_state.trailing_metadata()
153-
if X_YDB_SESSION_CLOSE in metadata.get(X_YDB_SERVER_HINTS, []):
154-
session_state.set_closing() # TODO: clarify & implement
164+
# def _check_session_is_closing(rpc_state, session_state):
165+
# metadata = rpc_state.trailing_metadata()
166+
# if X_YDB_SESSION_CLOSE in metadata.get(X_YDB_SERVER_HINTS, []):
167+
# session_state.set_closing() # TODO: clarify & implement
155168

156169

157170
def bad_session_handler(func):
158171
@functools.wraps(func)
159172
def decorator(rpc_state, response_pb, session_state, *args, **kwargs):
160173
try:
161-
_check_session_is_closing(rpc_state, session_state)
174+
# _check_session_is_closing(rpc_state, session_state)
162175
return func(rpc_state, response_pb, session_state, *args, **kwargs)
163176
except issues.BadSession:
164177
session_state.reset()

ydb/query/transaction.py

Lines changed: 74 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,37 @@
1515

1616
logger = logging.getLogger(__name__)
1717

18-
# def patch_table_service_tx_mode_to_query_service(tx_mode: AbstractTransactionModeBuilder):
19-
# if tx_mode.name == 'snapshot_read_only':
20-
# tx_mode = _ydb_query_public.QuerySnapshotReadOnly()
21-
# elif tx_mode.name == 'serializable_read_write':
22-
# tx_mode = _ydb_query_public.QuerySerializableReadWrite()
23-
# elif tx_mode.name =='online_read_only':
24-
# tx_mode = _ydb_query_public.QueryOnlineReadOnly()
25-
# elif tx_mode.name == 'stale_read_only':
26-
# tx_mode = _ydb_query_public.QueryStaleReadOnly()
27-
# else:
28-
# raise issues.YDBInvalidArgumentError(f'Unknown transaction mode: {tx_mode.name}')
2918

30-
# return tx_mode
19+
class QueryTxStateEnum(enum.Enum):
20+
NOT_INITIALIZED = "NOT_INITIALIZED"
21+
BEGINED = "BEGINED"
22+
COMMITTED = "COMMITTED"
23+
ROLLBACKED = "ROLLBACKED"
24+
DEAD = "DEAD"
25+
26+
27+
class QueryTxStateHelper(abc.ABC):
28+
_VALID_TRANSITIONS = {
29+
QueryTxStateEnum.NOT_INITIALIZED: [QueryTxStateEnum.BEGINED, QueryTxStateEnum.DEAD],
30+
QueryTxStateEnum.BEGINED: [QueryTxStateEnum.COMMITTED, QueryTxStateEnum.ROLLBACKED, QueryTxStateEnum.DEAD],
31+
QueryTxStateEnum.COMMITTED: [],
32+
QueryTxStateEnum.ROLLBACKED: [],
33+
QueryTxStateEnum.DEAD: [],
34+
}
35+
36+
_TERMINAL_STATES = [
37+
QueryTxStateEnum.COMMITTED,
38+
QueryTxStateEnum.ROLLBACKED,
39+
QueryTxStateEnum.DEAD,
40+
]
41+
42+
@classmethod
43+
def valid_transition(cls, before: QueryTxStateEnum, after: QueryTxStateEnum) -> bool:
44+
return after in cls._VALID_TRANSITIONS[before]
45+
46+
@classmethod
47+
def terminal(cls, state: QueryTxStateEnum) -> bool:
48+
return state in cls._TERMINAL_STATES
3149

3250

3351
def reset_tx_id_handler(func):
@@ -36,7 +54,7 @@ def decorator(rpc_state, response_pb, session_state, tx_state, *args, **kwargs):
3654
try:
3755
return func(rpc_state, response_pb, session_state, tx_state, *args, **kwargs)
3856
except issues.Error:
39-
tx_state.change_state(base.QueryTxStateEnum.DEAD)
57+
tx_state._change_state(QueryTxStateEnum.DEAD)
4058
tx_state.tx_id = None
4159
raise
4260

@@ -51,16 +69,23 @@ def __init__(self, tx_mode: base.BaseQueryTxMode):
5169
"""
5270
self.tx_id = None
5371
self.tx_mode = tx_mode
54-
self._state = base.QueryTxStateEnum.NOT_INITIALIZED
72+
self._state = QueryTxStateEnum.NOT_INITIALIZED
5573

56-
def check_invalid_transition(self, target: base.QueryTxStateEnum):
57-
if not base.QueryTxStateHelper.is_valid_transition(self._state, target):
74+
def _check_invalid_transition(self, target: QueryTxStateEnum):
75+
if not QueryTxStateHelper.valid_transition(self._state, target):
5876
raise RuntimeError(f"Transaction could not be moved from {self._state.value} to {target.value}")
5977

60-
def change_state(self, target: base.QueryTxStateEnum):
61-
self.check_invalid_transition(target)
78+
def _change_state(self, target: QueryTxStateEnum):
79+
self._check_invalid_transition(target)
6280
self._state = target
6381

82+
def _check_tx_not_terminal(self):
83+
if QueryTxStateHelper.terminal(self._state):
84+
raise RuntimeError(f"Transaction is in terminal state: {self._state.value}")
85+
86+
def _already_in(self, target: QueryTxStateEnum):
87+
return self._state == target
88+
6489

6590
def _construct_tx_settings(tx_state):
6691
tx_settings = _ydb_query.TransactionSettings.from_public(tx_state.tx_mode)
@@ -93,7 +118,7 @@ def _create_rollback_transaction_request(session_state, tx_state):
93118
def wrap_tx_begin_response(rpc_state, response_pb, session_state, tx_state, tx):
94119
message = _ydb_query.BeginTransactionResponse.from_proto(response_pb)
95120
issues._process_response(message.status)
96-
tx_state.change_state(base.QueryTxStateEnum.BEGINED)
121+
tx_state._change_state(QueryTxStateEnum.BEGINED)
97122
tx_state.tx_id = message.tx_meta.tx_id
98123
return tx
99124

@@ -104,7 +129,7 @@ def wrap_tx_commit_response(rpc_state, response_pb, session_state, tx_state, tx)
104129
message = _ydb_query.CommitTransactionResponse(response_pb)
105130
issues._process_response(message.status)
106131
tx_state.tx_id = None
107-
tx_state.change_state(base.QueryTxStateEnum.COMMITTED)
132+
tx_state._change_state(QueryTxStateEnum.COMMITTED)
108133
return tx
109134

110135
@base.bad_session_handler
@@ -113,7 +138,7 @@ def wrap_tx_rollback_response(rpc_state, response_pb, session_state, tx_state, t
113138
message = _ydb_query.RollbackTransactionResponse(response_pb)
114139
issues._process_response(message.status)
115140
tx_state.tx_id = None
116-
tx_state.change_state(base.QueryTxStateEnum.ROLLBACKED)
141+
tx_state._change_state(QueryTxStateEnum.ROLLBACKED)
117142
return tx
118143

119144

@@ -196,7 +221,7 @@ def begin(self, settings=None):
196221
197222
:return: An open transaction
198223
"""
199-
self._tx_state.check_invalid_transition(base.QueryTxStateEnum.BEGINED)
224+
self._tx_state._check_invalid_transition(QueryTxStateEnum.BEGINED)
200225

201226
return self._driver(
202227
_create_begin_transaction_request(self._session_state, self._tx_state),
@@ -216,8 +241,9 @@ def commit(self, settings=None):
216241
217242
:return: A committed transaction or exception if commit is failed
218243
"""
219-
220-
self._tx_state.check_invalid_transition(base.QueryTxStateEnum.COMMITTED)
244+
if self._tx_state._already_in(QueryTxStateEnum.COMMITTED):
245+
return
246+
self._tx_state._check_invalid_transition(QueryTxStateEnum.COMMITTED)
221247

222248
return self._driver(
223249
_create_commit_transaction_request(self._session_state, self._tx_state),
@@ -229,7 +255,10 @@ def commit(self, settings=None):
229255
)
230256

231257
def rollback(self, settings=None):
232-
self._tx_state.check_invalid_transition(base.QueryTxStateEnum.ROLLBACKED)
258+
if self._tx_state._already_in(QueryTxStateEnum.ROLLBACKED):
259+
return
260+
261+
self._tx_state._check_invalid_transition(QueryTxStateEnum.ROLLBACKED)
233262

234263
return self._driver(
235264
_create_rollback_transaction_request(self._session_state, self._tx_state),
@@ -240,5 +269,24 @@ def rollback(self, settings=None):
240269
(self._session_state, self._tx_state, self),
241270
)
242271

272+
def _execute_call(self, query: str, commit_tx: bool):
273+
request = base.create_execute_query_request(
274+
query=query,
275+
session_id=self._session_state.session_id,
276+
commit_tx=commit_tx
277+
)
278+
return self._driver(
279+
request,
280+
_apis.QueryService.Stub,
281+
_apis.QueryService.ExecuteQuery,
282+
)
283+
243284
def execute(self, query, parameters=None, commit_tx=False, settings=None):
244-
pass
285+
self._tx_state._check_tx_not_terminal()
286+
287+
stream_it = self._execute_call(query, commit_tx)
288+
289+
return _utilities.SyncResponseIterator(
290+
stream_it,
291+
lambda resp: base.wrap_execute_query_response(rpc_state=None, response_pb=resp),
292+
)

0 commit comments

Comments
 (0)