diff --git a/gql/transport/common/base.py b/gql/transport/common/base.py index cae8f488..a285ad2c 100644 --- a/gql/transport/common/base.py +++ b/gql/transport/common/base.py @@ -317,8 +317,7 @@ async def subscribe( if listener.send_stop: await self._stop_listener(query_id) listener.send_stop = False - if isinstance(e, GeneratorExit): - raise e + raise e finally: log.debug(f"In subscribe finally for query_id {query_id}") diff --git a/tests/test_aiohttp_websocket_graphqlws_subscription.py b/tests/test_aiohttp_websocket_graphqlws_subscription.py index 7c000d01..22dd1004 100644 --- a/tests/test_aiohttp_websocket_graphqlws_subscription.py +++ b/tests/test_aiohttp_websocket_graphqlws_subscription.py @@ -292,16 +292,24 @@ async def test_aiohttp_websocket_graphqlws_subscription_task_cancel( count = 10 subscription = gql(subscription_str.format(count=count)) + task_cancelled = False + async def task_coro(): nonlocal count - async for result in session.subscribe(subscription): + nonlocal task_cancelled - number = result["number"] - print(f"Number received: {number}") + try: + async for result in session.subscribe(subscription): - assert number == count + number = result["number"] + print(f"Number received: {number}") - count -= 1 + assert number == count + + count -= 1 + except asyncio.CancelledError: + print("Inside task cancelled") + task_cancelled = True task = asyncio.ensure_future(task_coro()) @@ -317,6 +325,7 @@ async def cancel_task_coro(): await asyncio.gather(task, cancel_task) assert count > 0 + assert task_cancelled is True @pytest.mark.asyncio diff --git a/tests/test_aiohttp_websocket_subscription.py b/tests/test_aiohttp_websocket_subscription.py index 83ae3589..32daf038 100644 --- a/tests/test_aiohttp_websocket_subscription.py +++ b/tests/test_aiohttp_websocket_subscription.py @@ -283,16 +283,24 @@ async def test_aiohttp_websocket_subscription_task_cancel( count = 10 subscription = gql(subscription_str.format(count=count)) + task_cancelled = False + async def task_coro(): nonlocal count - async for result in session.subscribe(subscription): + nonlocal task_cancelled - number = result["number"] - print(f"Number received: {number}") + try: + async for result in session.subscribe(subscription): - assert number == count + number = result["number"] + print(f"Number received: {number}") - count -= 1 + assert number == count + + count -= 1 + except asyncio.CancelledError: + print("Inside task cancelled") + task_cancelled = True task = asyncio.ensure_future(task_coro()) @@ -308,6 +316,7 @@ async def cancel_task_coro(): await asyncio.gather(task, cancel_task) assert count > 0 + assert task_cancelled is True @pytest.mark.asyncio diff --git a/tests/test_graphqlws_subscription.py b/tests/test_graphqlws_subscription.py index b4c6a17b..45e7aba4 100644 --- a/tests/test_graphqlws_subscription.py +++ b/tests/test_graphqlws_subscription.py @@ -290,16 +290,24 @@ async def test_graphqlws_subscription_task_cancel( count = 10 subscription = gql(subscription_str.format(count=count)) + task_cancelled = False + async def task_coro(): nonlocal count - async for result in session.subscribe(subscription): + nonlocal task_cancelled - number = result["number"] - print(f"Number received: {number}") + try: + async for result in session.subscribe(subscription): - assert number == count + number = result["number"] + print(f"Number received: {number}") - count -= 1 + assert number == count + + count -= 1 + except asyncio.CancelledError: + print("Inside task cancelled") + task_cancelled = True task = asyncio.ensure_future(task_coro()) @@ -315,6 +323,7 @@ async def cancel_task_coro(): await asyncio.gather(task, cancel_task) assert count > 0 + assert task_cancelled is True @pytest.mark.asyncio diff --git a/tests/test_websocket_subscription.py b/tests/test_websocket_subscription.py index 8d2fd152..487b9ba5 100644 --- a/tests/test_websocket_subscription.py +++ b/tests/test_websocket_subscription.py @@ -210,16 +210,24 @@ async def test_websocket_subscription_task_cancel(client_and_server, subscriptio count = 10 subscription = gql(subscription_str.format(count=count)) + task_cancelled = False + async def task_coro(): nonlocal count - async for result in session.subscribe(subscription): + nonlocal task_cancelled - number = result["number"] - print(f"Number received: {number}") + try: + async for result in session.subscribe(subscription): - assert number == count + number = result["number"] + print(f"Number received: {number}") - count -= 1 + assert number == count + + count -= 1 + except asyncio.CancelledError: + print("Inside task cancelled") + task_cancelled = True task = asyncio.ensure_future(task_coro()) @@ -235,6 +243,7 @@ async def cancel_task_coro(): await asyncio.gather(task, cancel_task) assert count > 0 + assert task_cancelled is True @pytest.mark.asyncio