diff --git a/examples/topic/basic_example.py b/examples/topic/basic_example.py index 50dd9a5d..18e9626f 100644 --- a/examples/topic/basic_example.py +++ b/examples/topic/basic_example.py @@ -9,7 +9,7 @@ async def connect(endpoint: str, database: str) -> ydb.aio.Driver: config = ydb.DriverConfig(endpoint=endpoint, database=database) config.credentials = ydb.credentials_from_env_variables() driver = ydb.aio.Driver(config) - await driver.wait(15) + await driver.wait(5, fail_fast=True) return driver @@ -25,7 +25,8 @@ async def create_topic(driver: ydb.aio.Driver, topic: str, consumer: str): async def write_messages(driver: ydb.aio.Driver, topic: str): async with driver.topic_client.writer(topic) as writer: for i in range(10): - await writer.write(f"mess-{i}") + mess = ydb.TopicWriterMessage(data=f"mess-{i}", metadata_items={"index": f"{i}"}) + await writer.write(mess) await asyncio.sleep(1) @@ -38,6 +39,7 @@ async def read_messages(driver: ydb.aio.Driver, topic: str, consumer: str): print(mess.seqno) print(mess.created_at) print(mess.data.decode()) + print(mess.metadata_items) reader.commit(mess) except asyncio.TimeoutError: return diff --git a/tests/conftest.py b/tests/conftest.py index 51b265d7..a8177f46 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -263,6 +263,31 @@ async def topic_with_messages(driver, topic_consumer, database): return topic_path +@pytest.fixture() +@pytest.mark.asyncio() +async def topic_with_messages_with_metadata(driver, topic_consumer, database): + topic_path = database + "/test-topic-with-messages-with-metadata" + try: + await driver.topic_client.drop_topic(topic_path) + except issues.SchemeError: + pass + + await driver.topic_client.create_topic( + path=topic_path, + consumers=[topic_consumer], + ) + + writer = driver.topic_client.writer(topic_path, producer_id="fixture-producer-id", codec=ydb.TopicCodec.RAW) + await writer.write_with_ack( + [ + ydb.TopicWriterMessage(data="123".encode(), metadata_items={"key": "value"}), + ydb.TopicWriterMessage(data="456".encode(), metadata_items={"key": b"value"}), + ] + ) + await writer.close() + return topic_path + + @pytest.fixture() @pytest.mark.asyncio() async def topic_reader(driver, topic_consumer, topic_path) -> ydb.TopicReaderAsyncIO: diff --git a/tests/topics/test_topic_reader.py b/tests/topics/test_topic_reader.py index 74c8bccd..23b5b4be 100644 --- a/tests/topics/test_topic_reader.py +++ b/tests/topics/test_topic_reader.py @@ -30,6 +30,21 @@ async def test_read_message(self, driver, topic_with_messages, topic_consumer): await reader.close() + async def test_read_metadata(self, driver, topic_with_messages_with_metadata, topic_consumer): + reader = driver.topic_client.reader(topic_with_messages_with_metadata, topic_consumer) + + expected_metadata_items = {"key": b"value"} + + for _ in range(2): + await reader.wait_message() + msg = await reader.receive_message() + + assert msg is not None + assert msg.metadata_items + assert msg.metadata_items == expected_metadata_items + + await reader.close() + async def test_read_and_commit_with_close_reader(self, driver, topic_with_messages, topic_consumer): async with driver.topic_client.reader(topic_with_messages, topic_consumer) as reader: message = await reader.receive_message() @@ -135,6 +150,20 @@ def test_read_message(self, driver_sync, topic_with_messages, topic_consumer): reader.close() + def test_read_metadata(self, driver_sync, topic_with_messages_with_metadata, topic_consumer): + reader = driver_sync.topic_client.reader(topic_with_messages_with_metadata, topic_consumer) + + expected_metadata_items = {"key": b"value"} + + for _ in range(2): + msg = reader.receive_message() + + assert msg is not None + assert msg.metadata_items + assert msg.metadata_items == expected_metadata_items + + reader.close() + def test_read_and_commit_with_close_reader(self, driver_sync, topic_with_messages, topic_consumer): with driver_sync.topic_client.reader(topic_with_messages, topic_consumer) as reader: message = reader.receive_message() diff --git a/tests/topics/test_topic_writer.py b/tests/topics/test_topic_writer.py index a3cdbe9d..ba5ae74c 100644 --- a/tests/topics/test_topic_writer.py +++ b/tests/topics/test_topic_writer.py @@ -15,6 +15,11 @@ async def test_send_message(self, driver: ydb.aio.Driver, topic_path): await writer.write(ydb.TopicWriterMessage(data="123".encode())) await writer.close() + async def test_send_message_with_metadata(self, driver: ydb.aio.Driver, topic_path): + writer = driver.topic_client.writer(topic_path, producer_id="test") + await writer.write(ydb.TopicWriterMessage(data="123".encode(), metadata_items={"key": "value"})) + await writer.close() + async def test_wait_last_seqno(self, driver: ydb.aio.Driver, topic_path): async with driver.topic_client.writer( topic_path, @@ -136,6 +141,11 @@ def test_send_message(self, driver_sync: ydb.Driver, topic_path): writer.write(ydb.TopicWriterMessage(data="123".encode())) writer.close() + def test_send_message_with_metadata(self, driver_sync: ydb.Driver, topic_path): + writer = driver_sync.topic_client.writer(topic_path, producer_id="test") + writer.write(ydb.TopicWriterMessage(data="123".encode(), metadata_items={"key": "value"})) + writer.close() + def test_wait_last_seqno(self, driver_sync: ydb.Driver, topic_path): with driver_sync.topic_client.writer( topic_path, diff --git a/ydb/_grpc/grpcwrapper/ydb_topic.py b/ydb/_grpc/grpcwrapper/ydb_topic.py index c1789b6c..ec84ab08 100644 --- a/ydb/_grpc/grpcwrapper/ydb_topic.py +++ b/ydb/_grpc/grpcwrapper/ydb_topic.py @@ -208,6 +208,7 @@ class MessageData(IToProto): data: bytes uncompressed_size: int partitioning: "StreamWriteMessage.PartitioningType" + metadata_items: Dict[str, bytes] def to_proto( self, @@ -218,6 +219,10 @@ def to_proto( proto.data = self.data proto.uncompressed_size = self.uncompressed_size + for key, value in self.metadata_items.items(): + item = ydb_topic_pb2.MetadataItem(key=key, value=value) + proto.metadata_items.append(item) + if self.partitioning is None: pass elif isinstance(self.partitioning, StreamWriteMessage.PartitioningPartitionID): @@ -489,16 +494,19 @@ class MessageData(IFromProto): data: bytes uncompresed_size: int message_group_id: str + metadata_items: Dict[str, bytes] @staticmethod def from_proto( msg: ydb_topic_pb2.StreamReadMessage.ReadResponse.MessageData, ) -> "StreamReadMessage.ReadResponse.MessageData": + metadata_items = {meta.key: meta.value for meta in msg.metadata_items} return StreamReadMessage.ReadResponse.MessageData( offset=msg.offset, seq_no=msg.seq_no, created_at=msg.created_at.ToDatetime(), data=msg.data, + metadata_items=metadata_items, uncompresed_size=msg.uncompressed_size, message_group_id=msg.message_group_id, ) diff --git a/ydb/_topic_reader/datatypes.py b/ydb/_topic_reader/datatypes.py index 01501638..a9c811ac 100644 --- a/ydb/_topic_reader/datatypes.py +++ b/ydb/_topic_reader/datatypes.py @@ -40,6 +40,7 @@ class PublicMessage(ICommittable, ISessionAlive): written_at: datetime.datetime producer_id: str data: Union[bytes, Any] # set as original decompressed bytes or deserialized object if deserializer set in reader + metadata_items: Dict[str, bytes] _partition_session: PartitionSession _commit_start_offset: int _commit_end_offset: int diff --git a/ydb/_topic_reader/topic_reader_asyncio.py b/ydb/_topic_reader/topic_reader_asyncio.py index 351efb9a..e407fe01 100644 --- a/ydb/_topic_reader/topic_reader_asyncio.py +++ b/ydb/_topic_reader/topic_reader_asyncio.py @@ -627,6 +627,7 @@ def _read_response_to_batches(self, message: StreamReadMessage.ReadResponse) -> written_at=server_batch.written_at, producer_id=server_batch.producer_id, data=message_data.data, + metadata_items=message_data.metadata_items, _partition_session=partition_session, _commit_start_offset=partition_session._next_message_start_commit_offset, _commit_end_offset=message_data.offset + 1, diff --git a/ydb/_topic_reader/topic_reader_asyncio_test.py b/ydb/_topic_reader/topic_reader_asyncio_test.py index b9f1e639..25e08029 100644 --- a/ydb/_topic_reader/topic_reader_asyncio_test.py +++ b/ydb/_topic_reader/topic_reader_asyncio_test.py @@ -74,6 +74,7 @@ def stub_message(id: int): written_at=datetime.datetime(2023, 3, 18, 14, 15), producer_id="", data=bytes(), + metadata_items={}, _partition_session=stub_partition_session(), _commit_start_offset=0, _commit_end_offset=1, @@ -207,6 +208,7 @@ def create_message( written_at=datetime.datetime(2023, 2, 3, 14, 16), producer_id="test-producer-id", data=bytes(), + metadata_items={}, _partition_session=partition_session, _commit_start_offset=partition_session._next_message_start_commit_offset + offset_delta - 1, _commit_end_offset=partition_session._next_message_start_commit_offset + offset_delta, @@ -250,6 +252,7 @@ def batch_size(): seq_no=message.seqno, created_at=message.created_at, data=message.data, + metadata_items={}, uncompresed_size=len(message.data), message_group_id=message.message_group_id, ) @@ -445,6 +448,7 @@ async def test_commit_ranges_for_received_messages( written_at=datetime.datetime(2023, 3, 14, 15, 42), producer_id="asd", data=rb"123", + metadata_items={}, _partition_session=None, _commit_start_offset=5, _commit_end_offset=15, @@ -468,6 +472,7 @@ async def test_commit_ranges_for_received_messages( written_at=datetime.datetime(2023, 3, 14, 15, 42), producer_id="asd", data=gzip.compress(rb"123"), + metadata_items={}, _partition_session=None, _commit_start_offset=5, _commit_end_offset=15, @@ -490,6 +495,7 @@ async def test_commit_ranges_for_received_messages( offset=1, written_at=datetime.datetime(2023, 3, 14, 15, 42), producer_id="asd", + metadata_items={}, data=rb"123", _partition_session=None, _commit_start_offset=5, @@ -504,6 +510,7 @@ async def test_commit_ranges_for_received_messages( written_at=datetime.datetime(2023, 3, 14, 15, 42), producer_id="asd", data=rb"456", + metadata_items={}, _partition_session=None, _commit_start_offset=5, _commit_end_offset=15, @@ -527,6 +534,7 @@ async def test_commit_ranges_for_received_messages( written_at=datetime.datetime(2023, 3, 14, 15, 42), producer_id="asd", data=gzip.compress(rb"123"), + metadata_items={}, _partition_session=None, _commit_start_offset=5, _commit_end_offset=15, @@ -540,6 +548,7 @@ async def test_commit_ranges_for_received_messages( written_at=datetime.datetime(2023, 3, 14, 15, 42), producer_id="asd", data=gzip.compress(rb"456"), + metadata_items={}, _partition_session=None, _commit_start_offset=5, _commit_end_offset=15, @@ -766,6 +775,7 @@ async def test_free_buffer_after_partition_stop(self, stream, stream_reader, par seq_no=123, created_at=t, data=bytes(), + metadata_items={}, uncompresed_size=message_size, message_group_id="test-message-group", ) @@ -846,6 +856,7 @@ def reader_batch_count(): created_at=created_at, data=data, uncompresed_size=len(data), + metadata_items={}, message_group_id=message_group_id, ) ], @@ -877,6 +888,7 @@ def reader_batch_count(): written_at=written_at, producer_id=producer_id, data=data, + metadata_items={}, _partition_session=partition_session, _commit_start_offset=expected_message_offset, _commit_end_offset=expected_message_offset + 1, @@ -923,6 +935,7 @@ async def test_read_batches(self, stream_reader, partition_session, second_parti seq_no=3, created_at=created_at, data=data, + metadata_items={}, uncompresed_size=len(data), message_group_id=message_group_id, ) @@ -944,6 +957,7 @@ async def test_read_batches(self, stream_reader, partition_session, second_parti seq_no=2, created_at=created_at2, data=data, + metadata_items={}, uncompresed_size=len(data), message_group_id=message_group_id, ) @@ -960,6 +974,7 @@ async def test_read_batches(self, stream_reader, partition_session, second_parti seq_no=3, created_at=created_at3, data=data2, + metadata_items={}, uncompresed_size=len(data2), message_group_id=message_group_id, ), @@ -968,6 +983,7 @@ async def test_read_batches(self, stream_reader, partition_session, second_parti seq_no=5, created_at=created_at4, data=data, + metadata_items={}, uncompresed_size=len(data), message_group_id=message_group_id2, ), @@ -998,6 +1014,7 @@ async def test_read_batches(self, stream_reader, partition_session, second_parti written_at=written_at, producer_id=producer_id, data=data, + metadata_items={}, _partition_session=partition_session, _commit_start_offset=partition1_mess1_expected_offset, _commit_end_offset=partition1_mess1_expected_offset + 1, @@ -1018,6 +1035,7 @@ async def test_read_batches(self, stream_reader, partition_session, second_parti written_at=written_at2, producer_id=producer_id, data=data, + metadata_items={}, _partition_session=second_partition_session, _commit_start_offset=partition2_mess1_expected_offset, _commit_end_offset=partition2_mess1_expected_offset + 1, @@ -1038,6 +1056,7 @@ async def test_read_batches(self, stream_reader, partition_session, second_parti written_at=written_at2, producer_id=producer_id2, data=data2, + metadata_items={}, _partition_session=second_partition_session, _commit_start_offset=partition2_mess2_expected_offset, _commit_end_offset=partition2_mess2_expected_offset + 1, @@ -1051,6 +1070,7 @@ async def test_read_batches(self, stream_reader, partition_session, second_parti written_at=written_at2, producer_id=producer_id, data=data, + metadata_items={}, _partition_session=second_partition_session, _commit_start_offset=partition2_mess3_expected_offset, _commit_end_offset=partition2_mess3_expected_offset + 1, diff --git a/ydb/_topic_writer/topic_writer.py b/ydb/_topic_writer/topic_writer.py index 527bf03e..aa5fe974 100644 --- a/ydb/_topic_writer/topic_writer.py +++ b/ydb/_topic_writer/topic_writer.py @@ -15,7 +15,7 @@ from .._grpc.grpcwrapper.ydb_topic_public_types import PublicCodec from .. import connection -Message = typing.Union["PublicMessage", "PublicMessage.SimpleMessageSourceType"] +Message = typing.Union["PublicMessage", "PublicMessage.SimpleSourceType"] @dataclass @@ -91,20 +91,23 @@ class PublicWriterInitInfo: class PublicMessage: seqno: Optional[int] created_at: Optional[datetime.datetime] - data: "PublicMessage.SimpleMessageSourceType" + data: "PublicMessage.SimpleSourceType" + metadata_items: Optional[Dict[str, "PublicMessage.SimpleSourceType"]] - SimpleMessageSourceType = Union[str, bytes] # Will be extend + SimpleSourceType = Union[str, bytes] # Will be extend def __init__( self, - data: SimpleMessageSourceType, + data: SimpleSourceType, *, + metadata_items: Optional[Dict[str, "PublicMessage.SimpleSourceType"]] = None, seqno: Optional[int] = None, created_at: Optional[datetime.datetime] = None, ): self.seqno = seqno self.created_at = created_at self.data = data + self.metadata_items = metadata_items @staticmethod def _create_message(data: Message) -> "PublicMessage": @@ -117,30 +120,37 @@ class InternalMessage(StreamWriteMessage.WriteRequest.MessageData, IToProto): codec: PublicCodec def __init__(self, mess: PublicMessage): + metadata_items = mess.metadata_items or {} super().__init__( seq_no=mess.seqno, created_at=mess.created_at, data=mess.data, + metadata_items=metadata_items, uncompressed_size=len(mess.data), partitioning=None, ) self.codec = PublicCodec.RAW - def get_bytes(self) -> bytes: - if self.data is None: + def _get_bytes(self, obj: Optional[PublicMessage.SimpleSourceType]) -> bytes: + if obj is None: return bytes() - if isinstance(self.data, bytes): - return self.data - if isinstance(self.data, str): - return self.data.encode("utf-8") + if isinstance(obj, bytes): + return obj + if isinstance(obj, str): + return obj.encode("utf-8") raise ValueError("Bad data type") + def get_data_bytes(self) -> bytes: + return self._get_bytes(self.data) + def to_message_data(self) -> StreamWriteMessage.WriteRequest.MessageData: - data = self.get_bytes() + data = self.get_data_bytes() + metadata_items = {key: self._get_bytes(value) for key, value in self.metadata_items.items()} return StreamWriteMessage.WriteRequest.MessageData( seq_no=self.seq_no, created_at=self.created_at, data=data, + metadata_items=metadata_items, uncompressed_size=len(data), partitioning=None, # unsupported by server now ) @@ -221,6 +231,7 @@ def messages_to_proto_requests( seq_no=_max_int, created_at=datetime.datetime(3000, 1, 1, 1, 1, 1, 1), data=bytes(1), + metadata_items={}, uncompressed_size=_max_int, partitioning=StreamWriteMessage.PartitioningMessageGroupID( message_group_id="a" * 100, diff --git a/ydb/_topic_writer/topic_writer_asyncio.py b/ydb/_topic_writer/topic_writer_asyncio.py index 869808f7..32d8fefe 100644 --- a/ydb/_topic_writer/topic_writer_asyncio.py +++ b/ydb/_topic_writer/topic_writer_asyncio.py @@ -427,7 +427,7 @@ async def _encode_data_inplace(self, codec: PublicCodec, messages: List[Internal for message in messages: encoded_data_futures = eventloop.run_in_executor( - self._encode_executor, encoder_function, message.get_bytes() + self._encode_executor, encoder_function, message.get_data_bytes() ) encode_waiters.append(encoded_data_futures) @@ -493,7 +493,7 @@ def get_compressed_size(codec) -> int: f = self._codec_functions[codec] for m in test_messages: - encoded = f(m.get_bytes()) + encoded = f(m.get_data_bytes()) s += len(encoded) return s diff --git a/ydb/_topic_writer/topic_writer_asyncio_test.py b/ydb/_topic_writer/topic_writer_asyncio_test.py index dc3f2cad..b288d0f0 100644 --- a/ydb/_topic_writer/topic_writer_asyncio_test.py +++ b/ydb/_topic_writer/topic_writer_asyncio_test.py @@ -153,6 +153,7 @@ async def test_write_a_message(self, writer_and_stream: WriterWithMockedStream): seq_no=1, created_at=now, data=data, + metadata_items={}, uncompressed_size=len(data), partitioning=None, ) @@ -544,7 +545,7 @@ def add_messages(_self, messages: typing.List[InternalMessage]): mess = mess[0] assert mess.codec == expected_codecs[i] - assert mess.get_bytes() == expected_datas[i] + assert mess.get_data_bytes() == expected_datas[i] await reconnector.close(flush=False) @@ -575,7 +576,7 @@ async def test_encode_data_inplace( for index, mess in enumerate(messages): assert mess.codec == codec - assert mess.get_bytes() == expected_datas[index] + assert mess.get_data_bytes() == expected_datas[index] await reconnector.close(flush=True)