Skip to content

Commit 02d6ca6

Browse files
committed
Intermediate changes
commit_hash:bdf980a10265fd16f4aecae0dff78846b425bd8b
1 parent c33aace commit 02d6ca6

File tree

22 files changed

+690
-46
lines changed

22 files changed

+690
-46
lines changed

contrib/python/ydb/py3/.dist-info/METADATA

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Metadata-Version: 2.1
22
Name: ydb
3-
Version: 3.19.3
3+
Version: 3.20.1
44
Summary: YDB Python SDK
55
Home-page: http://github.com/ydb-platform/ydb-python-sdk
66
Author: Yandex LLC

contrib/python/ydb/py3/ya.make

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
PY3_LIBRARY()
44

5-
VERSION(3.19.3)
5+
VERSION(3.20.1)
66

77
LICENSE(Apache-2.0)
88

contrib/python/ydb/py3/ydb/_apis.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ class TopicService(object):
115115
DropTopic = "DropTopic"
116116
StreamRead = "StreamRead"
117117
StreamWrite = "StreamWrite"
118+
UpdateOffsetsInTransaction = "UpdateOffsetsInTransaction"
118119

119120

120121
class QueryService(object):

contrib/python/ydb/py3/ydb/_errors.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
_errors_retriable_fast_backoff_types = [
77
issues.Unavailable,
8+
issues.ClientInternalError,
89
]
910
_errors_retriable_slow_backoff_types = [
1011
issues.Aborted,

contrib/python/ydb/py3/ydb/_grpc/grpcwrapper/ydb_topic.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,18 @@ def from_proto(msg: ydb_topic_pb2.UpdateTokenResponse) -> typing.Any:
141141
########################################################################################################################
142142

143143

144+
@dataclass
145+
class TransactionIdentity(IToProto):
146+
tx_id: str
147+
session_id: str
148+
149+
def to_proto(self) -> ydb_topic_pb2.TransactionIdentity:
150+
return ydb_topic_pb2.TransactionIdentity(
151+
id=self.tx_id,
152+
session=self.session_id,
153+
)
154+
155+
144156
class StreamWriteMessage:
145157
@dataclass()
146158
class InitRequest(IToProto):
@@ -199,6 +211,7 @@ def from_proto(
199211
class WriteRequest(IToProto):
200212
messages: typing.List["StreamWriteMessage.WriteRequest.MessageData"]
201213
codec: int
214+
tx_identity: Optional[TransactionIdentity]
202215

203216
@dataclass
204217
class MessageData(IToProto):
@@ -237,6 +250,9 @@ def to_proto(self) -> ydb_topic_pb2.StreamWriteMessage.WriteRequest:
237250
proto = ydb_topic_pb2.StreamWriteMessage.WriteRequest()
238251
proto.codec = self.codec
239252

253+
if self.tx_identity is not None:
254+
proto.tx.CopyFrom(self.tx_identity.to_proto())
255+
240256
for message in self.messages:
241257
proto_mess = proto.messages.add()
242258
proto_mess.CopyFrom(message.to_proto())
@@ -297,6 +313,8 @@ def from_proto(cls, proto_ack: ydb_topic_pb2.StreamWriteMessage.WriteResponse.Wr
297313
)
298314
except ValueError:
299315
message_write_status = reason
316+
elif proto_ack.HasField("written_in_tx"):
317+
message_write_status = StreamWriteMessage.WriteResponse.WriteAck.StatusWrittenInTx()
300318
else:
301319
raise NotImplementedError("unexpected ack status")
302320

@@ -309,6 +327,9 @@ def from_proto(cls, proto_ack: ydb_topic_pb2.StreamWriteMessage.WriteResponse.Wr
309327
class StatusWritten:
310328
offset: int
311329

330+
class StatusWrittenInTx:
331+
pass
332+
312333
@dataclass
313334
class StatusSkipped:
314335
reason: "StreamWriteMessage.WriteResponse.WriteAck.StatusSkipped.Reason"
@@ -1196,6 +1217,52 @@ def to_public(self) -> ydb_topic_public_types.PublicMeteringMode:
11961217
return ydb_topic_public_types.PublicMeteringMode.UNSPECIFIED
11971218

11981219

1220+
@dataclass
1221+
class UpdateOffsetsInTransactionRequest(IToProto):
1222+
tx: TransactionIdentity
1223+
topics: List[UpdateOffsetsInTransactionRequest.TopicOffsets]
1224+
consumer: str
1225+
1226+
def to_proto(self):
1227+
return ydb_topic_pb2.UpdateOffsetsInTransactionRequest(
1228+
tx=self.tx.to_proto(),
1229+
consumer=self.consumer,
1230+
topics=list(
1231+
map(
1232+
UpdateOffsetsInTransactionRequest.TopicOffsets.to_proto,
1233+
self.topics,
1234+
)
1235+
),
1236+
)
1237+
1238+
@dataclass
1239+
class TopicOffsets(IToProto):
1240+
path: str
1241+
partitions: List[UpdateOffsetsInTransactionRequest.TopicOffsets.PartitionOffsets]
1242+
1243+
def to_proto(self):
1244+
return ydb_topic_pb2.UpdateOffsetsInTransactionRequest.TopicOffsets(
1245+
path=self.path,
1246+
partitions=list(
1247+
map(
1248+
UpdateOffsetsInTransactionRequest.TopicOffsets.PartitionOffsets.to_proto,
1249+
self.partitions,
1250+
)
1251+
),
1252+
)
1253+
1254+
@dataclass
1255+
class PartitionOffsets(IToProto):
1256+
partition_id: int
1257+
partition_offsets: List[OffsetsRange]
1258+
1259+
def to_proto(self) -> ydb_topic_pb2.UpdateOffsetsInTransactionRequest.TopicOffsets.PartitionOffsets:
1260+
return ydb_topic_pb2.UpdateOffsetsInTransactionRequest.TopicOffsets.PartitionOffsets(
1261+
partition_id=self.partition_id,
1262+
partition_offsets=list(map(OffsetsRange.to_proto, self.partition_offsets)),
1263+
)
1264+
1265+
11991266
@dataclass
12001267
class CreateTopicRequest(IToProto, IFromPublic):
12011268
path: str

contrib/python/ydb/py3/ydb/_topic_reader/datatypes.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,9 @@ def ack_notify(self, offset: int):
108108
waiter = self._ack_waiters.popleft()
109109
waiter._finish_ok()
110110

111+
def _update_last_commited_offset_if_needed(self, offset: int):
112+
self.committed_offset = max(self.committed_offset, offset)
113+
111114
def close(self):
112115
if self.closed:
113116
return
@@ -211,3 +214,9 @@ def _pop_batch(self, message_count: int) -> PublicBatch:
211214
self._bytes_size = self._bytes_size - new_batch._bytes_size
212215

213216
return new_batch
217+
218+
def _update_partition_offsets(self, tx, exc=None):
219+
if exc is not None:
220+
return
221+
offsets = self._commit_get_offsets_range()
222+
self._partition_session._update_last_commited_offset_if_needed(offsets.end)

contrib/python/ydb/py3/ydb/_topic_reader/topic_reader_asyncio.py

Lines changed: 121 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import gzip
66
import typing
77
from asyncio import Task
8-
from collections import OrderedDict
8+
from collections import defaultdict, OrderedDict
99
from typing import Optional, Set, Dict, Union, Callable
1010

1111
import ydb
@@ -19,17 +19,24 @@
1919
from .._grpc.grpcwrapper.common_utils import (
2020
IGrpcWrapperAsyncIO,
2121
SupportedDriverType,
22+
to_thread,
2223
GrpcWrapperAsyncIO,
2324
)
2425
from .._grpc.grpcwrapper.ydb_topic import (
2526
StreamReadMessage,
2627
UpdateTokenRequest,
2728
UpdateTokenResponse,
29+
UpdateOffsetsInTransactionRequest,
2830
Codec,
2931
)
3032
from .._errors import check_retriable_error
3133
import logging
3234

35+
from ..query.base import TxEvent
36+
37+
if typing.TYPE_CHECKING:
38+
from ..query.transaction import BaseQueryTxContext
39+
3340
logger = logging.getLogger(__name__)
3441

3542

@@ -77,7 +84,7 @@ def __init__(
7784
):
7885
self._loop = asyncio.get_running_loop()
7986
self._closed = False
80-
self._reconnector = ReaderReconnector(driver, settings)
87+
self._reconnector = ReaderReconnector(driver, settings, self._loop)
8188
self._parent = _parent
8289

8390
async def __aenter__(self):
@@ -88,8 +95,12 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):
8895

8996
def __del__(self):
9097
if not self._closed:
91-
task = self._loop.create_task(self.close(flush=False))
92-
topic_common.wrap_set_name_for_asyncio_task(task, task_name="close reader")
98+
try:
99+
logger.warning("Topic reader was not closed properly. Consider using method close().")
100+
task = self._loop.create_task(self.close(flush=False))
101+
topic_common.wrap_set_name_for_asyncio_task(task, task_name="close reader")
102+
except BaseException:
103+
logger.warning("Something went wrong during reader close in __del__")
93104

94105
async def wait_message(self):
95106
"""
@@ -112,6 +123,23 @@ async def receive_batch(
112123
max_messages=max_messages,
113124
)
114125

126+
async def receive_batch_with_tx(
127+
self,
128+
tx: "BaseQueryTxContext",
129+
max_messages: typing.Union[int, None] = None,
130+
) -> typing.Union[datatypes.PublicBatch, None]:
131+
"""
132+
Get one messages batch with tx from reader.
133+
All messages in a batch from same partition.
134+
135+
use asyncio.wait_for for wait with timeout.
136+
"""
137+
await self._reconnector.wait_message()
138+
return self._reconnector.receive_batch_with_tx_nowait(
139+
tx=tx,
140+
max_messages=max_messages,
141+
)
142+
115143
async def receive_message(self) -> typing.Optional[datatypes.PublicMessage]:
116144
"""
117145
Block until receive new message
@@ -165,18 +193,27 @@ class ReaderReconnector:
165193
_state_changed: asyncio.Event
166194
_stream_reader: Optional["ReaderStream"]
167195
_first_error: asyncio.Future[YdbError]
196+
_tx_to_batches_map: Dict[str, typing.List[datatypes.PublicBatch]]
168197

169-
def __init__(self, driver: Driver, settings: topic_reader.PublicReaderSettings):
198+
def __init__(
199+
self,
200+
driver: Driver,
201+
settings: topic_reader.PublicReaderSettings,
202+
loop: Optional[asyncio.AbstractEventLoop] = None,
203+
):
170204
self._id = self._static_reader_reconnector_counter.inc_and_get()
171205
self._settings = settings
172206
self._driver = driver
207+
self._loop = loop if loop is not None else asyncio.get_running_loop()
173208
self._background_tasks = set()
174209

175210
self._state_changed = asyncio.Event()
176211
self._stream_reader = None
177212
self._background_tasks.add(asyncio.create_task(self._connection_loop()))
178213
self._first_error = asyncio.get_running_loop().create_future()
179214

215+
self._tx_to_batches_map = dict()
216+
180217
async def _connection_loop(self):
181218
attempt = 0
182219
while True:
@@ -190,6 +227,7 @@ async def _connection_loop(self):
190227
if not retry_info.is_retriable:
191228
self._set_first_error(err)
192229
return
230+
193231
await asyncio.sleep(retry_info.sleep_timeout_seconds)
194232

195233
attempt += 1
@@ -222,9 +260,87 @@ def receive_batch_nowait(self, max_messages: Optional[int] = None):
222260
max_messages=max_messages,
223261
)
224262

263+
def receive_batch_with_tx_nowait(self, tx: "BaseQueryTxContext", max_messages: Optional[int] = None):
264+
batch = self._stream_reader.receive_batch_nowait(
265+
max_messages=max_messages,
266+
)
267+
268+
self._init_tx(tx)
269+
270+
self._tx_to_batches_map[tx.tx_id].append(batch)
271+
272+
tx._add_callback(TxEvent.AFTER_COMMIT, batch._update_partition_offsets, self._loop)
273+
274+
return batch
275+
225276
def receive_message_nowait(self):
226277
return self._stream_reader.receive_message_nowait()
227278

279+
def _init_tx(self, tx: "BaseQueryTxContext"):
280+
if tx.tx_id not in self._tx_to_batches_map: # Init tx callbacks
281+
self._tx_to_batches_map[tx.tx_id] = []
282+
tx._add_callback(TxEvent.BEFORE_COMMIT, self._commit_batches_with_tx, self._loop)
283+
tx._add_callback(TxEvent.AFTER_COMMIT, self._handle_after_tx_commit, self._loop)
284+
tx._add_callback(TxEvent.AFTER_ROLLBACK, self._handle_after_tx_rollback, self._loop)
285+
286+
async def _commit_batches_with_tx(self, tx: "BaseQueryTxContext"):
287+
grouped_batches = defaultdict(lambda: defaultdict(list))
288+
for batch in self._tx_to_batches_map[tx.tx_id]:
289+
grouped_batches[batch._partition_session.topic_path][batch._partition_session.partition_id].append(batch)
290+
291+
request = UpdateOffsetsInTransactionRequest(tx=tx._tx_identity(), consumer=self._settings.consumer, topics=[])
292+
293+
for topic_path in grouped_batches:
294+
topic_offsets = UpdateOffsetsInTransactionRequest.TopicOffsets(path=topic_path, partitions=[])
295+
for partition_id in grouped_batches[topic_path]:
296+
partition_offsets = UpdateOffsetsInTransactionRequest.TopicOffsets.PartitionOffsets(
297+
partition_id=partition_id,
298+
partition_offsets=[
299+
batch._commit_get_offsets_range() for batch in grouped_batches[topic_path][partition_id]
300+
],
301+
)
302+
topic_offsets.partitions.append(partition_offsets)
303+
request.topics.append(topic_offsets)
304+
305+
try:
306+
return await self._do_commit_batches_with_tx_call(request)
307+
except BaseException:
308+
exc = issues.ClientInternalError("Failed to update offsets in tx.")
309+
tx._set_external_error(exc)
310+
self._stream_reader._set_first_error(exc)
311+
finally:
312+
del self._tx_to_batches_map[tx.tx_id]
313+
314+
async def _do_commit_batches_with_tx_call(self, request: UpdateOffsetsInTransactionRequest):
315+
args = [
316+
request.to_proto(),
317+
_apis.TopicService.Stub,
318+
_apis.TopicService.UpdateOffsetsInTransaction,
319+
topic_common.wrap_operation,
320+
]
321+
322+
if asyncio.iscoroutinefunction(self._driver.__call__):
323+
res = await self._driver(*args)
324+
else:
325+
res = await to_thread(self._driver, *args, executor=None)
326+
327+
return res
328+
329+
async def _handle_after_tx_rollback(self, tx: "BaseQueryTxContext", exc: Optional[BaseException]) -> None:
330+
if tx.tx_id in self._tx_to_batches_map:
331+
del self._tx_to_batches_map[tx.tx_id]
332+
exc = issues.ClientInternalError("Reconnect due to transaction rollback")
333+
self._stream_reader._set_first_error(exc)
334+
335+
async def _handle_after_tx_commit(self, tx: "BaseQueryTxContext", exc: Optional[BaseException]) -> None:
336+
if tx.tx_id in self._tx_to_batches_map:
337+
del self._tx_to_batches_map[tx.tx_id]
338+
339+
if exc is not None:
340+
self._stream_reader._set_first_error(
341+
issues.ClientInternalError("Reconnect due to transaction commit failed")
342+
)
343+
228344
def commit(self, batch: datatypes.ICommittable) -> datatypes.PartitionSession.CommitAckWaiter:
229345
return self._stream_reader.commit(batch)
230346

0 commit comments

Comments
 (0)