Skip to content

Commit ec7274f

Browse files
committed
interactive tx support
1 parent c3d1d2b commit ec7274f

File tree

3 files changed

+32
-17
lines changed

3 files changed

+32
-17
lines changed

tests/query/test_query_transaction.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,14 @@ def test_tx_rollback_raises_before_begin(self, tx):
2626
with pytest.raises(RuntimeError):
2727
tx.rollback()
2828

29-
# def test_tx_execute_raises_before_begin(self, tx):
30-
# with pytest.raises(RuntimeError):
31-
# tx.execute("select 1;")
29+
def test_tx_first_execute_begins_tx(self, tx):
30+
tx.execute("select 1;")
31+
tx.commit()
32+
33+
def test_interactive_tx_commit(self, tx):
34+
tx.execute("select 1;", commit_tx=True)
35+
with pytest.raises(RuntimeError):
36+
tx.execute("select 1;")
3237

3338
def text_tx_execute_raises_after_commit(self, tx):
3439
tx.begin()

ydb/query/base.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -192,26 +192,23 @@ def create_execute_query_request(
192192
return req.to_proto()
193193

194194

195-
def wrap_execute_query_response(rpc_state, response_pb):
195+
def wrap_execute_query_response(rpc_state, response_pb, tx, commit_tx=False):
196196
issues._process_response(response_pb)
197+
if response_pb.tx_meta and not tx.tx_id:
198+
tx._handle_tx_meta(response_pb.tx_meta)
199+
if commit_tx:
200+
tx._move_to_commited()
197201
return convert.ResultSet.from_message(response_pb.result_set)
198202

199203

200204
X_YDB_SERVER_HINTS = "x-ydb-server-hints"
201205
X_YDB_SESSION_CLOSE = "session-close"
202206

203207

204-
# def _check_session_is_closing(rpc_state, session_state):
205-
# metadata = rpc_state.trailing_metadata()
206-
# if X_YDB_SESSION_CLOSE in metadata.get(X_YDB_SERVER_HINTS, []):
207-
# session_state.set_closing() # TODO: clarify & implement
208-
209-
210208
def bad_session_handler(func):
211209
@functools.wraps(func)
212210
def decorator(rpc_state, response_pb, session_state, *args, **kwargs):
213211
try:
214-
# _check_session_is_closing(rpc_state, session_state)
215212
return func(rpc_state, response_pb, session_state, *args, **kwargs)
216213
except issues.BadSession:
217214
session_state.reset()

ydb/query/transaction.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -245,9 +245,8 @@ def commit(self, settings=None):
245245
"""
246246
if self._tx_state._already_in(QueryTxStateEnum.COMMITTED):
247247
return
248-
self._tx_state._check_invalid_transition(QueryTxStateEnum.COMMITTED)
249-
250248
self._ensure_prev_stream_finished()
249+
self._tx_state._check_invalid_transition(QueryTxStateEnum.COMMITTED)
251250

252251
return self._driver(
253252
_create_commit_transaction_request(self._session_state, self._tx_state),
@@ -262,9 +261,8 @@ def rollback(self, settings=None):
262261
if self._tx_state._already_in(QueryTxStateEnum.ROLLBACKED):
263262
return
264263

265-
self._tx_state._check_invalid_transition(QueryTxStateEnum.ROLLBACKED)
266-
267264
self._ensure_prev_stream_finished()
265+
self._tx_state._check_invalid_transition(QueryTxStateEnum.ROLLBACKED)
268266

269267
return self._driver(
270268
_create_rollback_transaction_request(self._session_state, self._tx_state),
@@ -309,6 +307,16 @@ def _ensure_prev_stream_finished(self):
309307
pass
310308
self._prev_stream = None
311309

310+
def _handle_tx_meta(self, tx_meta=None):
311+
if not self.tx_id:
312+
self._tx_state._change_state(QueryTxStateEnum.BEGINED)
313+
self._tx_state.tx_id = tx_meta.id
314+
315+
def _move_to_commited(self):
316+
if self._tx_state._already_in(QueryTxStateEnum.COMMITTED):
317+
return
318+
self._tx_state._change_state(QueryTxStateEnum.COMMITTED)
319+
312320
def execute(
313321
self,
314322
query: str,
@@ -319,8 +327,8 @@ def execute(
319327
parameters: dict = None,
320328
concurrent_result_sets: bool = False,
321329
):
322-
self._tx_state._check_tx_not_terminal()
323330
self._ensure_prev_stream_finished()
331+
self._tx_state._check_tx_not_terminal()
324332

325333
stream_it = self._execute_call(
326334
query=query,
@@ -333,6 +341,11 @@ def execute(
333341
)
334342
self._prev_stream = _utilities.SyncResponseIterator(
335343
stream_it,
336-
lambda resp: base.wrap_execute_query_response(rpc_state=None, response_pb=resp),
344+
lambda resp: base.wrap_execute_query_response(
345+
rpc_state=None,
346+
response_pb=resp,
347+
tx=self,
348+
commit_tx=commit_tx,
349+
),
337350
)
338351
return self._prev_stream

0 commit comments

Comments
 (0)