5
5
import gzip
6
6
import typing
7
7
from asyncio import Task
8
- from collections import OrderedDict
8
+ from collections import defaultdict , OrderedDict
9
9
from typing import Optional , Set , Dict , Union , Callable
10
10
11
11
import ydb
19
19
from .._grpc .grpcwrapper .common_utils import (
20
20
IGrpcWrapperAsyncIO ,
21
21
SupportedDriverType ,
22
+ to_thread ,
22
23
GrpcWrapperAsyncIO ,
23
24
)
24
25
from .._grpc .grpcwrapper .ydb_topic import (
25
26
StreamReadMessage ,
26
27
UpdateTokenRequest ,
27
28
UpdateTokenResponse ,
29
+ UpdateOffsetsInTransactionRequest ,
28
30
Codec ,
29
31
)
30
32
from .._errors import check_retriable_error
31
33
import logging
32
34
35
+ from ..query .base import TxEvent
36
+
37
+ if typing .TYPE_CHECKING :
38
+ from ..query .transaction import BaseQueryTxContext
39
+
33
40
logger = logging .getLogger (__name__ )
34
41
35
42
@@ -77,7 +84,7 @@ def __init__(
77
84
):
78
85
self ._loop = asyncio .get_running_loop ()
79
86
self ._closed = False
80
- self ._reconnector = ReaderReconnector (driver , settings )
87
+ self ._reconnector = ReaderReconnector (driver , settings , self . _loop )
81
88
self ._parent = _parent
82
89
83
90
async def __aenter__ (self ):
@@ -88,8 +95,12 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):
88
95
89
96
def __del__ (self ):
90
97
if not self ._closed :
91
- task = self ._loop .create_task (self .close (flush = False ))
92
- topic_common .wrap_set_name_for_asyncio_task (task , task_name = "close reader" )
98
+ try :
99
+ logger .warning ("Topic reader was not closed properly. Consider using method close()." )
100
+ task = self ._loop .create_task (self .close (flush = False ))
101
+ topic_common .wrap_set_name_for_asyncio_task (task , task_name = "close reader" )
102
+ except BaseException :
103
+ logger .warning ("Something went wrong during reader close in __del__" )
93
104
94
105
async def wait_message (self ):
95
106
"""
@@ -112,6 +123,23 @@ async def receive_batch(
112
123
max_messages = max_messages ,
113
124
)
114
125
126
+ async def receive_batch_with_tx (
127
+ self ,
128
+ tx : "BaseQueryTxContext" ,
129
+ max_messages : typing .Union [int , None ] = None ,
130
+ ) -> typing .Union [datatypes .PublicBatch , None ]:
131
+ """
132
+ Get one messages batch with tx from reader.
133
+ All messages in a batch from same partition.
134
+
135
+ use asyncio.wait_for for wait with timeout.
136
+ """
137
+ await self ._reconnector .wait_message ()
138
+ return self ._reconnector .receive_batch_with_tx_nowait (
139
+ tx = tx ,
140
+ max_messages = max_messages ,
141
+ )
142
+
115
143
async def receive_message (self ) -> typing .Optional [datatypes .PublicMessage ]:
116
144
"""
117
145
Block until receive new message
@@ -165,18 +193,27 @@ class ReaderReconnector:
165
193
_state_changed : asyncio .Event
166
194
_stream_reader : Optional ["ReaderStream" ]
167
195
_first_error : asyncio .Future [YdbError ]
196
+ _tx_to_batches_map : Dict [str , typing .List [datatypes .PublicBatch ]]
168
197
169
- def __init__ (self , driver : Driver , settings : topic_reader .PublicReaderSettings ):
198
+ def __init__ (
199
+ self ,
200
+ driver : Driver ,
201
+ settings : topic_reader .PublicReaderSettings ,
202
+ loop : Optional [asyncio .AbstractEventLoop ] = None ,
203
+ ):
170
204
self ._id = self ._static_reader_reconnector_counter .inc_and_get ()
171
205
self ._settings = settings
172
206
self ._driver = driver
207
+ self ._loop = loop if loop is not None else asyncio .get_running_loop ()
173
208
self ._background_tasks = set ()
174
209
175
210
self ._state_changed = asyncio .Event ()
176
211
self ._stream_reader = None
177
212
self ._background_tasks .add (asyncio .create_task (self ._connection_loop ()))
178
213
self ._first_error = asyncio .get_running_loop ().create_future ()
179
214
215
+ self ._tx_to_batches_map = dict ()
216
+
180
217
async def _connection_loop (self ):
181
218
attempt = 0
182
219
while True :
@@ -190,6 +227,7 @@ async def _connection_loop(self):
190
227
if not retry_info .is_retriable :
191
228
self ._set_first_error (err )
192
229
return
230
+
193
231
await asyncio .sleep (retry_info .sleep_timeout_seconds )
194
232
195
233
attempt += 1
@@ -222,9 +260,87 @@ def receive_batch_nowait(self, max_messages: Optional[int] = None):
222
260
max_messages = max_messages ,
223
261
)
224
262
263
+ def receive_batch_with_tx_nowait (self , tx : "BaseQueryTxContext" , max_messages : Optional [int ] = None ):
264
+ batch = self ._stream_reader .receive_batch_nowait (
265
+ max_messages = max_messages ,
266
+ )
267
+
268
+ self ._init_tx (tx )
269
+
270
+ self ._tx_to_batches_map [tx .tx_id ].append (batch )
271
+
272
+ tx ._add_callback (TxEvent .AFTER_COMMIT , batch ._update_partition_offsets , self ._loop )
273
+
274
+ return batch
275
+
225
276
def receive_message_nowait (self ):
226
277
return self ._stream_reader .receive_message_nowait ()
227
278
279
+ def _init_tx (self , tx : "BaseQueryTxContext" ):
280
+ if tx .tx_id not in self ._tx_to_batches_map : # Init tx callbacks
281
+ self ._tx_to_batches_map [tx .tx_id ] = []
282
+ tx ._add_callback (TxEvent .BEFORE_COMMIT , self ._commit_batches_with_tx , self ._loop )
283
+ tx ._add_callback (TxEvent .AFTER_COMMIT , self ._handle_after_tx_commit , self ._loop )
284
+ tx ._add_callback (TxEvent .AFTER_ROLLBACK , self ._handle_after_tx_rollback , self ._loop )
285
+
286
+ async def _commit_batches_with_tx (self , tx : "BaseQueryTxContext" ):
287
+ grouped_batches = defaultdict (lambda : defaultdict (list ))
288
+ for batch in self ._tx_to_batches_map [tx .tx_id ]:
289
+ grouped_batches [batch ._partition_session .topic_path ][batch ._partition_session .partition_id ].append (batch )
290
+
291
+ request = UpdateOffsetsInTransactionRequest (tx = tx ._tx_identity (), consumer = self ._settings .consumer , topics = [])
292
+
293
+ for topic_path in grouped_batches :
294
+ topic_offsets = UpdateOffsetsInTransactionRequest .TopicOffsets (path = topic_path , partitions = [])
295
+ for partition_id in grouped_batches [topic_path ]:
296
+ partition_offsets = UpdateOffsetsInTransactionRequest .TopicOffsets .PartitionOffsets (
297
+ partition_id = partition_id ,
298
+ partition_offsets = [
299
+ batch ._commit_get_offsets_range () for batch in grouped_batches [topic_path ][partition_id ]
300
+ ],
301
+ )
302
+ topic_offsets .partitions .append (partition_offsets )
303
+ request .topics .append (topic_offsets )
304
+
305
+ try :
306
+ return await self ._do_commit_batches_with_tx_call (request )
307
+ except BaseException :
308
+ exc = issues .ClientInternalError ("Failed to update offsets in tx." )
309
+ tx ._set_external_error (exc )
310
+ self ._stream_reader ._set_first_error (exc )
311
+ finally :
312
+ del self ._tx_to_batches_map [tx .tx_id ]
313
+
314
+ async def _do_commit_batches_with_tx_call (self , request : UpdateOffsetsInTransactionRequest ):
315
+ args = [
316
+ request .to_proto (),
317
+ _apis .TopicService .Stub ,
318
+ _apis .TopicService .UpdateOffsetsInTransaction ,
319
+ topic_common .wrap_operation ,
320
+ ]
321
+
322
+ if asyncio .iscoroutinefunction (self ._driver .__call__ ):
323
+ res = await self ._driver (* args )
324
+ else :
325
+ res = await to_thread (self ._driver , * args , executor = None )
326
+
327
+ return res
328
+
329
+ async def _handle_after_tx_rollback (self , tx : "BaseQueryTxContext" , exc : Optional [BaseException ]) -> None :
330
+ if tx .tx_id in self ._tx_to_batches_map :
331
+ del self ._tx_to_batches_map [tx .tx_id ]
332
+ exc = issues .ClientInternalError ("Reconnect due to transaction rollback" )
333
+ self ._stream_reader ._set_first_error (exc )
334
+
335
+ async def _handle_after_tx_commit (self , tx : "BaseQueryTxContext" , exc : Optional [BaseException ]) -> None :
336
+ if tx .tx_id in self ._tx_to_batches_map :
337
+ del self ._tx_to_batches_map [tx .tx_id ]
338
+
339
+ if exc is not None :
340
+ self ._stream_reader ._set_first_error (
341
+ issues .ClientInternalError ("Reconnect due to transaction commit failed" )
342
+ )
343
+
228
344
def commit (self , batch : datatypes .ICommittable ) -> datatypes .PartitionSession .CommitAckWaiter :
229
345
return self ._stream_reader .commit (batch )
230
346
0 commit comments