Skip to content

Commit 2d8a2ff

Browse files
committed
style fixes
1 parent d0c9388 commit 2d8a2ff

File tree

3 files changed

+80
-36
lines changed

3 files changed

+80
-36
lines changed

ydb/query/base.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,11 @@
1414
from .._grpc.grpcwrapper.ydb_query_public_types import (
1515
BaseQueryTxMode,
1616
)
17+
from ..connection import _RpcState as RpcState
1718
from .. import convert
1819
from .. import issues
1920
from .. import _utilities
21+
from .. import _apis
2022

2123

2224
class QuerySyntax(enum.IntEnum):
@@ -42,7 +44,7 @@ class StatsMode(enum.IntEnum):
4244

4345

4446
class SyncResponseContextIterator(_utilities.SyncResponseIterator):
45-
def __enter__(self):
47+
def __enter__(self) -> "SyncResponseContextIterator":
4648
return self
4749

4850
def __exit__(self, exc_type, exc_val, exc_tb):
@@ -315,7 +317,7 @@ def create_execute_query_request(
315317
exec_mode: Optional[QueryExecMode],
316318
parameters: Optional[dict],
317319
concurrent_result_sets: Optional[bool],
318-
):
320+
) -> ydb_query.ExecuteQueryRequest:
319321
syntax = QuerySyntax.YQL_V1 if not syntax else syntax
320322
exec_mode = QueryExecMode.EXECUTE if not exec_mode else exec_mode
321323
stats_mode = StatsMode.NONE # TODO: choise is not supported yet
@@ -338,7 +340,7 @@ def create_execute_query_request(
338340
tx_id=None,
339341
)
340342

341-
req = ydb_query.ExecuteQueryRequest(
343+
return ydb_query.ExecuteQueryRequest(
342344
session_id=session_id,
343345
query_content=ydb_query.QueryContent.from_public(
344346
query=query,
@@ -351,21 +353,24 @@ def create_execute_query_request(
351353
stats_mode=stats_mode,
352354
)
353355

354-
return req.to_proto()
355356

356-
357-
def wrap_execute_query_response(rpc_state, response_pb, tx=None, commit_tx=False):
357+
def wrap_execute_query_response(
358+
rpc_state: RpcState,
359+
response_pb: _apis.ydb_query.ExecuteQueryResponsePart,
360+
tx: Optional[IQueryTxContext] = None,
361+
commit_tx: Optional[bool] = False,
362+
) -> convert.ResultSet:
358363
issues._process_response(response_pb)
359364
if tx and response_pb.tx_meta and not tx.tx_id:
360-
tx._handle_tx_meta(response_pb.tx_meta)
365+
tx._move_to_beginned(response_pb.tx_meta.id)
361366
if tx and commit_tx:
362367
tx._move_to_commited()
363368
return convert.ResultSet.from_message(response_pb.result_set)
364369

365370

366371
def bad_session_handler(func):
367372
@functools.wraps(func)
368-
def decorator(rpc_state, response_pb, session_state, *args, **kwargs):
373+
def decorator(rpc_state, response_pb, session_state: IQuerySessionState, *args, **kwargs):
369374
try:
370375
return func(rpc_state, response_pb, session_state, *args, **kwargs)
371376
except issues.BadSession:

ydb/query/session.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@
33
import logging
44
import threading
55
from typing import (
6+
Iterable,
67
Optional,
78
)
89

910
from . import base
1011

1112
from .. import _apis, issues, _utilities
13+
from ..connection import _RpcState as RpcState
1214
from .._grpc.grpcwrapper import common_utils
1315
from .._grpc.grpcwrapper import ydb_query as _ydb_query
1416
from .._grpc.grpcwrapper import ydb_query_public_types as _ydb_query_public
@@ -99,14 +101,24 @@ def _already_in(self, target) -> bool:
99101
return self._state == target
100102

101103

102-
def wrapper_create_session(rpc_state, response_pb, session_state: QuerySessionState, session):
104+
def wrapper_create_session(
105+
rpc_state: RpcState,
106+
response_pb: _apis.ydb_query.CreateSessionResponse,
107+
session_state: QuerySessionState,
108+
session: "BaseQuerySession"
109+
) -> "BaseQuerySession":
103110
message = _ydb_query.CreateSessionResponse.from_proto(response_pb)
104111
issues._process_response(message.status)
105112
session_state.set_session_id(message.session_id).set_node_id(message.node_id)
106113
return session
107114

108115

109-
def wrapper_delete_session(rpc_state, response_pb, session_state: QuerySessionState, session):
116+
def wrapper_delete_session(
117+
rpc_state: RpcState,
118+
response_pb: _apis.ydb_query.DeleteSessionResponse,
119+
session_state: QuerySessionState,
120+
session: "BaseQuerySession"
121+
) -> "BaseQuerySession":
110122
message = _ydb_query.DeleteSessionResponse.from_proto(response_pb)
111123
issues._process_response(message.status)
112124
session_state.reset()
@@ -124,7 +136,7 @@ def __init__(self, driver: base.SupportedDriverType, settings: Optional[base.Que
124136
self._settings = settings if settings is not None else base.QueryClientSettings()
125137
self._state = QuerySessionState(settings)
126138

127-
def _create_call(self):
139+
def _create_call(self) -> "BaseQuerySession":
128140
return self._driver(
129141
_apis.ydb_query.CreateSessionRequest(),
130142
_apis.QueryService.Stub,
@@ -133,7 +145,7 @@ def _create_call(self):
133145
wrap_args=(self._state, self),
134146
)
135147

136-
def _delete_call(self):
148+
def _delete_call(self) -> "BaseQuerySession":
137149
return self._driver(
138150
_apis.ydb_query.DeleteSessionRequest(session_id=self._state.session_id),
139151
_apis.QueryService.Stub,
@@ -142,7 +154,7 @@ def _delete_call(self):
142154
wrap_args=(self._state, self),
143155
)
144156

145-
def _attach_call(self):
157+
def _attach_call(self) -> Iterable[_apis.ydb_query.SessionState]:
146158
return self._driver(
147159
_apis.ydb_query.AttachSessionRequest(session_id=self._state.session_id),
148160
_apis.QueryService.Stub,
@@ -157,7 +169,7 @@ def _execute_call(
157169
exec_mode: base.QueryExecMode = None,
158170
parameters: dict = None,
159171
concurrent_result_sets: bool = False,
160-
):
172+
) -> Iterable[_apis.ydb_query.ExecuteQueryResponsePart]:
161173
request = base.create_execute_query_request(
162174
query=query,
163175
session_id=self._state.session_id,
@@ -171,7 +183,7 @@ def _execute_call(
171183
)
172184

173185
return self._driver(
174-
request,
186+
request.to_proto(),
175187
_apis.QueryService.Stub,
176188
_apis.QueryService.ExecuteQuery,
177189
)
@@ -184,7 +196,7 @@ class QuerySessionSync(BaseQuerySession):
184196

185197
_stream = None
186198

187-
def _attach(self):
199+
def _attach(self) -> None:
188200
self._stream = self._attach_call()
189201
status_stream = _utilities.SyncResponseIterator(
190202
self._stream,
@@ -205,7 +217,7 @@ def _attach(self):
205217
daemon=True,
206218
).start()
207219

208-
def _check_session_status_loop(self, status_stream):
220+
def _check_session_status_loop(self, status_stream: _utilities.SyncResponseIterator) -> None:
209221
try:
210222
for status in status_stream:
211223
if status.status != issues.StatusCode.SUCCESS:

ydb/query/transaction.py

Lines changed: 46 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import enum
44
import functools
55
from typing import (
6+
Iterable,
67
Optional,
78
)
89

@@ -11,6 +12,7 @@
1112
issues,
1213
)
1314
from .._grpc.grpcwrapper import ydb_query as _ydb_query
15+
from ..connection import _RpcState as RpcState
1416

1517
from . import base
1618

@@ -50,7 +52,7 @@ def terminal(cls, state: QueryTxStateEnum) -> bool:
5052

5153
def reset_tx_id_handler(func):
5254
@functools.wraps(func)
53-
def decorator(rpc_state, response_pb, session_state, tx_state, *args, **kwargs):
55+
def decorator(rpc_state, response_pb, session_state: base.IQuerySessionState, tx_state: QueryTxState, *args, **kwargs):
5456
try:
5557
return func(rpc_state, response_pb, session_state, tx_state, *args, **kwargs)
5658
except issues.Error:
@@ -87,35 +89,47 @@ def _already_in(self, target: QueryTxStateEnum) -> bool:
8789
return self._state == target
8890

8991

90-
def _construct_tx_settings(tx_state):
92+
def _construct_tx_settings(tx_state: QueryTxState) -> _ydb_query.TransactionSettings:
9193
tx_settings = _ydb_query.TransactionSettings.from_public(tx_state.tx_mode)
9294
return tx_settings
9395

9496

95-
def _create_begin_transaction_request(session_state, tx_state):
97+
def _create_begin_transaction_request(
98+
session_state: base.IQuerySessionState, tx_state: QueryTxState
99+
) -> _apis.ydb_query.BeginTransactionRequest:
96100
request = _ydb_query.BeginTransactionRequest(
97101
session_id=session_state.session_id,
98102
tx_settings=_construct_tx_settings(tx_state),
99103
).to_proto()
100104
return request
101105

102106

103-
def _create_commit_transaction_request(session_state, tx_state):
107+
def _create_commit_transaction_request(
108+
session_state: base.IQuerySessionState, tx_state: QueryTxState
109+
) -> _apis.ydb_query.CommitTransactionRequest:
104110
request = _apis.ydb_query.CommitTransactionRequest()
105111
request.tx_id = tx_state.tx_id
106112
request.session_id = session_state.session_id
107113
return request
108114

109115

110-
def _create_rollback_transaction_request(session_state, tx_state):
116+
def _create_rollback_transaction_request(
117+
session_state: base.IQuerySessionState, tx_state: QueryTxState
118+
) -> _apis.ydb_query.RollbackTransactionRequest:
111119
request = _apis.ydb_query.RollbackTransactionRequest()
112120
request.tx_id = tx_state.tx_id
113121
request.session_id = session_state.session_id
114122
return request
115123

116124

117125
@base.bad_session_handler
118-
def wrap_tx_begin_response(rpc_state, response_pb, session_state, tx_state, tx):
126+
def wrap_tx_begin_response(
127+
rpc_state: RpcState,
128+
response_pb: _apis.ydb_query.BeginTransactionResponse,
129+
session_state: base.IQuerySessionState,
130+
tx_state: QueryTxState,
131+
tx: "BaseQueryTxContext",
132+
) -> "BaseQueryTxContext":
119133
message = _ydb_query.BeginTransactionResponse.from_proto(response_pb)
120134
issues._process_response(message.status)
121135
tx_state._change_state(QueryTxStateEnum.BEGINED)
@@ -125,7 +139,13 @@ def wrap_tx_begin_response(rpc_state, response_pb, session_state, tx_state, tx):
125139

126140
@base.bad_session_handler
127141
@reset_tx_id_handler
128-
def wrap_tx_commit_response(rpc_state, response_pb, session_state, tx_state, tx):
142+
def wrap_tx_commit_response(
143+
rpc_state: RpcState,
144+
response_pb: _apis.ydb_query.CommitTransactionResponse,
145+
session_state: base.IQuerySessionState,
146+
tx_state: QueryTxState,
147+
tx: "BaseQueryTxContext",
148+
) -> "BaseQueryTxContext":
129149
message = _ydb_query.CommitTransactionResponse.from_proto(response_pb)
130150
issues._process_response(message.status)
131151
tx_state._change_state(QueryTxStateEnum.COMMITTED)
@@ -134,7 +154,13 @@ def wrap_tx_commit_response(rpc_state, response_pb, session_state, tx_state, tx)
134154

135155
@base.bad_session_handler
136156
@reset_tx_id_handler
137-
def wrap_tx_rollback_response(rpc_state, response_pb, session_state, tx_state, tx):
157+
def wrap_tx_rollback_response(
158+
rpc_state: RpcState,
159+
response_pb: _apis.ydb_query.RollbackTransactionResponse,
160+
session_state: base.IQuerySessionState,
161+
tx_state: QueryTxState,
162+
tx: "BaseQueryTxContext",
163+
) -> "BaseQueryTxContext":
138164
message = _ydb_query.RollbackTransactionResponse.from_proto(response_pb)
139165
issues._process_response(message.status)
140166
tx_state._change_state(QueryTxStateEnum.ROLLBACKED)
@@ -211,7 +237,7 @@ def tx_id(self) -> Optional[str]:
211237
"""
212238
return self._tx_state.tx_id
213239

214-
def _begin_call(self, settings: Optional[base.QueryClientSettings]):
240+
def _begin_call(self, settings: Optional[base.QueryClientSettings]) -> "BaseQueryTxContext":
215241
return self._driver(
216242
_create_begin_transaction_request(self._session_state, self._tx_state),
217243
_apis.QueryService.Stub,
@@ -221,7 +247,7 @@ def _begin_call(self, settings: Optional[base.QueryClientSettings]):
221247
(self._session_state, self._tx_state, self),
222248
)
223249

224-
def _commit_call(self, settings: Optional[base.QueryClientSettings]):
250+
def _commit_call(self, settings: Optional[base.QueryClientSettings]) -> "BaseQueryTxContext":
225251
return self._driver(
226252
_create_commit_transaction_request(self._session_state, self._tx_state),
227253
_apis.QueryService.Stub,
@@ -231,7 +257,7 @@ def _commit_call(self, settings: Optional[base.QueryClientSettings]):
231257
(self._session_state, self._tx_state, self),
232258
)
233259

234-
def _rollback_call(self, settings: Optional[base.QueryClientSettings]):
260+
def _rollback_call(self, settings: Optional[base.QueryClientSettings]) -> "BaseQueryTxContext":
235261
return self._driver(
236262
_create_rollback_transaction_request(self._session_state, self._tx_state),
237263
_apis.QueryService.Stub,
@@ -249,7 +275,7 @@ def _execute_call(
249275
exec_mode: base.QueryExecMode = None,
250276
parameters: dict = None,
251277
concurrent_result_sets: bool = False,
252-
):
278+
) -> Iterable[_apis.ydb_query.ExecuteQueryResponsePart]:
253279
request = base.create_execute_query_request(
254280
query=query,
255281
session_id=self._session_state.session_id,
@@ -263,23 +289,24 @@ def _execute_call(
263289
)
264290

265291
return self._driver(
266-
request,
292+
request.to_proto(),
267293
_apis.QueryService.Stub,
268294
_apis.QueryService.ExecuteQuery,
269295
)
270296

271-
def _ensure_prev_stream_finished(self):
297+
def _ensure_prev_stream_finished(self) -> None:
272298
if self._prev_stream is not None:
273299
for _ in self._prev_stream:
274300
pass
275301
self._prev_stream = None
276302

277-
def _handle_tx_meta(self, tx_meta=None):
278-
if not self.tx_id and tx_meta:
279-
self._tx_state._change_state(QueryTxStateEnum.BEGINED)
280-
self._tx_state.tx_id = tx_meta.id
303+
def _move_to_beginned(self, tx_id: str) -> None:
304+
if self._tx_state._already_in(QueryTxStateEnum.BEGINED):
305+
return
306+
self._tx_state._change_state(QueryTxStateEnum.BEGINED)
307+
self._tx_state.tx_id = tx_id
281308

282-
def _move_to_commited(self):
309+
def _move_to_commited(self) -> None:
283310
if self._tx_state._already_in(QueryTxStateEnum.COMMITTED):
284311
return
285312
self._tx_state._change_state(QueryTxStateEnum.COMMITTED)

0 commit comments

Comments
 (0)