Skip to content

Commit 6a740e4

Browse files
committed
Revert "Refactor wrappers to split UnaryStream and StreamStream wrappers"
This reverts commit 0ffe545.
1 parent c90ccf4 commit 6a740e4

File tree

4 files changed

+31
-71
lines changed

4 files changed

+31
-71
lines changed

ydb/_grpc/grpcwrapper/common_utils.py

Lines changed: 23 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def close(self):
145145
SupportedDriverType = Union[ydb.Driver, ydb.aio.Driver]
146146

147147

148-
class AbstractGrpcWrapperAsyncIO(IGrpcWrapperAsyncIO, abc.ABC):
148+
class GrpcWrapperAsyncIO(IGrpcWrapperAsyncIO):
149149
from_client_grpc: asyncio.Queue
150150
from_server_grpc: AsyncIterator
151151
convert_server_grpc_to_wrapper: Callable[[Any], Any]
@@ -163,6 +163,13 @@ def __init__(self, convert_server_grpc_to_wrapper):
163163
def __del__(self):
164164
self._clean_executor(wait=False)
165165

166+
async def start(self, driver: SupportedDriverType, stub, method):
167+
if asyncio.iscoroutinefunction(driver.__call__):
168+
await self._start_asyncio_driver(driver, stub, method)
169+
else:
170+
await self._start_sync_driver(driver, stub, method)
171+
self._connection_state = "started"
172+
166173
def close(self):
167174
self.from_client_grpc.put_nowait(_stop_grpc_connection_marker)
168175
if self._stream_call:
@@ -174,35 +181,6 @@ def _clean_executor(self, wait: bool):
174181
if self._wait_executor:
175182
self._wait_executor.shutdown(wait)
176183

177-
async def receive(self) -> Any:
178-
# todo handle grpc exceptions and convert it to internal exceptions
179-
try:
180-
grpc_message = await self.from_server_grpc.__anext__()
181-
except (grpc.RpcError, grpc.aio.AioRpcError) as e:
182-
raise connection._rpc_error_handler(self._connection_state, e)
183-
184-
issues._process_response(grpc_message)
185-
186-
if self._connection_state != "has_received_messages":
187-
self._connection_state = "has_received_messages"
188-
189-
# print("rekby, grpc, received", grpc_message)
190-
return self.convert_server_grpc_to_wrapper(grpc_message)
191-
192-
def write(self, wrap_message: IToProto):
193-
grpc_message = wrap_message.to_proto()
194-
# print("rekby, grpc, send", grpc_message)
195-
self.from_client_grpc.put_nowait(grpc_message)
196-
197-
198-
class GrpcWrapperStreamStreamAsyncIO(AbstractGrpcWrapperAsyncIO):
199-
async def start(self, driver: SupportedDriverType, stub, method):
200-
if asyncio.iscoroutinefunction(driver.__call__):
201-
await self._start_asyncio_driver(driver, stub, method)
202-
else:
203-
await self._start_sync_driver(driver, stub, method)
204-
self._connection_state = "started"
205-
206184
async def _start_asyncio_driver(self, driver: ydb.aio.Driver, stub, method):
207185
requests_iterator = QueueToIteratorAsyncIO(self.from_client_grpc)
208186
stream_call = await driver(
@@ -221,30 +199,25 @@ async def _start_sync_driver(self, driver: ydb.Driver, stub, method):
221199
self._stream_call = stream_call
222200
self.from_server_grpc = SyncToAsyncIterator(stream_call.__iter__(), self._wait_executor)
223201

202+
async def receive(self) -> Any:
203+
# todo handle grpc exceptions and convert it to internal exceptions
204+
try:
205+
grpc_message = await self.from_server_grpc.__anext__()
206+
except (grpc.RpcError, grpc.aio.AioRpcError) as e:
207+
raise connection._rpc_error_handler(self._connection_state, e)
224208

225-
class GrpcWrapperUnaryStreamAsyncIO(AbstractGrpcWrapperAsyncIO):
226-
async def start(self, driver: SupportedDriverType, request, stub, method):
227-
if asyncio.iscoroutinefunction(driver.__call__):
228-
await self._start_asyncio_driver(driver, request, stub, method)
229-
else:
230-
await self._start_sync_driver(driver, request, stub, method)
231-
self._connection_state = "started"
209+
issues._process_response(grpc_message)
232210

233-
async def _start_asyncio_driver(self, driver: ydb.aio.Driver, request, stub, method):
234-
stream_call = await driver(
235-
request,
236-
stub,
237-
method,
238-
)
239-
self._stream_call = stream_call
240-
self.from_server_grpc = stream_call.__aiter__()
211+
if self._connection_state != "has_received_messages":
212+
self._connection_state = "has_received_messages"
241213

242-
async def _start_sync_driver(self, driver: ydb.Driver, request, stub, method):
243-
self._wait_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
214+
# print("rekby, grpc, received", grpc_message)
215+
return self.convert_server_grpc_to_wrapper(grpc_message)
244216

245-
stream_call = await to_thread(driver, request, stub, method, executor=self._wait_executor)
246-
self._stream_call = stream_call
247-
self.from_server_grpc = SyncToAsyncIterator(stream_call.__iter__(), self._wait_executor)
217+
def write(self, wrap_message: IToProto):
218+
grpc_message = wrap_message.to_proto()
219+
# print("rekby, grpc, send", grpc_message)
220+
self.from_client_grpc.put_nowait(grpc_message)
248221

249222

250223
@dataclass(init=False)
@@ -283,19 +256,6 @@ def issue_to_str(cls, issue: ydb_issue_message_pb2.IssueMessage):
283256
return res
284257

285258

286-
ResultType = typing.TypeVar("ResultType", bound=IFromProtoWithProtoType)
287-
288-
289-
def create_result_wrapper(
290-
result_type: typing.Type[ResultType],
291-
) -> typing.Callable[[typing.Any, typing.Any, typing.Any], ResultType]:
292-
def wrapper(rpc_state, response_pb, driver=None):
293-
# issues._process_response(response_pb.operation)
294-
return result_type.from_proto(response_pb)
295-
296-
return wrapper
297-
298-
299259
def callback_from_asyncio(callback: Union[Callable, Coroutine]) -> [asyncio.Future, asyncio.Task]:
300260
loop = asyncio.get_running_loop()
301261

ydb/_topic_common/common_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from .common import CallFromSyncToAsync
1010
from .._grpc.grpcwrapper.common_utils import (
11-
GrpcWrapperStreamStreamAsyncIO,
11+
GrpcWrapperAsyncIO,
1212
ServerStatus,
1313
callback_from_asyncio,
1414
)
@@ -77,7 +77,7 @@ async def async_failed():
7777

7878

7979
@pytest.mark.asyncio
80-
class TestGrpcWrapperStreamStreamAsyncIO:
80+
class TestGrpcWrapperAsyncIO:
8181
async def test_convert_grpc_errors_to_ydb(self):
8282
class TestError(grpc.RpcError, grpc.Call):
8383
def __init__(self):
@@ -93,7 +93,7 @@ class FromServerMock:
9393
async def __anext__(self):
9494
raise TestError()
9595

96-
wrapper = GrpcWrapperStreamStreamAsyncIO(lambda: None)
96+
wrapper = GrpcWrapperAsyncIO(lambda: None)
9797
wrapper.from_server_grpc = FromServerMock()
9898

9999
with pytest.raises(issues.Unauthenticated):
@@ -107,7 +107,7 @@ async def __anext__(self):
107107
issues=[],
108108
)
109109

110-
wrapper = GrpcWrapperStreamStreamAsyncIO(lambda: None)
110+
wrapper = GrpcWrapperAsyncIO(lambda: None)
111111
wrapper.from_server_grpc = FromServerMock()
112112

113113
with pytest.raises(issues.Overloaded):

ydb/_topic_reader/topic_reader_asyncio.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from .._grpc.grpcwrapper.common_utils import (
1919
IGrpcWrapperAsyncIO,
2020
SupportedDriverType,
21-
GrpcWrapperStreamStreamAsyncIO,
21+
GrpcWrapperAsyncIO,
2222
)
2323
from .._grpc.grpcwrapper.ydb_topic import (
2424
StreamReadMessage,
@@ -308,7 +308,7 @@ async def create(
308308
driver: SupportedDriverType,
309309
settings: topic_reader.PublicReaderSettings,
310310
) -> "ReaderStream":
311-
stream = GrpcWrapperStreamStreamAsyncIO(StreamReadMessage.FromServer.from_proto)
311+
stream = GrpcWrapperAsyncIO(StreamReadMessage.FromServer.from_proto)
312312

313313
await stream.start(driver, _apis.TopicService.Stub, _apis.TopicService.StreamRead)
314314

ydb/_topic_writer/topic_writer_asyncio.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
from .._grpc.grpcwrapper.common_utils import (
4040
IGrpcWrapperAsyncIO,
4141
SupportedDriverType,
42-
GrpcWrapperStreamStreamAsyncIO,
42+
GrpcWrapperAsyncIO,
4343
)
4444

4545
logger = logging.getLogger(__name__)
@@ -613,7 +613,7 @@ async def create(
613613
init_request: StreamWriteMessage.InitRequest,
614614
update_token_interval: Optional[Union[int, float]] = None,
615615
) -> "WriterAsyncIOStream":
616-
stream = GrpcWrapperStreamStreamAsyncIO(StreamWriteMessage.FromServer.from_proto)
616+
stream = GrpcWrapperAsyncIO(StreamWriteMessage.FromServer.from_proto)
617617

618618
await stream.start(driver, _apis.TopicService.Stub, _apis.TopicService.StreamWrite)
619619

0 commit comments

Comments
 (0)