Skip to content

Refactor wrappers to split UnaryStream and StreamStream wrappers #451

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 50 additions & 23 deletions ydb/_grpc/grpcwrapper/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def close(self):
SupportedDriverType = Union[ydb.Driver, ydb.aio.Driver]


class GrpcWrapperAsyncIO(IGrpcWrapperAsyncIO):
class AbstractGrpcWrapperAsyncIO(IGrpcWrapperAsyncIO, abc.ABC):
from_client_grpc: asyncio.Queue
from_server_grpc: AsyncIterator
convert_server_grpc_to_wrapper: Callable[[Any], Any]
Expand All @@ -163,13 +163,6 @@ def __init__(self, convert_server_grpc_to_wrapper):
def __del__(self):
self._clean_executor(wait=False)

async def start(self, driver: SupportedDriverType, stub, method):
if asyncio.iscoroutinefunction(driver.__call__):
await self._start_asyncio_driver(driver, stub, method)
else:
await self._start_sync_driver(driver, stub, method)
self._connection_state = "started"

def close(self):
self.from_client_grpc.put_nowait(_stop_grpc_connection_marker)
if self._stream_call:
Expand All @@ -181,6 +174,35 @@ def _clean_executor(self, wait: bool):
if self._wait_executor:
self._wait_executor.shutdown(wait)

async def receive(self) -> Any:
# todo handle grpc exceptions and convert it to internal exceptions
try:
grpc_message = await self.from_server_grpc.__anext__()
except (grpc.RpcError, grpc.aio.AioRpcError) as e:
raise connection._rpc_error_handler(self._connection_state, e)

issues._process_response(grpc_message)

if self._connection_state != "has_received_messages":
self._connection_state = "has_received_messages"

# print("rekby, grpc, received", grpc_message)
return self.convert_server_grpc_to_wrapper(grpc_message)

def write(self, wrap_message: IToProto):
grpc_message = wrap_message.to_proto()
# print("rekby, grpc, send", grpc_message)
self.from_client_grpc.put_nowait(grpc_message)


class GrpcWrapperStreamStreamAsyncIO(AbstractGrpcWrapperAsyncIO):
async def start(self, driver: SupportedDriverType, stub, method):
if asyncio.iscoroutinefunction(driver.__call__):
await self._start_asyncio_driver(driver, stub, method)
else:
await self._start_sync_driver(driver, stub, method)
self._connection_state = "started"

async def _start_asyncio_driver(self, driver: ydb.aio.Driver, stub, method):
requests_iterator = QueueToIteratorAsyncIO(self.from_client_grpc)
stream_call = await driver(
Expand All @@ -199,25 +221,30 @@ async def _start_sync_driver(self, driver: ydb.Driver, stub, method):
self._stream_call = stream_call
self.from_server_grpc = SyncToAsyncIterator(stream_call.__iter__(), self._wait_executor)

async def receive(self) -> Any:
# todo handle grpc exceptions and convert it to internal exceptions
try:
grpc_message = await self.from_server_grpc.__anext__()
except (grpc.RpcError, grpc.aio.AioRpcError) as e:
raise connection._rpc_error_handler(self._connection_state, e)

issues._process_response(grpc_message)
class GrpcWrapperUnaryStreamAsyncIO(AbstractGrpcWrapperAsyncIO):
async def start(self, driver: SupportedDriverType, request, stub, method):
if asyncio.iscoroutinefunction(driver.__call__):
await self._start_asyncio_driver(driver, request, stub, method)
else:
await self._start_sync_driver(driver, request, stub, method)
self._connection_state = "started"

if self._connection_state != "has_received_messages":
self._connection_state = "has_received_messages"
async def _start_asyncio_driver(self, driver: ydb.aio.Driver, request, stub, method):
stream_call = await driver(
request,
stub,
method,
)
self._stream_call = stream_call
self.from_server_grpc = stream_call.__aiter__()

# print("rekby, grpc, received", grpc_message)
return self.convert_server_grpc_to_wrapper(grpc_message)
async def _start_sync_driver(self, driver: ydb.Driver, request, stub, method):
self._wait_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)

def write(self, wrap_message: IToProto):
grpc_message = wrap_message.to_proto()
# print("rekby, grpc, send", grpc_message)
self.from_client_grpc.put_nowait(grpc_message)
stream_call = await to_thread(driver, request, stub, method, executor=self._wait_executor)
self._stream_call = stream_call
self.from_server_grpc = SyncToAsyncIterator(stream_call.__iter__(), self._wait_executor)


@dataclass(init=False)
Expand Down
8 changes: 4 additions & 4 deletions ydb/_topic_common/common_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from .common import CallFromSyncToAsync
from .._grpc.grpcwrapper.common_utils import (
GrpcWrapperAsyncIO,
GrpcWrapperStreamStreamAsyncIO,
ServerStatus,
callback_from_asyncio,
)
Expand Down Expand Up @@ -77,7 +77,7 @@ async def async_failed():


@pytest.mark.asyncio
class TestGrpcWrapperAsyncIO:
class TestGrpcWrapperStreamStreamAsyncIO:
async def test_convert_grpc_errors_to_ydb(self):
class TestError(grpc.RpcError, grpc.Call):
def __init__(self):
Expand All @@ -93,7 +93,7 @@ class FromServerMock:
async def __anext__(self):
raise TestError()

wrapper = GrpcWrapperAsyncIO(lambda: None)
wrapper = GrpcWrapperStreamStreamAsyncIO(lambda: None)
wrapper.from_server_grpc = FromServerMock()

with pytest.raises(issues.Unauthenticated):
Expand All @@ -107,7 +107,7 @@ async def __anext__(self):
issues=[],
)

wrapper = GrpcWrapperAsyncIO(lambda: None)
wrapper = GrpcWrapperStreamStreamAsyncIO(lambda: None)
wrapper.from_server_grpc = FromServerMock()

with pytest.raises(issues.Overloaded):
Expand Down
4 changes: 2 additions & 2 deletions ydb/_topic_reader/topic_reader_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from .._grpc.grpcwrapper.common_utils import (
IGrpcWrapperAsyncIO,
SupportedDriverType,
GrpcWrapperAsyncIO,
GrpcWrapperStreamStreamAsyncIO,
)
from .._grpc.grpcwrapper.ydb_topic import (
StreamReadMessage,
Expand Down Expand Up @@ -308,7 +308,7 @@ async def create(
driver: SupportedDriverType,
settings: topic_reader.PublicReaderSettings,
) -> "ReaderStream":
stream = GrpcWrapperAsyncIO(StreamReadMessage.FromServer.from_proto)
stream = GrpcWrapperStreamStreamAsyncIO(StreamReadMessage.FromServer.from_proto)

await stream.start(driver, _apis.TopicService.Stub, _apis.TopicService.StreamRead)

Expand Down
4 changes: 2 additions & 2 deletions ydb/_topic_writer/topic_writer_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from .._grpc.grpcwrapper.common_utils import (
IGrpcWrapperAsyncIO,
SupportedDriverType,
GrpcWrapperAsyncIO,
GrpcWrapperStreamStreamAsyncIO,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -613,7 +613,7 @@ async def create(
init_request: StreamWriteMessage.InitRequest,
update_token_interval: Optional[Union[int, float]] = None,
) -> "WriterAsyncIOStream":
stream = GrpcWrapperAsyncIO(StreamWriteMessage.FromServer.from_proto)
stream = GrpcWrapperStreamStreamAsyncIO(StreamWriteMessage.FromServer.from_proto)

await stream.start(driver, _apis.TopicService.Stub, _apis.TopicService.StreamWrite)

Expand Down
Loading