3
3
import enum
4
4
import functools
5
5
from typing import (
6
+ Iterable ,
6
7
Optional ,
7
8
)
8
9
11
12
issues ,
12
13
)
13
14
from .._grpc .grpcwrapper import ydb_query as _ydb_query
15
+ from ..connection import _RpcState as RpcState
14
16
15
17
from . import base
16
18
@@ -50,7 +52,7 @@ def terminal(cls, state: QueryTxStateEnum) -> bool:
50
52
51
53
def reset_tx_id_handler (func ):
52
54
@functools .wraps (func )
53
- def decorator (rpc_state , response_pb , session_state , tx_state , * args , ** kwargs ):
55
+ def decorator (rpc_state , response_pb , session_state : base . IQuerySessionState , tx_state : QueryTxState , * args , ** kwargs ):
54
56
try :
55
57
return func (rpc_state , response_pb , session_state , tx_state , * args , ** kwargs )
56
58
except issues .Error :
@@ -87,35 +89,47 @@ def _already_in(self, target: QueryTxStateEnum) -> bool:
87
89
return self ._state == target
88
90
89
91
90
- def _construct_tx_settings (tx_state ) :
92
+ def _construct_tx_settings (tx_state : QueryTxState ) -> _ydb_query . TransactionSettings :
91
93
tx_settings = _ydb_query .TransactionSettings .from_public (tx_state .tx_mode )
92
94
return tx_settings
93
95
94
96
95
- def _create_begin_transaction_request (session_state , tx_state ):
97
+ def _create_begin_transaction_request (
98
+ session_state : base .IQuerySessionState , tx_state : QueryTxState
99
+ ) -> _apis .ydb_query .BeginTransactionRequest :
96
100
request = _ydb_query .BeginTransactionRequest (
97
101
session_id = session_state .session_id ,
98
102
tx_settings = _construct_tx_settings (tx_state ),
99
103
).to_proto ()
100
104
return request
101
105
102
106
103
- def _create_commit_transaction_request (session_state , tx_state ):
107
+ def _create_commit_transaction_request (
108
+ session_state : base .IQuerySessionState , tx_state : QueryTxState
109
+ ) -> _apis .ydb_query .CommitTransactionRequest :
104
110
request = _apis .ydb_query .CommitTransactionRequest ()
105
111
request .tx_id = tx_state .tx_id
106
112
request .session_id = session_state .session_id
107
113
return request
108
114
109
115
110
- def _create_rollback_transaction_request (session_state , tx_state ):
116
+ def _create_rollback_transaction_request (
117
+ session_state : base .IQuerySessionState , tx_state : QueryTxState
118
+ ) -> _apis .ydb_query .RollbackTransactionRequest :
111
119
request = _apis .ydb_query .RollbackTransactionRequest ()
112
120
request .tx_id = tx_state .tx_id
113
121
request .session_id = session_state .session_id
114
122
return request
115
123
116
124
117
125
@base .bad_session_handler
118
- def wrap_tx_begin_response (rpc_state , response_pb , session_state , tx_state , tx ):
126
+ def wrap_tx_begin_response (
127
+ rpc_state : RpcState ,
128
+ response_pb : _apis .ydb_query .BeginTransactionResponse ,
129
+ session_state : base .IQuerySessionState ,
130
+ tx_state : QueryTxState ,
131
+ tx : "BaseQueryTxContext" ,
132
+ ) -> "BaseQueryTxContext" :
119
133
message = _ydb_query .BeginTransactionResponse .from_proto (response_pb )
120
134
issues ._process_response (message .status )
121
135
tx_state ._change_state (QueryTxStateEnum .BEGINED )
@@ -125,7 +139,13 @@ def wrap_tx_begin_response(rpc_state, response_pb, session_state, tx_state, tx):
125
139
126
140
@base .bad_session_handler
127
141
@reset_tx_id_handler
128
- def wrap_tx_commit_response (rpc_state , response_pb , session_state , tx_state , tx ):
142
+ def wrap_tx_commit_response (
143
+ rpc_state : RpcState ,
144
+ response_pb : _apis .ydb_query .CommitTransactionResponse ,
145
+ session_state : base .IQuerySessionState ,
146
+ tx_state : QueryTxState ,
147
+ tx : "BaseQueryTxContext" ,
148
+ ) -> "BaseQueryTxContext" :
129
149
message = _ydb_query .CommitTransactionResponse .from_proto (response_pb )
130
150
issues ._process_response (message .status )
131
151
tx_state ._change_state (QueryTxStateEnum .COMMITTED )
@@ -134,7 +154,13 @@ def wrap_tx_commit_response(rpc_state, response_pb, session_state, tx_state, tx)
134
154
135
155
@base .bad_session_handler
136
156
@reset_tx_id_handler
137
- def wrap_tx_rollback_response (rpc_state , response_pb , session_state , tx_state , tx ):
157
+ def wrap_tx_rollback_response (
158
+ rpc_state : RpcState ,
159
+ response_pb : _apis .ydb_query .RollbackTransactionResponse ,
160
+ session_state : base .IQuerySessionState ,
161
+ tx_state : QueryTxState ,
162
+ tx : "BaseQueryTxContext" ,
163
+ ) -> "BaseQueryTxContext" :
138
164
message = _ydb_query .RollbackTransactionResponse .from_proto (response_pb )
139
165
issues ._process_response (message .status )
140
166
tx_state ._change_state (QueryTxStateEnum .ROLLBACKED )
@@ -211,7 +237,7 @@ def tx_id(self) -> Optional[str]:
211
237
"""
212
238
return self ._tx_state .tx_id
213
239
214
- def _begin_call (self , settings : Optional [base .QueryClientSettings ]):
240
+ def _begin_call (self , settings : Optional [base .QueryClientSettings ]) -> "BaseQueryTxContext" :
215
241
return self ._driver (
216
242
_create_begin_transaction_request (self ._session_state , self ._tx_state ),
217
243
_apis .QueryService .Stub ,
@@ -221,7 +247,7 @@ def _begin_call(self, settings: Optional[base.QueryClientSettings]):
221
247
(self ._session_state , self ._tx_state , self ),
222
248
)
223
249
224
- def _commit_call (self , settings : Optional [base .QueryClientSettings ]):
250
+ def _commit_call (self , settings : Optional [base .QueryClientSettings ]) -> "BaseQueryTxContext" :
225
251
return self ._driver (
226
252
_create_commit_transaction_request (self ._session_state , self ._tx_state ),
227
253
_apis .QueryService .Stub ,
@@ -231,7 +257,7 @@ def _commit_call(self, settings: Optional[base.QueryClientSettings]):
231
257
(self ._session_state , self ._tx_state , self ),
232
258
)
233
259
234
- def _rollback_call (self , settings : Optional [base .QueryClientSettings ]):
260
+ def _rollback_call (self , settings : Optional [base .QueryClientSettings ]) -> "BaseQueryTxContext" :
235
261
return self ._driver (
236
262
_create_rollback_transaction_request (self ._session_state , self ._tx_state ),
237
263
_apis .QueryService .Stub ,
@@ -249,7 +275,7 @@ def _execute_call(
249
275
exec_mode : base .QueryExecMode = None ,
250
276
parameters : dict = None ,
251
277
concurrent_result_sets : bool = False ,
252
- ):
278
+ ) -> Iterable [ _apis . ydb_query . ExecuteQueryResponsePart ] :
253
279
request = base .create_execute_query_request (
254
280
query = query ,
255
281
session_id = self ._session_state .session_id ,
@@ -263,23 +289,24 @@ def _execute_call(
263
289
)
264
290
265
291
return self ._driver (
266
- request ,
292
+ request . to_proto () ,
267
293
_apis .QueryService .Stub ,
268
294
_apis .QueryService .ExecuteQuery ,
269
295
)
270
296
271
- def _ensure_prev_stream_finished (self ):
297
+ def _ensure_prev_stream_finished (self ) -> None :
272
298
if self ._prev_stream is not None :
273
299
for _ in self ._prev_stream :
274
300
pass
275
301
self ._prev_stream = None
276
302
277
- def _handle_tx_meta (self , tx_meta = None ):
278
- if not self .tx_id and tx_meta :
279
- self ._tx_state ._change_state (QueryTxStateEnum .BEGINED )
280
- self ._tx_state .tx_id = tx_meta .id
303
+ def _move_to_beginned (self , tx_id : str ) -> None :
304
+ if self ._tx_state ._already_in (QueryTxStateEnum .BEGINED ):
305
+ return
306
+ self ._tx_state ._change_state (QueryTxStateEnum .BEGINED )
307
+ self ._tx_state .tx_id = tx_id
281
308
282
- def _move_to_commited (self ):
309
+ def _move_to_commited (self ) -> None :
283
310
if self ._tx_state ._already_in (QueryTxStateEnum .COMMITTED ):
284
311
return
285
312
self ._tx_state ._change_state (QueryTxStateEnum .COMMITTED )
0 commit comments