1
1
import abc
2
2
import logging
3
+ import enum
4
+ import functools
3
5
4
6
from .. import (
5
7
_apis ,
9
11
from .._grpc .grpcwrapper import ydb_query as _ydb_query
10
12
from .._grpc .grpcwrapper import ydb_query_public_types as _ydb_query_public
11
13
12
- from .._tx_ctx_impl import TxState , reset_tx_id_handler
13
- from .._session_impl import bad_session_handler
14
14
from . import base
15
15
16
16
logger = logging .getLogger (__name__ )
30
30
# return tx_mode
31
31
32
32
33
+ def reset_tx_id_handler (func ):
34
+ @functools .wraps (func )
35
+ def decorator (rpc_state , response_pb , session_state , tx_state , * args , ** kwargs ):
36
+ try :
37
+ return func (rpc_state , response_pb , session_state , tx_state , * args , ** kwargs )
38
+ except issues .Error :
39
+ tx_state .change_state (base .QueryTxStateEnum .DEAD )
40
+ tx_state .tx_id = None
41
+ raise
42
+
43
+ return decorator
44
+
45
+
46
+ class QueryTxState :
47
+ def __init__ (self , tx_mode : base .BaseQueryTxMode ):
48
+ """
49
+ Holds transaction context manager info
50
+ :param tx_mode: A mode of transaction
51
+ """
52
+ self .tx_id = None
53
+ self .tx_mode = tx_mode
54
+ self ._state = base .QueryTxStateEnum .NOT_INITIALIZED
55
+
56
+ def check_invalid_transition (self , target : base .QueryTxStateEnum ):
57
+ if not base .QueryTxStateHelper .is_valid_transition (self ._state , target ):
58
+ raise RuntimeError (f"Transaction could not be moved from { self ._state .value } to { target .value } " )
59
+
60
+ def change_state (self , target : base .QueryTxStateEnum ):
61
+ self .check_invalid_transition (target )
62
+ self ._state = target
63
+
64
+
33
65
def _construct_tx_settings (tx_state ):
34
66
tx_settings = _ydb_query .TransactionSettings .from_public (tx_state .tx_mode )
35
67
return tx_settings
@@ -41,8 +73,6 @@ def _create_begin_transaction_request(session_state, tx_state):
41
73
tx_settings = _construct_tx_settings (tx_state ),
42
74
).to_proto ()
43
75
44
- print (request )
45
-
46
76
return request
47
77
48
78
@@ -59,32 +89,35 @@ def _create_rollback_transaction_request(session_state, tx_state):
59
89
return request
60
90
61
91
62
- @bad_session_handler
92
+ @base . bad_session_handler
63
93
def wrap_tx_begin_response (rpc_state , response_pb , session_state , tx_state , tx ):
64
- # session_state.complete_query()
65
- # issues._process_response(response_pb.operation)
66
- print ("wrap result" )
67
94
message = _ydb_query .BeginTransactionResponse .from_proto (response_pb )
68
-
95
+ issues ._process_response (message .status )
96
+ tx_state .change_state (base .QueryTxStateEnum .BEGINED )
69
97
tx_state .tx_id = message .tx_meta .tx_id
70
98
return tx
71
99
72
100
73
- @bad_session_handler
101
+ @base . bad_session_handler
74
102
@reset_tx_id_handler
75
- def wrap_result_on_rollback_or_commit_tx (rpc_state , response_pb , session_state , tx_state , tx ):
103
+ def wrap_tx_commit_response (rpc_state , response_pb , session_state , tx_state , tx ):
104
+ message = _ydb_query .CommitTransactionResponse (response_pb )
105
+ issues ._process_response (message .status )
106
+ tx_state .tx_id = None
107
+ tx_state .change_state (base .QueryTxStateEnum .COMMITTED )
108
+ return tx
76
109
77
- # issues._process_response(response_pb.operation)
78
- # transaction successfully committed or rolled back
110
+ @base .bad_session_handler
111
+ @reset_tx_id_handler
112
+ def wrap_tx_rollback_response (rpc_state , response_pb , session_state , tx_state , tx ):
113
+ message = _ydb_query .RollbackTransactionResponse (response_pb )
114
+ issues ._process_response (message .status )
79
115
tx_state .tx_id = None
116
+ tx_state .change_state (base .QueryTxStateEnum .ROLLBACKED )
80
117
return tx
81
118
82
119
83
120
class BaseTxContext (base .IQueryTxContext ):
84
-
85
- _COMMIT = "commit"
86
- _ROLLBACK = "rollback"
87
-
88
121
def __init__ (self , driver , session_state , session , tx_mode = None ):
89
122
"""
90
123
An object that provides a simple transaction context manager that allows statements execution
@@ -106,7 +139,7 @@ def __init__(self, driver, session_state, session, tx_mode=None):
106
139
self ._driver = driver
107
140
if tx_mode is None :
108
141
tx_mode = _ydb_query_public .QuerySerializableReadWrite ()
109
- self ._tx_state = TxState (tx_mode )
142
+ self ._tx_state = QueryTxState (tx_mode )
110
143
self ._session_state = session_state
111
144
self .session = session
112
145
self ._finished = ""
@@ -163,17 +196,15 @@ def begin(self, settings=None):
163
196
164
197
:return: An open transaction
165
198
"""
166
- if self ._tx_state .tx_id is not None :
167
- return self
168
-
169
- print ('try to begin tx' )
199
+ self ._tx_state .check_invalid_transition (base .QueryTxStateEnum .BEGINED )
170
200
171
201
return self ._driver (
172
202
_create_begin_transaction_request (self ._session_state , self ._tx_state ),
173
203
_apis .QueryService .Stub ,
174
204
_apis .QueryService .BeginTransaction ,
175
- wrap_result = wrap_tx_begin_response ,
176
- wrap_args = (self ._session_state , self ._tx_state , self ),
205
+ wrap_tx_begin_response ,
206
+ settings ,
207
+ (self ._session_state , self ._tx_state , self ),
177
208
)
178
209
179
210
def commit (self , settings = None ):
@@ -186,22 +217,28 @@ def commit(self, settings=None):
186
217
:return: A committed transaction or exception if commit is failed
187
218
"""
188
219
189
- self ._set_finish (self ._COMMIT )
190
-
191
- if self ._tx_state .tx_id is None and not self ._tx_state .dead :
192
- return self
220
+ self ._tx_state .check_invalid_transition (base .QueryTxStateEnum .COMMITTED )
193
221
194
222
return self ._driver (
195
223
_create_commit_transaction_request (self ._session_state , self ._tx_state ),
196
224
_apis .QueryService .Stub ,
197
225
_apis .QueryService .CommitTransaction ,
198
- wrap_result_on_rollback_or_commit_tx ,
226
+ wrap_tx_commit_response ,
199
227
settings ,
200
228
(self ._session_state , self ._tx_state , self ),
201
229
)
202
230
203
231
def rollback (self , settings = None ):
204
- pass
232
+ self ._tx_state .check_invalid_transition (base .QueryTxStateEnum .ROLLBACKED )
233
+
234
+ return self ._driver (
235
+ _create_rollback_transaction_request (self ._session_state , self ._tx_state ),
236
+ _apis .QueryService .Stub ,
237
+ _apis .QueryService .RollbackTransaction ,
238
+ wrap_tx_rollback_response ,
239
+ settings ,
240
+ (self ._session_state , self ._tx_state , self ),
241
+ )
205
242
206
243
def execute (self , query , parameters = None , commit_tx = False , settings = None ):
207
244
pass
0 commit comments