Skip to content

Commit 7b48b1b

Browse files
authored
AsyncIO Race Condition Fix (#2639)
1 parent 54a1dce commit 7b48b1b

File tree

6 files changed

+64
-8
lines changed

6 files changed

+64
-8
lines changed

.github/workflows/integration.yaml

+4-2
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,11 @@ jobs:
3232
invoke linters
3333
3434
run-tests:
35-
runs-on: ubuntu-latest
35+
runs-on: ubuntu-20.04
3636
timeout-minutes: 30
3737
strategy:
3838
max-parallel: 15
39+
fail-fast: false
3940
matrix:
4041
python-version: ['3.6', '3.7', '3.8', '3.9', '3.10', 'pypy-3.7']
4142
test-type: ['standalone', 'cluster']
@@ -79,8 +80,9 @@ jobs:
7980
8081
install_package_from_commit:
8182
name: Install package from commit hash
82-
runs-on: ubuntu-latest
83+
runs-on: ubuntu-20.04
8384
strategy:
85+
fail-fast: false
8486
matrix:
8587
python-version: ['3.6', '3.7', '3.8', '3.9', '3.10', 'pypy-3.7']
8688
steps:

redis/asyncio/client.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -1349,10 +1349,16 @@ async def execute(self, raise_on_error: bool = True):
13491349
conn = cast(Connection, conn)
13501350

13511351
try:
1352-
return await conn.retry.call_with_retry(
1353-
lambda: execute(conn, stack, raise_on_error),
1354-
lambda error: self._disconnect_raise_reset(conn, error),
1352+
return await asyncio.shield(
1353+
conn.retry.call_with_retry(
1354+
lambda: execute(conn, stack, raise_on_error),
1355+
lambda error: self._disconnect_raise_reset(conn, error),
1356+
)
13551357
)
1358+
except asyncio.CancelledError:
1359+
# not supposed to be possible, yet here we are
1360+
await conn.disconnect(nowait=True)
1361+
raise
13561362
finally:
13571363
await self.reset()
13581364

redis/asyncio/cluster.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -879,10 +879,18 @@ async def execute_command(self, *args: Any, **kwargs: Any) -> Any:
879879
await connection.send_packed_command(connection.pack_command(*args), False)
880880

881881
# Read response
882+
return await asyncio.shield(
883+
self._parse_and_release(connection, args[0], **kwargs)
884+
)
885+
886+
async def _parse_and_release(self, connection, *args, **kwargs):
882887
try:
883-
return await self.parse_response(connection, args[0], **kwargs)
888+
return await self.parse_response(connection, *args, **kwargs)
889+
except asyncio.CancelledError:
890+
# should not be possible
891+
await connection.disconnect(nowait=True)
892+
raise
884893
finally:
885-
# Release connection
886894
self._free.append(connection)
887895

888896
async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool:

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
long_description_content_type="text/markdown",
99
keywords=["Redis", "key-value store", "database"],
1010
license="MIT",
11-
version="4.3.5",
11+
version="4.3.6",
1212
packages=find_packages(
1313
include=[
1414
"redis",

tests/test_asyncio/test_cluster.py

+17
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,23 @@ async def test_execute_command_node_flag_random(self, r: RedisCluster) -> None:
333333
called_count += 1
334334
assert called_count == 1
335335

336+
async def test_asynckills(self, r) -> None:
337+
338+
await r.set("foo", "foo")
339+
await r.set("bar", "bar")
340+
341+
t = asyncio.create_task(r.get("foo"))
342+
await asyncio.sleep(1)
343+
t.cancel()
344+
try:
345+
await t
346+
except asyncio.CancelledError:
347+
pytest.fail("connection is left open with unread response")
348+
349+
assert await r.get("bar") == b"bar"
350+
assert await r.ping()
351+
assert await r.get("foo") == b"foo"
352+
336353
async def test_execute_command_default_node(self, r: RedisCluster) -> None:
337354
"""
338355
Test command execution without node flag is being executed on the

tests/test_asyncio/test_connection.py

+23
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,29 @@ async def test_invalid_response(create_redis):
2828
assert str(cm.value) == f"Protocol Error: {raw!r}"
2929

3030

31+
@pytest.mark.onlynoncluster
32+
async def test_asynckills():
33+
from redis.asyncio.client import Redis
34+
35+
for b in [True, False]:
36+
r = Redis(single_connection_client=b)
37+
38+
await r.set("foo", "foo")
39+
await r.set("bar", "bar")
40+
41+
t = asyncio.create_task(r.get("foo"))
42+
await asyncio.sleep(1)
43+
t.cancel()
44+
try:
45+
await t
46+
except asyncio.CancelledError:
47+
pytest.fail("connection left open with unread response")
48+
49+
assert await r.get("bar") == b"bar"
50+
assert await r.ping()
51+
assert await r.get("foo") == b"foo"
52+
53+
3154
@skip_if_server_version_lt("4.0.0")
3255
@pytest.mark.redismod
3356
@pytest.mark.onlynoncluster

0 commit comments

Comments
 (0)