diff --git a/ydb/_grpc/grpcwrapper/common_utils.py b/ydb/_grpc/grpcwrapper/common_utils.py index a7febd5b..e8815333 100644 --- a/ydb/_grpc/grpcwrapper/common_utils.py +++ b/ydb/_grpc/grpcwrapper/common_utils.py @@ -145,7 +145,7 @@ def close(self): SupportedDriverType = Union[ydb.Driver, ydb.aio.Driver] -class GrpcWrapperAsyncIO(IGrpcWrapperAsyncIO): +class AbstractGrpcWrapperAsyncIO(IGrpcWrapperAsyncIO, abc.ABC): from_client_grpc: asyncio.Queue from_server_grpc: AsyncIterator convert_server_grpc_to_wrapper: Callable[[Any], Any] @@ -163,13 +163,6 @@ def __init__(self, convert_server_grpc_to_wrapper): def __del__(self): self._clean_executor(wait=False) - async def start(self, driver: SupportedDriverType, stub, method): - if asyncio.iscoroutinefunction(driver.__call__): - await self._start_asyncio_driver(driver, stub, method) - else: - await self._start_sync_driver(driver, stub, method) - self._connection_state = "started" - def close(self): self.from_client_grpc.put_nowait(_stop_grpc_connection_marker) if self._stream_call: @@ -181,6 +174,35 @@ def _clean_executor(self, wait: bool): if self._wait_executor: self._wait_executor.shutdown(wait) + async def receive(self) -> Any: + # todo handle grpc exceptions and convert it to internal exceptions + try: + grpc_message = await self.from_server_grpc.__anext__() + except (grpc.RpcError, grpc.aio.AioRpcError) as e: + raise connection._rpc_error_handler(self._connection_state, e) + + issues._process_response(grpc_message) + + if self._connection_state != "has_received_messages": + self._connection_state = "has_received_messages" + + # print("rekby, grpc, received", grpc_message) + return self.convert_server_grpc_to_wrapper(grpc_message) + + def write(self, wrap_message: IToProto): + grpc_message = wrap_message.to_proto() + # print("rekby, grpc, send", grpc_message) + self.from_client_grpc.put_nowait(grpc_message) + + +class GrpcWrapperStreamStreamAsyncIO(AbstractGrpcWrapperAsyncIO): + async def start(self, driver: SupportedDriverType, stub, method): + if asyncio.iscoroutinefunction(driver.__call__): + await self._start_asyncio_driver(driver, stub, method) + else: + await self._start_sync_driver(driver, stub, method) + self._connection_state = "started" + async def _start_asyncio_driver(self, driver: ydb.aio.Driver, stub, method): requests_iterator = QueueToIteratorAsyncIO(self.from_client_grpc) stream_call = await driver( @@ -199,25 +221,30 @@ async def _start_sync_driver(self, driver: ydb.Driver, stub, method): self._stream_call = stream_call self.from_server_grpc = SyncToAsyncIterator(stream_call.__iter__(), self._wait_executor) - async def receive(self) -> Any: - # todo handle grpc exceptions and convert it to internal exceptions - try: - grpc_message = await self.from_server_grpc.__anext__() - except (grpc.RpcError, grpc.aio.AioRpcError) as e: - raise connection._rpc_error_handler(self._connection_state, e) - issues._process_response(grpc_message) +class GrpcWrapperUnaryStreamAsyncIO(AbstractGrpcWrapperAsyncIO): + async def start(self, driver: SupportedDriverType, request, stub, method): + if asyncio.iscoroutinefunction(driver.__call__): + await self._start_asyncio_driver(driver, request, stub, method) + else: + await self._start_sync_driver(driver, request, stub, method) + self._connection_state = "started" - if self._connection_state != "has_received_messages": - self._connection_state = "has_received_messages" + async def _start_asyncio_driver(self, driver: ydb.aio.Driver, request, stub, method): + stream_call = await driver( + request, + stub, + method, + ) + self._stream_call = stream_call + self.from_server_grpc = stream_call.__aiter__() - # print("rekby, grpc, received", grpc_message) - return self.convert_server_grpc_to_wrapper(grpc_message) + async def _start_sync_driver(self, driver: ydb.Driver, request, stub, method): + self._wait_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) - def write(self, wrap_message: IToProto): - grpc_message = wrap_message.to_proto() - # print("rekby, grpc, send", grpc_message) - self.from_client_grpc.put_nowait(grpc_message) + stream_call = await to_thread(driver, request, stub, method, executor=self._wait_executor) + self._stream_call = stream_call + self.from_server_grpc = SyncToAsyncIterator(stream_call.__iter__(), self._wait_executor) @dataclass(init=False) diff --git a/ydb/_topic_common/common_test.py b/ydb/_topic_common/common_test.py index b31f9af9..1dadaa04 100644 --- a/ydb/_topic_common/common_test.py +++ b/ydb/_topic_common/common_test.py @@ -8,7 +8,7 @@ from .common import CallFromSyncToAsync from .._grpc.grpcwrapper.common_utils import ( - GrpcWrapperAsyncIO, + GrpcWrapperStreamStreamAsyncIO, ServerStatus, callback_from_asyncio, ) @@ -77,7 +77,7 @@ async def async_failed(): @pytest.mark.asyncio -class TestGrpcWrapperAsyncIO: +class TestGrpcWrapperStreamStreamAsyncIO: async def test_convert_grpc_errors_to_ydb(self): class TestError(grpc.RpcError, grpc.Call): def __init__(self): @@ -93,7 +93,7 @@ class FromServerMock: async def __anext__(self): raise TestError() - wrapper = GrpcWrapperAsyncIO(lambda: None) + wrapper = GrpcWrapperStreamStreamAsyncIO(lambda: None) wrapper.from_server_grpc = FromServerMock() with pytest.raises(issues.Unauthenticated): @@ -107,7 +107,7 @@ async def __anext__(self): issues=[], ) - wrapper = GrpcWrapperAsyncIO(lambda: None) + wrapper = GrpcWrapperStreamStreamAsyncIO(lambda: None) wrapper.from_server_grpc = FromServerMock() with pytest.raises(issues.Overloaded): diff --git a/ydb/_topic_reader/topic_reader_asyncio.py b/ydb/_topic_reader/topic_reader_asyncio.py index 81c6d9f4..8cc48a1d 100644 --- a/ydb/_topic_reader/topic_reader_asyncio.py +++ b/ydb/_topic_reader/topic_reader_asyncio.py @@ -18,7 +18,7 @@ from .._grpc.grpcwrapper.common_utils import ( IGrpcWrapperAsyncIO, SupportedDriverType, - GrpcWrapperAsyncIO, + GrpcWrapperStreamStreamAsyncIO, ) from .._grpc.grpcwrapper.ydb_topic import ( StreamReadMessage, @@ -308,7 +308,7 @@ async def create( driver: SupportedDriverType, settings: topic_reader.PublicReaderSettings, ) -> "ReaderStream": - stream = GrpcWrapperAsyncIO(StreamReadMessage.FromServer.from_proto) + stream = GrpcWrapperStreamStreamAsyncIO(StreamReadMessage.FromServer.from_proto) await stream.start(driver, _apis.TopicService.Stub, _apis.TopicService.StreamRead) diff --git a/ydb/_topic_writer/topic_writer_asyncio.py b/ydb/_topic_writer/topic_writer_asyncio.py index 007c8a54..064f19ce 100644 --- a/ydb/_topic_writer/topic_writer_asyncio.py +++ b/ydb/_topic_writer/topic_writer_asyncio.py @@ -39,7 +39,7 @@ from .._grpc.grpcwrapper.common_utils import ( IGrpcWrapperAsyncIO, SupportedDriverType, - GrpcWrapperAsyncIO, + GrpcWrapperStreamStreamAsyncIO, ) logger = logging.getLogger(__name__) @@ -613,7 +613,7 @@ async def create( init_request: StreamWriteMessage.InitRequest, update_token_interval: Optional[Union[int, float]] = None, ) -> "WriterAsyncIOStream": - stream = GrpcWrapperAsyncIO(StreamWriteMessage.FromServer.from_proto) + stream = GrpcWrapperStreamStreamAsyncIO(StreamWriteMessage.FromServer.from_proto) await stream.start(driver, _apis.TopicService.Stub, _apis.TopicService.StreamWrite)