diff --git a/tests/aio/query/test_query_session.py b/tests/aio/query/test_query_session.py index 0bd06fba..67db045a 100644 --- a/tests/aio/query/test_query_session.py +++ b/tests/aio/query/test_query_session.py @@ -103,10 +103,13 @@ async def test_basic_execute(self, session: QuerySession): async def test_two_results(self, session: QuerySession): await session.create() res = [] + counter = 0 async with await session.execute("select 1; select 2") as results: async for result_set in results: + counter += 1 if len(result_set.rows) > 0: res.append(list(result_set.rows[0].values())) assert res == [[1], [2]] + assert counter == 2 diff --git a/tests/aio/query/test_query_transaction.py b/tests/aio/query/test_query_transaction.py index 47222d0b..aa59abb3 100644 --- a/tests/aio/query/test_query_transaction.py +++ b/tests/aio/query/test_query_transaction.py @@ -92,3 +92,18 @@ async def test_execute_as_context_manager(self, tx: QueryTxContext): res = [result_set async for result_set in results] assert len(res) == 1 + + @pytest.mark.asyncio + async def test_execute_two_results(self, tx: QueryTxContext): + await tx.begin() + counter = 0 + res = [] + + async with await tx.execute("select 1; select 2") as results: + async for result_set in results: + counter += 1 + if len(result_set.rows) > 0: + res.append(list(result_set.rows[0].values())) + + assert res == [[1], [2]] + assert counter == 2 diff --git a/tests/query/test_query_session.py b/tests/query/test_query_session.py index 6c1bc3e8..f151661f 100644 --- a/tests/query/test_query_session.py +++ b/tests/query/test_query_session.py @@ -98,13 +98,16 @@ def test_basic_execute(self, session: QuerySession): def test_two_results(self, session: QuerySession): session.create() res = [] + counter = 0 with session.execute("select 1; select 2") as results: for result_set in results: + counter += 1 if len(result_set.rows) > 0: res.append(list(result_set.rows[0].values())) assert res == [[1], [2]] + assert counter == 2 def test_thread_leaks(self, session: QuerySession): session.create() diff --git a/tests/query/test_query_transaction.py b/tests/query/test_query_transaction.py index 9e78988a..dfc88897 100644 --- a/tests/query/test_query_transaction.py +++ b/tests/query/test_query_transaction.py @@ -79,3 +79,16 @@ def test_execute_as_context_manager(self, tx: QueryTxContext): res = [result_set for result_set in results] assert len(res) == 1 + + def test_execute_two_results(self, tx: QueryTxContext): + tx.begin() + counter = 0 + res = [] + + with tx.execute("select 1; select 2") as results: + for result_set in results: + counter += 1 + res.append(list(result_set.rows[0].values())) + + assert res == [[1], [2]] + assert counter == 2 diff --git a/ydb/_utilities.py b/ydb/_utilities.py index 117c7407..8496dbd9 100644 --- a/ydb/_utilities.py +++ b/ydb/_utilities.py @@ -161,7 +161,10 @@ def __iter__(self): return self def _next(self): - return self.wrapper(next(self.it)) + res = self.wrapper(next(self.it)) + if res is not None: + return res + return self._next() def next(self): return self._next() diff --git a/ydb/aio/_utilities.py b/ydb/aio/_utilities.py index 5bd0f1a0..296cd256 100644 --- a/ydb/aio/_utilities.py +++ b/ydb/aio/_utilities.py @@ -17,7 +17,10 @@ def __aiter__(self): return self async def _next(self): - return self.wrapper(await self.it.__anext__()) + res = self.wrapper(await self.it.__anext__()) + if res is not None: + return res + return await self._next() async def next(self): return await self._next() diff --git a/ydb/query/base.py b/ydb/query/base.py index 9372cbcf..57a769bb 100644 --- a/ydb/query/base.py +++ b/ydb/query/base.py @@ -192,4 +192,7 @@ def wrap_execute_query_response( elif tx and response_pb.tx_meta and not tx.tx_id: tx._move_to_beginned(response_pb.tx_meta.id) - return convert.ResultSet.from_message(response_pb.result_set, settings) + if response_pb.HasField("result_set"): + return convert.ResultSet.from_message(response_pb.result_set, settings) + + return None