Skip to content

Commit 9adfa30

Browse files
committed
tx state handler
1 parent 8bf6bea commit 9adfa30

File tree

3 files changed

+157
-37
lines changed

3 files changed

+157
-37
lines changed

ydb/_grpc/grpcwrapper/ydb_query.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,29 @@ def from_proto(msg: ydb_query_pb2.BeginTransactionResponse) -> "BeginTransaction
116116
tx_meta=TransactionMeta.from_proto(msg.tx_meta),
117117
)
118118

119+
120+
@dataclass
121+
class CommitTransactionResponse(IFromProto):
122+
status: Optional[ServerStatus]
123+
124+
@staticmethod
125+
def from_proto(msg: ydb_query_pb2.CommitTransactionResponse) -> "CommitTransactionResponse":
126+
return CommitTransactionResponse(
127+
status=ServerStatus(msg.status, msg.issues),
128+
)
129+
130+
131+
@dataclass
132+
class RollbackTransactionResponse(IFromProto):
133+
status: Optional[ServerStatus]
134+
135+
@staticmethod
136+
def from_proto(msg: ydb_query_pb2.RollbackTransactionResponse) -> "RollbackTransactionResponse":
137+
return RollbackTransactionResponse(
138+
status=ServerStatus(msg.status, msg.issues),
139+
)
140+
141+
119142
@dataclass
120143
class QueryContent(IFromPublic, IToProto):
121144
text: str

ydb/query/base.py

Lines changed: 67 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import abc
2+
import enum
3+
import functools
24

35
from typing import (
46
Optional,
@@ -13,10 +15,51 @@
1315
QuerySerializableReadWrite
1416
)
1517
from .. import convert
18+
from .. import issues
1619

1720
class QueryClientSettings: ...
1821

1922

23+
class QuerySessionStateEnum(enum.Enum):
24+
NOT_INITIALIZED = "NOT_INITIALIZED"
25+
CREATED = "CREATED"
26+
CLOSED = "CLOSED"
27+
28+
29+
class QuerySessionStateHelper(abc.ABC):
30+
_VALID_TRANSITIONS = {
31+
QuerySessionStateEnum.NOT_INITIALIZED: [QuerySessionStateEnum.CREATED],
32+
QuerySessionStateEnum.CREATED: [QuerySessionStateEnum.CLOSED],
33+
QuerySessionStateEnum.CLOSED: []
34+
}
35+
36+
@classmethod
37+
def valid_transition(cls, before: QuerySessionStateEnum, after: QuerySessionStateEnum) -> bool:
38+
return after in cls._VALID_TRANSITIONS[before]
39+
40+
41+
class QueryTxStateEnum(enum.Enum):
42+
NOT_INITIALIZED = "NOT_INITIALIZED"
43+
BEGINED = "BEGINED"
44+
COMMITTED = "COMMITTED"
45+
ROLLBACKED = "ROLLBACKED"
46+
DEAD = "DEAD"
47+
48+
49+
class QueryTxStateHelper(abc.ABC):
50+
_VALID_TRANSITIONS = {
51+
QueryTxStateEnum.NOT_INITIALIZED: [QueryTxStateEnum.BEGINED, QueryTxStateEnum.DEAD],
52+
QueryTxStateEnum.BEGINED: [QueryTxStateEnum.COMMITTED, QueryTxStateEnum.ROLLBACKED, QueryTxStateEnum.DEAD],
53+
QueryTxStateEnum.COMMITTED: [],
54+
QueryTxStateEnum.ROLLBACKED: [],
55+
QueryTxStateEnum.DEAD: [],
56+
}
57+
58+
@classmethod
59+
def valid_transition(cls, before: QueryTxStateEnum, after: QueryTxStateEnum) -> bool:
60+
return after in cls._VALID_TRANSITIONS[before]
61+
62+
2063
class QuerySessionState:
2164
_session_id: Optional[str]
2265
_node_id: Optional[int]
@@ -99,15 +142,15 @@ def tx_id(self):
99142
pass
100143

101144
@abc.abstractmethod
102-
def begin():
145+
def begin(settings: QueryClientSettings = None):
103146
pass
104147

105148
@abc.abstractmethod
106-
def commit():
149+
def commit(settings: QueryClientSettings = None):
107150
pass
108151

109152
@abc.abstractmethod
110-
def rollback():
153+
def rollback(settings: QueryClientSettings = None):
111154
pass
112155

113156
@abc.abstractmethod
@@ -142,9 +185,26 @@ def create_execute_query_request(query: str, session_id: str, commit_tx: bool):
142185

143186
def wrap_execute_query_response(rpc_state, response_pb):
144187

145-
# print("RESP:")
146-
# print(f"meta: {response_pb.tx_meta}")
147-
# print(response_pb)
188+
return convert.ResultSet.from_message(response_pb.result_set)
148189

190+
X_YDB_SERVER_HINTS = "x-ydb-server-hints"
191+
X_YDB_SESSION_CLOSE = "session-close"
149192

150-
return convert.ResultSet.from_message(response_pb.result_set)
193+
194+
def _check_session_is_closing(rpc_state, session_state):
195+
metadata = rpc_state.trailing_metadata()
196+
if X_YDB_SESSION_CLOSE in metadata.get(X_YDB_SERVER_HINTS, []):
197+
session_state.set_closing()
198+
199+
200+
def bad_session_handler(func):
201+
@functools.wraps(func)
202+
def decorator(rpc_state, response_pb, session_state, *args, **kwargs):
203+
try:
204+
_check_session_is_closing(rpc_state, session_state)
205+
return func(rpc_state, response_pb, session_state, *args, **kwargs)
206+
except issues.BadSession:
207+
session_state.reset()
208+
raise
209+
210+
return decorator

ydb/query/transaction.py

Lines changed: 67 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import abc
22
import logging
3+
import enum
4+
import functools
35

46
from .. import (
57
_apis,
@@ -9,8 +11,6 @@
911
from .._grpc.grpcwrapper import ydb_query as _ydb_query
1012
from .._grpc.grpcwrapper import ydb_query_public_types as _ydb_query_public
1113

12-
from .._tx_ctx_impl import TxState, reset_tx_id_handler
13-
from .._session_impl import bad_session_handler
1414
from . import base
1515

1616
logger = logging.getLogger(__name__)
@@ -30,6 +30,38 @@
3030
# return tx_mode
3131

3232

33+
def reset_tx_id_handler(func):
34+
@functools.wraps(func)
35+
def decorator(rpc_state, response_pb, session_state, tx_state, *args, **kwargs):
36+
try:
37+
return func(rpc_state, response_pb, session_state, tx_state, *args, **kwargs)
38+
except issues.Error:
39+
tx_state.change_state(base.QueryTxStateEnum.DEAD)
40+
tx_state.tx_id = None
41+
raise
42+
43+
return decorator
44+
45+
46+
class QueryTxState:
47+
def __init__(self, tx_mode: base.BaseQueryTxMode):
48+
"""
49+
Holds transaction context manager info
50+
:param tx_mode: A mode of transaction
51+
"""
52+
self.tx_id = None
53+
self.tx_mode = tx_mode
54+
self._state = base.QueryTxStateEnum.NOT_INITIALIZED
55+
56+
def check_invalid_transition(self, target: base.QueryTxStateEnum):
57+
if not base.QueryTxStateHelper.is_valid_transition(self._state, target):
58+
raise RuntimeError(f"Transaction could not be moved from {self._state.value} to {target.value}")
59+
60+
def change_state(self, target: base.QueryTxStateEnum):
61+
self.check_invalid_transition(target)
62+
self._state = target
63+
64+
3365
def _construct_tx_settings(tx_state):
3466
tx_settings = _ydb_query.TransactionSettings.from_public(tx_state.tx_mode)
3567
return tx_settings
@@ -41,8 +73,6 @@ def _create_begin_transaction_request(session_state, tx_state):
4173
tx_settings=_construct_tx_settings(tx_state),
4274
).to_proto()
4375

44-
print(request)
45-
4676
return request
4777

4878

@@ -59,32 +89,35 @@ def _create_rollback_transaction_request(session_state, tx_state):
5989
return request
6090

6191

62-
@bad_session_handler
92+
@base.bad_session_handler
6393
def wrap_tx_begin_response(rpc_state, response_pb, session_state, tx_state, tx):
64-
# session_state.complete_query()
65-
# issues._process_response(response_pb.operation)
66-
print("wrap result")
6794
message = _ydb_query.BeginTransactionResponse.from_proto(response_pb)
68-
95+
issues._process_response(message.status)
96+
tx_state.change_state(base.QueryTxStateEnum.BEGINED)
6997
tx_state.tx_id = message.tx_meta.tx_id
7098
return tx
7199

72100

73-
@bad_session_handler
101+
@base.bad_session_handler
74102
@reset_tx_id_handler
75-
def wrap_result_on_rollback_or_commit_tx(rpc_state, response_pb, session_state, tx_state, tx):
103+
def wrap_tx_commit_response(rpc_state, response_pb, session_state, tx_state, tx):
104+
message = _ydb_query.CommitTransactionResponse(response_pb)
105+
issues._process_response(message.status)
106+
tx_state.tx_id = None
107+
tx_state.change_state(base.QueryTxStateEnum.COMMITTED)
108+
return tx
76109

77-
# issues._process_response(response_pb.operation)
78-
# transaction successfully committed or rolled back
110+
@base.bad_session_handler
111+
@reset_tx_id_handler
112+
def wrap_tx_rollback_response(rpc_state, response_pb, session_state, tx_state, tx):
113+
message = _ydb_query.RollbackTransactionResponse(response_pb)
114+
issues._process_response(message.status)
79115
tx_state.tx_id = None
116+
tx_state.change_state(base.QueryTxStateEnum.ROLLBACKED)
80117
return tx
81118

82119

83120
class BaseTxContext(base.IQueryTxContext):
84-
85-
_COMMIT = "commit"
86-
_ROLLBACK = "rollback"
87-
88121
def __init__(self, driver, session_state, session, tx_mode=None):
89122
"""
90123
An object that provides a simple transaction context manager that allows statements execution
@@ -106,7 +139,7 @@ def __init__(self, driver, session_state, session, tx_mode=None):
106139
self._driver = driver
107140
if tx_mode is None:
108141
tx_mode = _ydb_query_public.QuerySerializableReadWrite()
109-
self._tx_state = TxState(tx_mode)
142+
self._tx_state = QueryTxState(tx_mode)
110143
self._session_state = session_state
111144
self.session = session
112145
self._finished = ""
@@ -163,17 +196,15 @@ def begin(self, settings=None):
163196
164197
:return: An open transaction
165198
"""
166-
if self._tx_state.tx_id is not None:
167-
return self
168-
169-
print('try to begin tx')
199+
self._tx_state.check_invalid_transition(base.QueryTxStateEnum.BEGINED)
170200

171201
return self._driver(
172202
_create_begin_transaction_request(self._session_state, self._tx_state),
173203
_apis.QueryService.Stub,
174204
_apis.QueryService.BeginTransaction,
175-
wrap_result=wrap_tx_begin_response,
176-
wrap_args=(self._session_state, self._tx_state, self),
205+
wrap_tx_begin_response,
206+
settings,
207+
(self._session_state, self._tx_state, self),
177208
)
178209

179210
def commit(self, settings=None):
@@ -186,22 +217,28 @@ def commit(self, settings=None):
186217
:return: A committed transaction or exception if commit is failed
187218
"""
188219

189-
self._set_finish(self._COMMIT)
190-
191-
if self._tx_state.tx_id is None and not self._tx_state.dead:
192-
return self
220+
self._tx_state.check_invalid_transition(base.QueryTxStateEnum.COMMITTED)
193221

194222
return self._driver(
195223
_create_commit_transaction_request(self._session_state, self._tx_state),
196224
_apis.QueryService.Stub,
197225
_apis.QueryService.CommitTransaction,
198-
wrap_result_on_rollback_or_commit_tx,
226+
wrap_tx_commit_response,
199227
settings,
200228
(self._session_state, self._tx_state, self),
201229
)
202230

203231
def rollback(self, settings=None):
204-
pass
232+
self._tx_state.check_invalid_transition(base.QueryTxStateEnum.ROLLBACKED)
233+
234+
return self._driver(
235+
_create_rollback_transaction_request(self._session_state, self._tx_state),
236+
_apis.QueryService.Stub,
237+
_apis.QueryService.RollbackTransaction,
238+
wrap_tx_rollback_response,
239+
settings,
240+
(self._session_state, self._tx_state, self),
241+
)
205242

206243
def execute(self, query, parameters=None, commit_tx=False, settings=None):
207244
pass

0 commit comments

Comments
 (0)