From 7e847589fdeb7c6b8139bca0288ff7aa32e26ffb Mon Sep 17 00:00:00 2001 From: ManelCoutinhoSensei Date: Mon, 14 Apr 2025 16:31:43 +0100 Subject: [PATCH 1/9] SentinelManagedConnection searches for new master upon connection failure (#3560) --- redis/asyncio/connection.py | 13 ++++--- redis/asyncio/sentinel.py | 17 ++++++---- redis/connection.py | 13 ++++--- redis/sentinel.py | 18 ++++++---- .../test_sentinel_managed_connection.py | 1 + tests/test_sentinel_managed_connection.py | 34 +++++++++++++++++++ 6 files changed, 74 insertions(+), 22 deletions(-) create mode 100644 tests/test_sentinel_managed_connection.py diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index 95390bd66c..db8025b6f2 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -295,13 +295,18 @@ async def connect(self): """Connects to the Redis server if not already connected""" await self.connect_check_health(check_health=True) - async def connect_check_health(self, check_health: bool = True): + async def connect_check_health( + self, check_health: bool = True, retry_socket_connect: bool = True + ): if self.is_connected: return try: - await self.retry.call_with_retry( - lambda: self._connect(), lambda error: self.disconnect() - ) + if retry_socket_connect: + await self.retry.call_with_retry( + lambda: self._connect(), lambda error: self.disconnect() + ) + else: + await self._connect() except asyncio.CancelledError: raise # in 3.7 and earlier, this is an Exception, not BaseException except (socket.timeout, asyncio.TimeoutError): diff --git a/redis/asyncio/sentinel.py b/redis/asyncio/sentinel.py index 0bf7086555..d0455ab6eb 100644 --- a/redis/asyncio/sentinel.py +++ b/redis/asyncio/sentinel.py @@ -11,8 +11,12 @@ SSLConnection, ) from redis.commands import AsyncSentinelCommands -from redis.exceptions import ConnectionError, ReadOnlyError, ResponseError, TimeoutError -from redis.utils import str_if_bytes +from redis.exceptions import ( + ConnectionError, + ReadOnlyError, + ResponseError, + TimeoutError, +) class MasterNotFoundError(ConnectionError): @@ -37,11 +41,10 @@ def __repr__(self): async def connect_to(self, address): self.host, self.port = address - await super().connect() - if self.connection_pool.check_connection: - await self.send_command("PING") - if str_if_bytes(await self.read_response()) != "PONG": - raise ConnectionError("PING failed") + await self.connect_check_health( + check_health=self.connection_pool.check_connection, + retry_socket_connect=False, + ) async def _connect_retry(self): if self._reader: diff --git a/redis/connection.py b/redis/connection.py index d457b1015c..a456514a88 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -378,13 +378,18 @@ def connect(self): "Connects to the Redis server if not already connected" self.connect_check_health(check_health=True) - def connect_check_health(self, check_health: bool = True): + def connect_check_health( + self, check_health: bool = True, retry_socket_connect: bool = True + ): if self._sock: return try: - sock = self.retry.call_with_retry( - lambda: self._connect(), lambda error: self.disconnect(error) - ) + if retry_socket_connect: + sock = self.retry.call_with_retry( + lambda: self._connect(), lambda error: self.disconnect(error) + ) + else: + sock = self._connect() except socket.timeout: raise TimeoutError("Timeout connecting to server") except OSError as e: diff --git a/redis/sentinel.py b/redis/sentinel.py index 198639c932..f12bd8dd5d 100644 --- a/redis/sentinel.py +++ b/redis/sentinel.py @@ -5,8 +5,12 @@ from redis.client import Redis from redis.commands import SentinelCommands from redis.connection import Connection, ConnectionPool, SSLConnection -from redis.exceptions import ConnectionError, ReadOnlyError, ResponseError, TimeoutError -from redis.utils import str_if_bytes +from redis.exceptions import ( + ConnectionError, + ReadOnlyError, + ResponseError, + TimeoutError, +) class MasterNotFoundError(ConnectionError): @@ -35,11 +39,11 @@ def __repr__(self): def connect_to(self, address): self.host, self.port = address - super().connect() - if self.connection_pool.check_connection: - self.send_command("PING") - if str_if_bytes(self.read_response()) != "PONG": - raise ConnectionError("PING failed") + + self.connect_check_health( + check_health=self.connection_pool.check_connection, + retry_socket_connect=False, + ) def _connect_retry(self): if self._sock: diff --git a/tests/test_asyncio/test_sentinel_managed_connection.py b/tests/test_asyncio/test_sentinel_managed_connection.py index 01f717ee38..5a511b2793 100644 --- a/tests/test_asyncio/test_sentinel_managed_connection.py +++ b/tests/test_asyncio/test_sentinel_managed_connection.py @@ -33,4 +33,5 @@ async def mock_connect(): conn._connect.side_effect = mock_connect await conn.connect() assert conn._connect.call_count == 3 + assert connection_pool.get_master_address.call_count == 3 await conn.disconnect() diff --git a/tests/test_sentinel_managed_connection.py b/tests/test_sentinel_managed_connection.py new file mode 100644 index 0000000000..6fe5f7cd5b --- /dev/null +++ b/tests/test_sentinel_managed_connection.py @@ -0,0 +1,34 @@ +import socket + +from redis.retry import Retry +from redis.sentinel import SentinelManagedConnection +from redis.backoff import NoBackoff +from unittest import mock + + +def test_connect_retry_on_timeout_error(master_host): + """Test that the _connect function is retried in case of a timeout""" + connection_pool = mock.Mock() + connection_pool.get_master_address = mock.Mock( + return_value=(master_host[0], master_host[1]) + ) + conn = SentinelManagedConnection( + retry_on_timeout=True, + retry=Retry(NoBackoff(), 3), + connection_pool=connection_pool, + ) + origin_connect = conn._connect + conn._connect = mock.Mock() + + def mock_connect(): + # connect only on the last retry + if conn._connect.call_count <= 2: + raise socket.timeout + else: + return origin_connect() + + conn._connect.side_effect = mock_connect + conn.connect() + assert conn._connect.call_count == 3 + assert connection_pool.get_master_address.call_count == 3 + conn.disconnect() From b7dcd8067d94c8610dcd92b3490319f4dc1eefcf Mon Sep 17 00:00:00 2001 From: petyaslavova Date: Fri, 4 Jul 2025 18:57:17 +0300 Subject: [PATCH 2/9] Updating the latest Redis image for pipeline testing (#3695) --- .github/workflows/integration.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/integration.yaml b/.github/workflows/integration.yaml index c91fa91d01..b720e1d99c 100644 --- a/.github/workflows/integration.yaml +++ b/.github/workflows/integration.yaml @@ -74,7 +74,7 @@ jobs: max-parallel: 15 fail-fast: false matrix: - redis-version: ['8.2-M01-pre', '${{ needs.redis_version.outputs.CURRENT }}', '7.4.4', '7.2.9'] + redis-version: ['8.2-RC1-pre', '${{ needs.redis_version.outputs.CURRENT }}', '7.4.4', '7.2.9'] python-version: ['3.9', '3.13'] parser-backend: ['plain'] event-loop: ['asyncio'] From fbce374b1ea949435e27e24c0a054fe444dd680e Mon Sep 17 00:00:00 2001 From: ofekshenawa <104765379+ofekshenawa@users.noreply.github.com> Date: Mon, 7 Jul 2025 08:33:51 +0300 Subject: [PATCH 3/9] Add support for new BITOP operations: DIFF, DIFF1, ANDOR, ONE (#3690) * Add support for new BITOP operations: DIFF, DIFF1, ANDOR, ONE * fix linting issues * change version checking from 8.2.0 to 8.1.224 --- tests/test_asyncio/test_commands.py | 97 +++++++++++++++++++++++++++++ tests/test_commands.py | 97 +++++++++++++++++++++++++++++ 2 files changed, 194 insertions(+) diff --git a/tests/test_asyncio/test_commands.py b/tests/test_asyncio/test_commands.py index bfb6855a0f..bf65210f31 100644 --- a/tests/test_asyncio/test_commands.py +++ b/tests/test_asyncio/test_commands.py @@ -879,6 +879,103 @@ async def test_bitop_string_operands(self, r: redis.Redis): assert int(binascii.hexlify(await r.get("res2")), 16) == 0x0102FFFF assert int(binascii.hexlify(await r.get("res3")), 16) == 0x000000FF + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("8.1.224") + async def test_bitop_diff(self, r: redis.Redis): + await r.set("a", b"\xf0") + await r.set("b", b"\xc0") + await r.set("c", b"\x80") + + result = await r.bitop("DIFF", "result", "a", "b", "c") + assert result == 1 + assert await r.get("result") == b"\x30" + + await r.bitop("DIFF", "result2", "a", "nonexistent") + assert await r.get("result2") == b"\xf0" + + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("8.1.224") + async def test_bitop_diff1(self, r: redis.Redis): + await r.set("a", b"\xf0") + await r.set("b", b"\xc0") + await r.set("c", b"\x80") + + result = await r.bitop("DIFF1", "result", "a", "b", "c") + assert result == 1 + assert await r.get("result") == b"\x00" + + await r.set("d", b"\x0f") + await r.set("e", b"\x03") + await r.bitop("DIFF1", "result2", "d", "e") + assert await r.get("result2") == b"\x00" + + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("8.1.224") + async def test_bitop_andor(self, r: redis.Redis): + await r.set("a", b"\xf0") + await r.set("b", b"\xc0") + await r.set("c", b"\x80") + + result = await r.bitop("ANDOR", "result", "a", "b", "c") + assert result == 1 + assert await r.get("result") == b"\xc0" + + await r.set("x", b"\xf0") + await r.set("y", b"\x0f") + await r.bitop("ANDOR", "result2", "x", "y") + assert await r.get("result2") == b"\x00" + + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("8.1.224") + async def test_bitop_one(self, r: redis.Redis): + await r.set("a", b"\xf0") + await r.set("b", b"\xc0") + await r.set("c", b"\x80") + + result = await r.bitop("ONE", "result", "a", "b", "c") + assert result == 1 + assert await r.get("result") == b"\x30" + + await r.set("x", b"\xf0") + await r.set("y", b"\x0f") + await r.bitop("ONE", "result2", "x", "y") + assert await r.get("result2") == b"\xff" + + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("8.1.224") + async def test_bitop_new_operations_with_empty_keys(self, r: redis.Redis): + await r.set("a", b"\xff") + + await r.bitop("DIFF", "empty_result", "nonexistent", "a") + assert await r.get("empty_result") == b"\x00" + + await r.bitop("DIFF1", "empty_result2", "a", "nonexistent") + assert await r.get("empty_result2") == b"\x00" + + await r.bitop("ANDOR", "empty_result3", "a", "nonexistent") + assert await r.get("empty_result3") == b"\x00" + + await r.bitop("ONE", "empty_result4", "nonexistent") + assert await r.get("empty_result4") is None + + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("8.1.224") + async def test_bitop_new_operations_return_values(self, r: redis.Redis): + await r.set("a", b"\xff\x00\xff") + await r.set("b", b"\x00\xff") + + result1 = await r.bitop("DIFF", "result1", "a", "b") + assert result1 == 3 + + result2 = await r.bitop("DIFF1", "result2", "a", "b") + assert result2 == 3 + + result3 = await r.bitop("ANDOR", "result3", "a", "b") + assert result3 == 3 + + result4 = await r.bitop("ONE", "result4", "a", "b") + assert result4 == 3 + @pytest.mark.onlynoncluster @skip_if_server_version_lt("2.8.7") async def test_bitpos(self, r: redis.Redis): diff --git a/tests/test_commands.py b/tests/test_commands.py index 04574cbb81..9ac8ee4933 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -1313,6 +1313,103 @@ def test_bitop_string_operands(self, r): assert int(binascii.hexlify(r["res2"]), 16) == 0x0102FFFF assert int(binascii.hexlify(r["res3"]), 16) == 0x000000FF + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("8.1.224") + def test_bitop_diff(self, r): + r["a"] = b"\xf0" + r["b"] = b"\xc0" + r["c"] = b"\x80" + + result = r.bitop("DIFF", "result", "a", "b", "c") + assert result == 1 + assert r["result"] == b"\x30" + + r.bitop("DIFF", "result2", "a", "nonexistent") + assert r["result2"] == b"\xf0" + + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("8.1.224") + def test_bitop_diff1(self, r): + r["a"] = b"\xf0" + r["b"] = b"\xc0" + r["c"] = b"\x80" + + result = r.bitop("DIFF1", "result", "a", "b", "c") + assert result == 1 + assert r["result"] == b"\x00" + + r["d"] = b"\x0f" + r["e"] = b"\x03" + r.bitop("DIFF1", "result2", "d", "e") + assert r["result2"] == b"\x00" + + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("8.1.224") + def test_bitop_andor(self, r): + r["a"] = b"\xf0" + r["b"] = b"\xc0" + r["c"] = b"\x80" + + result = r.bitop("ANDOR", "result", "a", "b", "c") + assert result == 1 + assert r["result"] == b"\xc0" + + r["x"] = b"\xf0" + r["y"] = b"\x0f" + r.bitop("ANDOR", "result2", "x", "y") + assert r["result2"] == b"\x00" + + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("8.1.224") + def test_bitop_one(self, r): + r["a"] = b"\xf0" + r["b"] = b"\xc0" + r["c"] = b"\x80" + + result = r.bitop("ONE", "result", "a", "b", "c") + assert result == 1 + assert r["result"] == b"\x30" + + r["x"] = b"\xf0" + r["y"] = b"\x0f" + r.bitop("ONE", "result2", "x", "y") + assert r["result2"] == b"\xff" + + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("8.1.224") + def test_bitop_new_operations_with_empty_keys(self, r): + r["a"] = b"\xff" + + r.bitop("DIFF", "empty_result", "nonexistent", "a") + assert r.get("empty_result") == b"\x00" + + r.bitop("DIFF1", "empty_result2", "a", "nonexistent") + assert r.get("empty_result2") == b"\x00" + + r.bitop("ANDOR", "empty_result3", "a", "nonexistent") + assert r.get("empty_result3") == b"\x00" + + r.bitop("ONE", "empty_result4", "nonexistent") + assert r.get("empty_result4") is None + + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("8.1.224") + def test_bitop_new_operations_return_values(self, r): + r["a"] = b"\xff\x00\xff" + r["b"] = b"\x00\xff" + + result1 = r.bitop("DIFF", "result1", "a", "b") + assert result1 == 3 + + result2 = r.bitop("DIFF1", "result2", "a", "b") + assert result2 == 3 + + result3 = r.bitop("ANDOR", "result3", "a", "b") + assert result3 == 3 + + result4 = r.bitop("ONE", "result4", "a", "b") + assert result4 == 3 + @pytest.mark.onlynoncluster @skip_if_server_version_lt("2.8.7") def test_bitpos(self, r): From a00d182445961c883ce1892be8f5ebd2272dae86 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 7 Jul 2025 10:35:47 +0300 Subject: [PATCH 4/9] Bump rojopolis/spellcheck-github-actions from 0.49.0 to 0.51.0 (#3689) Bumps [rojopolis/spellcheck-github-actions](https://github.com/rojopolis/spellcheck-github-actions) from 0.49.0 to 0.51.0. - [Release notes](https://github.com/rojopolis/spellcheck-github-actions/releases) - [Changelog](https://github.com/rojopolis/spellcheck-github-actions/blob/master/CHANGELOG.md) - [Commits](https://github.com/rojopolis/spellcheck-github-actions/compare/0.49.0...0.51.0) --- updated-dependencies: - dependency-name: rojopolis/spellcheck-github-actions dependency-version: 0.51.0 dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: petyaslavova --- .github/workflows/spellcheck.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/spellcheck.yml b/.github/workflows/spellcheck.yml index 6ab2c46701..81e73cd4ba 100644 --- a/.github/workflows/spellcheck.yml +++ b/.github/workflows/spellcheck.yml @@ -8,7 +8,7 @@ jobs: - name: Checkout uses: actions/checkout@v4 - name: Check Spelling - uses: rojopolis/spellcheck-github-actions@0.49.0 + uses: rojopolis/spellcheck-github-actions@0.51.0 with: config_path: .github/spellcheck-settings.yml task_name: Markdown From 50773eccaf36b85c78dfeeed8440fc295169eb1f Mon Sep 17 00:00:00 2001 From: andy-stark-redis <164213578+andy-stark-redis@users.noreply.github.com> Date: Mon, 7 Jul 2025 12:09:14 +0100 Subject: [PATCH 5/9] DOC-5225 testable probabilistic dt examples (#3691) --- doctests/home_prob_dts.py | 232 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 232 insertions(+) create mode 100644 doctests/home_prob_dts.py diff --git a/doctests/home_prob_dts.py b/doctests/home_prob_dts.py new file mode 100644 index 0000000000..39d516242f --- /dev/null +++ b/doctests/home_prob_dts.py @@ -0,0 +1,232 @@ +# EXAMPLE: home_prob_dts +""" +Probabilistic data type examples: + https://redis.io/docs/latest/develop/connect/clients/python/redis-py/prob +""" + +# HIDE_START +import redis +r = redis.Redis(decode_responses=True) +# HIDE_END +# REMOVE_START +r.delete( + "recorded_users", "other_users", + "group:1", "group:2", "both_groups", + "items_sold", + "male_heights", "female_heights", "all_heights", + "top_3_songs" +) +# REMOVE_END + +# STEP_START bloom +res1 = r.bf().madd("recorded_users", "andy", "cameron", "david", "michelle") +print(res1) # >>> [1, 1, 1, 1] + +res2 = r.bf().exists("recorded_users", "cameron") +print(res2) # >>> 1 + +res3 = r.bf().exists("recorded_users", "kaitlyn") +print(res3) # >>> 0 +# STEP_END +# REMOVE_START +assert res1 == [1, 1, 1, 1] +assert res2 == 1 +assert res3 == 0 +# REMOVE_END + +# STEP_START cuckoo +res4 = r.cf().add("other_users", "paolo") +print(res4) # >>> 1 + +res5 = r.cf().add("other_users", "kaitlyn") +print(res5) # >>> 1 + +res6 = r.cf().add("other_users", "rachel") +print(res6) # >>> 1 + +res7 = r.cf().mexists("other_users", "paolo", "rachel", "andy") +print(res7) # >>> [1, 1, 0] + +res8 = r.cf().delete("other_users", "paolo") +print(res8) # >>> 1 + +res9 = r.cf().exists("other_users", "paolo") +print(res9) # >>> 0 +# STEP_END +# REMOVE_START +assert res4 == 1 +assert res5 == 1 +assert res6 == 1 +assert res7 == [1, 1, 0] +assert res8 == 1 +assert res9 == 0 +# REMOVE_END + +# STEP_START hyperloglog +res10 = r.pfadd("group:1", "andy", "cameron", "david") +print(res10) # >>> 1 + +res11 = r.pfcount("group:1") +print(res11) # >>> 3 + +res12 = r.pfadd("group:2", "kaitlyn", "michelle", "paolo", "rachel") +print(res12) # >>> 1 + +res13 = r.pfcount("group:2") +print(res13) # >>> 4 + +res14 = r.pfmerge("both_groups", "group:1", "group:2") +print(res14) # >>> True + +res15 = r.pfcount("both_groups") +print(res15) # >>> 7 +# STEP_END +# REMOVE_START +assert res10 == 1 +assert res11 == 3 +assert res12 == 1 +assert res13 == 4 +assert res14 +assert res15 == 7 +# REMOVE_END + +# STEP_START cms +# Specify that you want to keep the counts within 0.01 +# (1%) of the true value with a 0.005 (0.5%) chance +# of going outside this limit. +res16 = r.cms().initbyprob("items_sold", 0.01, 0.005) +print(res16) # >>> True + +# The parameters for `incrby()` are two lists. The count +# for each item in the first list is incremented by the +# value at the same index in the second list. +res17 = r.cms().incrby( + "items_sold", + ["bread", "tea", "coffee", "beer"], # Items sold + [300, 200, 200, 100] +) +print(res17) # >>> [300, 200, 200, 100] + +res18 = r.cms().incrby( + "items_sold", + ["bread", "coffee"], + [100, 150] +) +print(res18) # >>> [400, 350] + +res19 = r.cms().query("items_sold", "bread", "tea", "coffee", "beer") +print(res19) # >>> [400, 200, 350, 100] +# STEP_END +# REMOVE_START +assert res16 +assert res17 == [300, 200, 200, 100] +assert res18 == [400, 350] +assert res19 == [400, 200, 350, 100] +# REMOVE_END + +# STEP_START tdigest +res20 = r.tdigest().create("male_heights") +print(res20) # >>> True + +res21 = r.tdigest().add( + "male_heights", + [175.5, 181, 160.8, 152, 177, 196, 164] +) +print(res21) # >>> OK + +res22 = r.tdigest().min("male_heights") +print(res22) # >>> 152.0 + +res23 = r.tdigest().max("male_heights") +print(res23) # >>> 196.0 + +res24 = r.tdigest().quantile("male_heights", 0.75) +print(res24) # >>> 181 + +# Note that the CDF value for 181 is not exactly +# 0.75. Both values are estimates. +res25 = r.tdigest().cdf("male_heights", 181) +print(res25) # >>> [0.7857142857142857] + +res26 = r.tdigest().create("female_heights") +print(res26) # >>> True + +res27 = r.tdigest().add( + "female_heights", + [155.5, 161, 168.5, 170, 157.5, 163, 171] +) +print(res27) # >>> OK + +res28 = r.tdigest().quantile("female_heights", 0.75) +print(res28) # >>> [170] + +res29 = r.tdigest().merge( + "all_heights", 2, "male_heights", "female_heights" +) +print(res29) # >>> OK + +res30 = r.tdigest().quantile("all_heights", 0.75) +print(res30) # >>> [175.5] +# STEP_END +# REMOVE_START +assert res20 +assert res21 == "OK" +assert res22 == 152.0 +assert res23 == 196.0 +assert res24 == [181] +assert res25 == [0.7857142857142857] +assert res26 +assert res27 == "OK" +assert res28 == [170] +assert res29 == "OK" +assert res30 == [175.5] +# REMOVE_END + +# STEP_START topk +# The `reserve()` method creates the Top-K object with +# the given key. The parameters are the number of items +# in the ranking and values for `width`, `depth`, and +# `decay`, described in the Top-K reference page. +res31 = r.topk().reserve("top_3_songs", 3, 7, 8, 0.9) +print(res31) # >>> True + +# The parameters for `incrby()` are two lists. The count +# for each item in the first list is incremented by the +# value at the same index in the second list. +res32 = r.topk().incrby( + "top_3_songs", + [ + "Starfish Trooper", + "Only one more time", + "Rock me, Handel", + "How will anyone know?", + "Average lover", + "Road to everywhere" + ], + [ + 3000, + 1850, + 1325, + 3890, + 4098, + 770 + ] +) +print(res32) +# >>> [None, None, None, 'Rock me, Handel', 'Only one more time', None] + +res33 = r.topk().list("top_3_songs") +print(res33) +# >>> ['Average lover', 'How will anyone know?', 'Starfish Trooper'] + +res34 = r.topk().query( + "top_3_songs", "Starfish Trooper", "Road to everywhere" +) +print(res34) # >>> [1, 0] +# STEP_END +# REMOVE_START +assert res31 +assert res32 == [None, None, None, 'Rock me, Handel', 'Only one more time', None] +assert res33 == ['Average lover', 'How will anyone know?', 'Starfish Trooper'] +assert res34 == [1, 0] +# REMOVE_END From 07b0e2c84ad9aa6a31b4c5baa58ff3f6f5c605ab Mon Sep 17 00:00:00 2001 From: AmirHossein <2002gholami@gmail.com> Date: Tue, 8 Jul 2025 15:18:55 +0330 Subject: [PATCH 6/9] Update README.md (#3699) Co-authored-by: petyaslavova --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 1c4dfc0f11..97afa2f9bc 100644 --- a/README.md +++ b/README.md @@ -41,7 +41,7 @@ Start a redis via docker (for Redis versions < 8.0): ``` bash docker run -p 6379:6379 -it redis/redis-stack:latest - +``` To install redis-py, simply: ``` bash @@ -209,4 +209,4 @@ Special thanks to: system. - Paul Hubbard for initial packaging support. -[![Redis](./docs/_static/logo-redis.svg)](https://redis.io) \ No newline at end of file +[![Redis](./docs/_static/logo-redis.svg)](https://redis.io) From 310813aae8d1e27aa492c5b6b292276b0f331593 Mon Sep 17 00:00:00 2001 From: Mitch Harding Date: Wed, 9 Jul 2025 08:32:25 -0400 Subject: [PATCH 7/9] Annotate deprecated_args decorator to preserve wrapped function type signature (#3701) --- redis/utils.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/redis/utils.py b/redis/utils.py index 715913e914..79c23c8bda 100644 --- a/redis/utils.py +++ b/redis/utils.py @@ -1,9 +1,10 @@ import datetime import logging import textwrap +from collections.abc import Callable from contextlib import contextmanager from functools import wraps -from typing import Any, Dict, List, Mapping, Optional, Union +from typing import Any, Dict, List, Mapping, Optional, TypeVar, Union from redis.exceptions import DataError from redis.typing import AbsExpiryT, EncodableT, ExpiryT @@ -150,18 +151,21 @@ def warn_deprecated_arg_usage( warnings.warn(msg, category=DeprecationWarning, stacklevel=stacklevel) +C = TypeVar("C", bound=Callable) + + def deprecated_args( args_to_warn: list = ["*"], allowed_args: list = [], reason: str = "", version: str = "", -): +) -> Callable[[C], C]: """ Decorator to mark specified args of a function as deprecated. If '*' is in args_to_warn, all arguments will be marked as deprecated. """ - def decorator(func): + def decorator(func: C) -> C: @wraps(func) def wrapper(*args, **kwargs): # Get function argument names From ce56d1cb0d214c86c3bd3e8348510889aa6e6f0c Mon Sep 17 00:00:00 2001 From: hulk Date: Fri, 11 Jul 2025 19:41:47 +0800 Subject: [PATCH 8/9] Convert the value to int type only if it exists in CLIENT INFO (#3688) Currently, client info will try to convert the set of keys to the int type without checking if it exists or not. For example, both `argv-mem` and `tot-mem` are introduced in 6.2, and force converting an non-existent value might cause an exception in older or newer versions. To keep the compatibility with the older/newer Redis server, we could just convert it only if the key exists. --- redis/_parsers/helpers.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/redis/_parsers/helpers.py b/redis/_parsers/helpers.py index 5468addf62..154dc66dfb 100644 --- a/redis/_parsers/helpers.py +++ b/redis/_parsers/helpers.py @@ -676,7 +676,8 @@ def parse_client_info(value): "omem", "tot-mem", }: - client_info[int_key] = int(client_info[int_key]) + if int_key in client_info: + client_info[int_key] = int(client_info[int_key]) return client_info From a757bad78bf18869cdceb370ca8d05d3a7c96942 Mon Sep 17 00:00:00 2001 From: ofekshenawa <104765379+ofekshenawa@users.noreply.github.com> Date: Mon, 14 Jul 2025 10:31:47 +0300 Subject: [PATCH 9/9] Support new VAMANA vector type (#3702) * Support new vector type * Skip VAMANA tests is redis versin is not 8.2 * Add async tests * Fix resp 3 errors --- redis/commands/search/field.py | 8 +- tests/test_asyncio/test_search.py | 178 +++++++ tests/test_search.py | 840 +++++++++++++++++++++++++++++- 3 files changed, 1020 insertions(+), 6 deletions(-) diff --git a/redis/commands/search/field.py b/redis/commands/search/field.py index 8af7777f19..45cd403e49 100644 --- a/redis/commands/search/field.py +++ b/redis/commands/search/field.py @@ -181,7 +181,7 @@ def __init__(self, name: str, algorithm: str, attributes: dict, **kwargs): ``name`` is the name of the field. - ``algorithm`` can be "FLAT" or "HNSW". + ``algorithm`` can be "FLAT", "HNSW", or "SVS-VAMANA". ``attributes`` each algorithm can have specific attributes. Some of them are mandatory and some of them are optional. See @@ -194,10 +194,10 @@ def __init__(self, name: str, algorithm: str, attributes: dict, **kwargs): if sort or noindex: raise DataError("Cannot set 'sortable' or 'no_index' in Vector fields.") - if algorithm.upper() not in ["FLAT", "HNSW"]: + if algorithm.upper() not in ["FLAT", "HNSW", "SVS-VAMANA"]: raise DataError( - "Realtime vector indexing supporting 2 Indexing Methods:" - "'FLAT' and 'HNSW'." + "Realtime vector indexing supporting 3 Indexing Methods:" + "'FLAT', 'HNSW', and 'SVS-VAMANA'." ) attr_li = [] diff --git a/tests/test_asyncio/test_search.py b/tests/test_asyncio/test_search.py index 932ece59b8..0004f9ba75 100644 --- a/tests/test_asyncio/test_search.py +++ b/tests/test_asyncio/test_search.py @@ -1815,3 +1815,181 @@ async def test_binary_and_text_fields(decoded_r: redis.Redis): assert docs[0]["first_name"] == mixed_data["first_name"], ( "The text field is not decoded correctly" ) + + +# SVS-VAMANA Async Tests +@pytest.mark.redismod +@skip_if_server_version_lt("8.1.224") +async def test_async_svs_vamana_basic_functionality(decoded_r: redis.Redis): + await decoded_r.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + {"TYPE": "FLOAT32", "DIM": 4, "DISTANCE_METRIC": "L2"}, + ), + ) + ) + + vectors = [ + [1.0, 2.0, 3.0, 4.0], + [2.0, 3.0, 4.0, 5.0], + [3.0, 4.0, 5.0, 6.0], + [10.0, 11.0, 12.0, 13.0], + ] + + for i, vec in enumerate(vectors): + await decoded_r.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) + + query = "*=>[KNN 3 @v $vec]" + q = Query(query).return_field("__v_score").sort_by("__v_score", True) + res = await decoded_r.ft().search( + q, query_params={"vec": np.array(vectors[0], dtype=np.float32).tobytes()} + ) + + if is_resp2_connection(decoded_r): + assert res.total == 3 + assert "doc0" == res.docs[0].id + else: + assert res["total_results"] == 3 + assert "doc0" == res["results"][0]["id"] + + +@pytest.mark.redismod +@skip_if_server_version_lt("8.1.224") +async def test_async_svs_vamana_distance_metrics(decoded_r: redis.Redis): + # Test COSINE distance + await decoded_r.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + {"TYPE": "FLOAT32", "DIM": 3, "DISTANCE_METRIC": "COSINE"}, + ), + ) + ) + + vectors = [[1.0, 0.0, 0.0], [0.707, 0.707, 0.0], [0.0, 1.0, 0.0], [-1.0, 0.0, 0.0]] + + for i, vec in enumerate(vectors): + await decoded_r.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) + + query = Query("*=>[KNN 2 @v $vec as score]").sort_by("score").no_content() + query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} + + res = await decoded_r.ft().search(query, query_params=query_params) + if is_resp2_connection(decoded_r): + assert res.total == 2 + assert "doc0" == res.docs[0].id + else: + assert res["total_results"] == 2 + assert "doc0" == res["results"][0]["id"] + + +@pytest.mark.redismod +@skip_if_server_version_lt("8.1.224") +async def test_async_svs_vamana_vector_types(decoded_r: redis.Redis): + # Test FLOAT16 + await decoded_r.ft("idx16").create_index( + ( + VectorField( + "v16", + "SVS-VAMANA", + {"TYPE": "FLOAT16", "DIM": 4, "DISTANCE_METRIC": "L2"}, + ), + ) + ) + + vectors = [[1.5, 2.5, 3.5, 4.5], [2.5, 3.5, 4.5, 5.5], [3.5, 4.5, 5.5, 6.5]] + + for i, vec in enumerate(vectors): + await decoded_r.hset( + f"doc16_{i}", "v16", np.array(vec, dtype=np.float16).tobytes() + ) + + query = Query("*=>[KNN 2 @v16 $vec as score]").no_content() + query_params = {"vec": np.array(vectors[0], dtype=np.float16).tobytes()} + + res = await decoded_r.ft("idx16").search(query, query_params=query_params) + if is_resp2_connection(decoded_r): + assert res.total == 2 + assert "doc16_0" == res.docs[0].id + else: + assert res["total_results"] == 2 + assert "doc16_0" == res["results"][0]["id"] + + +@pytest.mark.redismod +@skip_if_server_version_lt("8.1.224") +async def test_async_svs_vamana_compression(decoded_r: redis.Redis): + await decoded_r.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + { + "TYPE": "FLOAT32", + "DIM": 8, + "DISTANCE_METRIC": "L2", + "COMPRESSION": "LVQ8", + "TRAINING_THRESHOLD": 1024, + }, + ), + ) + ) + + vectors = [] + for i in range(20): + vec = [float(i + j) for j in range(8)] + vectors.append(vec) + await decoded_r.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) + + query = Query("*=>[KNN 5 @v $vec as score]").no_content() + query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} + + res = await decoded_r.ft().search(query, query_params=query_params) + if is_resp2_connection(decoded_r): + assert res.total == 5 + assert "doc0" == res.docs[0].id + else: + assert res["total_results"] == 5 + assert "doc0" == res["results"][0]["id"] + + +@pytest.mark.redismod +@skip_if_server_version_lt("8.1.224") +async def test_async_svs_vamana_build_parameters(decoded_r: redis.Redis): + await decoded_r.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + { + "TYPE": "FLOAT32", + "DIM": 6, + "DISTANCE_METRIC": "COSINE", + "CONSTRUCTION_WINDOW_SIZE": 300, + "GRAPH_MAX_DEGREE": 64, + "SEARCH_WINDOW_SIZE": 20, + "EPSILON": 0.05, + }, + ), + ) + ) + + vectors = [] + for i in range(15): + vec = [float(i + j) for j in range(6)] + vectors.append(vec) + await decoded_r.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) + + query = Query("*=>[KNN 3 @v $vec as score]").no_content() + query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} + + res = await decoded_r.ft().search(query, query_params=query_params) + if is_resp2_connection(decoded_r): + assert res.total == 3 + assert "doc0" == res.docs[0].id + else: + assert res["total_results"] == 3 + assert "doc0" == res["results"][0]["id"] diff --git a/tests/test_search.py b/tests/test_search.py index 4af55e8a17..3460b56ca1 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -2863,6 +2863,100 @@ def test_vector_search_with_default_dialect(client): assert res["total_results"] == 2 +@pytest.mark.redismod +@skip_ifmodversion_lt("2.4.3", "search") +@skip_if_server_version_lt("8.1.224") +def test_svs_vamana_l2_distance_metric(client): + client.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + {"TYPE": "FLOAT32", "DIM": 3, "DISTANCE_METRIC": "L2"}, + ), + ) + ) + + # L2 distance test vectors + vectors = [[1.0, 0.0, 0.0], [2.0, 0.0, 0.0], [0.0, 1.0, 0.0], [5.0, 0.0, 0.0]] + + for i, vec in enumerate(vectors): + client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) + + query = Query("*=>[KNN 3 @v $vec as score]").sort_by("score").no_content() + query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} + + res = client.ft().search(query, query_params=query_params) + if is_resp2_connection(client): + assert res.total == 3 + assert "doc0" == res.docs[0].id + else: + assert res["total_results"] == 3 + assert "doc0" == res["results"][0]["id"] + + +@pytest.mark.redismod +@skip_ifmodversion_lt("2.4.3", "search") +@skip_if_server_version_lt("8.1.224") +def test_svs_vamana_cosine_distance_metric(client): + client.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + {"TYPE": "FLOAT32", "DIM": 3, "DISTANCE_METRIC": "COSINE"}, + ), + ) + ) + + vectors = [[1.0, 0.0, 0.0], [0.707, 0.707, 0.0], [0.0, 1.0, 0.0], [-1.0, 0.0, 0.0]] + + for i, vec in enumerate(vectors): + client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) + + query = Query("*=>[KNN 3 @v $vec as score]").sort_by("score").no_content() + query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} + + res = client.ft().search(query, query_params=query_params) + if is_resp2_connection(client): + assert res.total == 3 + assert "doc0" == res.docs[0].id + else: + assert res["total_results"] == 3 + assert "doc0" == res["results"][0]["id"] + + +@pytest.mark.redismod +@skip_ifmodversion_lt("2.4.3", "search") +@skip_if_server_version_lt("8.1.224") +def test_svs_vamana_ip_distance_metric(client): + client.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + {"TYPE": "FLOAT32", "DIM": 3, "DISTANCE_METRIC": "IP"}, + ), + ) + ) + + vectors = [[1.0, 2.0, 3.0], [2.0, 1.0, 1.0], [3.0, 3.0, 3.0], [0.1, 0.1, 0.1]] + + for i, vec in enumerate(vectors): + client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) + + query = Query("*=>[KNN 3 @v $vec as score]").sort_by("score").no_content() + query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} + + res = client.ft().search(query, query_params=query_params) + if is_resp2_connection(client): + assert res.total == 3 + assert "doc2" == res.docs[0].id + else: + assert res["total_results"] == 3 + assert "doc2" == res["results"][0]["id"] + + @pytest.mark.redismod @skip_if_server_version_lt("7.9.0") def test_vector_search_with_int8_type(client): @@ -2878,7 +2972,7 @@ def test_vector_search_with_int8_type(client): client.hset("b", "v", np.array(b, dtype=np.int8).tobytes()) client.hset("c", "v", np.array(c, dtype=np.int8).tobytes()) - query = Query("*=>[KNN 2 @v $vec as score]") + query = Query("*=>[KNN 2 @v $vec as score]").no_content() query_params = {"vec": np.array(a, dtype=np.int8).tobytes()} assert 2 in query.get_args() @@ -2909,7 +3003,7 @@ def test_vector_search_with_uint8_type(client): client.hset("b", "v", np.array(b, dtype=np.uint8).tobytes()) client.hset("c", "v", np.array(c, dtype=np.uint8).tobytes()) - query = Query("*=>[KNN 2 @v $vec as score]") + query = Query("*=>[KNN 2 @v $vec as score]").no_content() query_params = {"vec": np.array(a, dtype=np.uint8).tobytes()} assert 2 in query.get_args() @@ -2966,3 +3060,745 @@ def _assert_search_result(client, result, expected_doc_ids): assert set([doc.id for doc in result.docs]) == set(expected_doc_ids) else: assert set([doc["id"] for doc in result["results"]]) == set(expected_doc_ids) + + +@pytest.mark.redismod +@skip_ifmodversion_lt("2.4.3", "search") +@skip_if_server_version_lt("8.1.224") +def test_svs_vamana_basic_functionality(client): + client.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + {"TYPE": "FLOAT32", "DIM": 4, "DISTANCE_METRIC": "L2"}, + ), + ) + ) + + vectors = [ + [1.0, 2.0, 3.0, 4.0], + [2.0, 3.0, 4.0, 5.0], + [3.0, 4.0, 5.0, 6.0], + [10.0, 11.0, 12.0, 13.0], + ] + + for i, vec in enumerate(vectors): + client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) + + query = "*=>[KNN 3 @v $vec]" + q = Query(query).return_field("__v_score").sort_by("__v_score", True) + res = client.ft().search( + q, query_params={"vec": np.array(vectors[0], dtype=np.float32).tobytes()} + ) + + if is_resp2_connection(client): + assert res.total == 3 + assert "doc0" == res.docs[0].id # Should be closest to itself + assert "0" == res.docs[0].__getattribute__("__v_score") + else: + assert res["total_results"] == 3 + assert "doc0" == res["results"][0]["id"] + assert "0" == res["results"][0]["extra_attributes"]["__v_score"] + + +@pytest.mark.redismod +@skip_ifmodversion_lt("2.4.3", "search") +@skip_if_server_version_lt("8.1.224") +def test_svs_vamana_float16_type(client): + client.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + {"TYPE": "FLOAT16", "DIM": 4, "DISTANCE_METRIC": "L2"}, + ), + ) + ) + + vectors = [[1.5, 2.5, 3.5, 4.5], [2.5, 3.5, 4.5, 5.5], [3.5, 4.5, 5.5, 6.5]] + + for i, vec in enumerate(vectors): + client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float16).tobytes()) + + query = Query("*=>[KNN 2 @v $vec as score]").no_content() + query_params = {"vec": np.array(vectors[0], dtype=np.float16).tobytes()} + + res = client.ft().search(query, query_params=query_params) + if is_resp2_connection(client): + assert res.total == 2 + assert "doc0" == res.docs[0].id + else: + assert res["total_results"] == 2 + assert "doc0" == res["results"][0]["id"] + + +@pytest.mark.redismod +@skip_ifmodversion_lt("2.4.3", "search") +@skip_if_server_version_lt("8.1.224") +def test_svs_vamana_float32_type(client): + client.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + {"TYPE": "FLOAT32", "DIM": 4, "DISTANCE_METRIC": "L2"}, + ), + ) + ) + + vectors = [[1.0, 2.0, 3.0, 4.0], [2.0, 3.0, 4.0, 5.0], [3.0, 4.0, 5.0, 6.0]] + + for i, vec in enumerate(vectors): + client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) + + query = Query("*=>[KNN 2 @v $vec as score]").no_content() + query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} + + res = client.ft().search(query, query_params=query_params) + if is_resp2_connection(client): + assert res.total == 2 + assert "doc0" == res.docs[0].id + else: + assert res["total_results"] == 2 + assert "doc0" == res["results"][0]["id"] + + +@pytest.mark.redismod +@skip_ifmodversion_lt("2.4.3", "search") +@skip_if_server_version_lt("8.1.224") +def test_svs_vamana_vector_search_with_default_dialect(client): + client.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + {"TYPE": "FLOAT32", "DIM": 2, "DISTANCE_METRIC": "L2"}, + ), + ) + ) + + client.hset("a", "v", "aaaaaaaa") + client.hset("b", "v", "aaaabaaa") + client.hset("c", "v", "aaaaabaa") + + query = "*=>[KNN 2 @v $vec]" + q = Query(query).return_field("__v_score").sort_by("__v_score", True) + res = client.ft().search(q, query_params={"vec": "aaaaaaaa"}) + + if is_resp2_connection(client): + assert res.total == 2 + else: + assert res["total_results"] == 2 + + +@pytest.mark.redismod +@skip_ifmodversion_lt("2.4.3", "search") +@skip_if_server_version_lt("8.1.224") +def test_svs_vamana_vector_field_basic(): + field = VectorField( + "v", "SVS-VAMANA", {"TYPE": "FLOAT32", "DIM": 128, "DISTANCE_METRIC": "COSINE"} + ) + + # Check that the field was created successfully + assert field.name == "v" + assert field.args[0] == "VECTOR" + assert field.args[1] == "SVS-VAMANA" + assert field.args[2] == 6 + assert "TYPE" in field.args + assert "FLOAT32" in field.args + assert "DIM" in field.args + assert 128 in field.args + assert "DISTANCE_METRIC" in field.args + assert "COSINE" in field.args + + +@pytest.mark.redismod +@skip_ifmodversion_lt("2.4.3", "search") +@skip_if_server_version_lt("8.1.224") +def test_svs_vamana_lvq8_compression(client): + client.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + { + "TYPE": "FLOAT32", + "DIM": 8, + "DISTANCE_METRIC": "L2", + "COMPRESSION": "LVQ8", + "TRAINING_THRESHOLD": 1024, + }, + ), + ) + ) + + vectors = [] + for i in range(20): + vec = [float(i + j) for j in range(8)] + vectors.append(vec) + client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) + + query = Query("*=>[KNN 5 @v $vec as score]").no_content() + query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} + + res = client.ft().search(query, query_params=query_params) + if is_resp2_connection(client): + assert res.total == 5 + assert "doc0" == res.docs[0].id + else: + assert res["total_results"] == 5 + assert "doc0" == res["results"][0]["id"] + + +@pytest.mark.redismod +@skip_ifmodversion_lt("2.4.3", "search") +@skip_if_server_version_lt("8.1.224") +def test_svs_vamana_compression_with_both_vector_types(client): + # Test FLOAT16 with LVQ8 + client.ft("idx16").create_index( + ( + VectorField( + "v16", + "SVS-VAMANA", + { + "TYPE": "FLOAT16", + "DIM": 8, + "DISTANCE_METRIC": "L2", + "COMPRESSION": "LVQ8", + "TRAINING_THRESHOLD": 1024, + }, + ), + ) + ) + + # Test FLOAT32 with LVQ8 + client.ft("idx32").create_index( + ( + VectorField( + "v32", + "SVS-VAMANA", + { + "TYPE": "FLOAT32", + "DIM": 8, + "DISTANCE_METRIC": "L2", + "COMPRESSION": "LVQ8", + "TRAINING_THRESHOLD": 1024, + }, + ), + ) + ) + + # Add data to both indices + for i in range(15): + vec = [float(i + j) for j in range(8)] + client.hset(f"doc16_{i}", "v16", np.array(vec, dtype=np.float16).tobytes()) + client.hset(f"doc32_{i}", "v32", np.array(vec, dtype=np.float32).tobytes()) + + # Test both indices + query = Query("*=>[KNN 3 @v16 $vec as score]").no_content() + res16 = client.ft("idx16").search( + query, + query_params={ + "vec": np.array( + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], dtype=np.float16 + ).tobytes() + }, + ) + + query = Query("*=>[KNN 3 @v32 $vec as score]").no_content() + res32 = client.ft("idx32").search( + query, + query_params={ + "vec": np.array( + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], dtype=np.float32 + ).tobytes() + }, + ) + + if is_resp2_connection(client): + assert res16.total == 3 + assert res32.total == 3 + else: + assert res16["total_results"] == 3 + assert res32["total_results"] == 3 + + +@pytest.mark.redismod +@skip_ifmodversion_lt("2.4.3", "search") +@skip_if_server_version_lt("8.1.224") +def test_svs_vamana_construction_window_size(client): + client.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + { + "TYPE": "FLOAT32", + "DIM": 6, + "DISTANCE_METRIC": "L2", + "CONSTRUCTION_WINDOW_SIZE": 300, + }, + ), + ) + ) + + vectors = [] + for i in range(20): + vec = [float(i + j) for j in range(6)] + vectors.append(vec) + client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) + + query = Query("*=>[KNN 5 @v $vec as score]").no_content() + query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} + + res = client.ft().search(query, query_params=query_params) + if is_resp2_connection(client): + assert res.total == 5 + assert "doc0" == res.docs[0].id + else: + assert res["total_results"] == 5 + assert "doc0" == res["results"][0]["id"] + + +@pytest.mark.redismod +@skip_ifmodversion_lt("2.4.3", "search") +@skip_if_server_version_lt("8.1.224") +def test_svs_vamana_graph_max_degree(client): + client.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + { + "TYPE": "FLOAT32", + "DIM": 6, + "DISTANCE_METRIC": "COSINE", + "GRAPH_MAX_DEGREE": 64, + }, + ), + ) + ) + + vectors = [] + for i in range(25): + vec = [float(i + j) for j in range(6)] + vectors.append(vec) + client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) + + query = Query("*=>[KNN 6 @v $vec as score]").no_content() + query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} + + res = client.ft().search(query, query_params=query_params) + if is_resp2_connection(client): + assert res.total == 6 + assert "doc0" == res.docs[0].id + else: + assert res["total_results"] == 6 + assert "doc0" == res["results"][0]["id"] + + +@pytest.mark.redismod +@skip_ifmodversion_lt("2.4.3", "search") +@skip_if_server_version_lt("8.1.224") +def test_svs_vamana_search_window_size(client): + client.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + { + "TYPE": "FLOAT32", + "DIM": 6, + "DISTANCE_METRIC": "L2", + "SEARCH_WINDOW_SIZE": 20, + }, + ), + ) + ) + + vectors = [] + for i in range(30): + vec = [float(i + j) for j in range(6)] + vectors.append(vec) + client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) + + query = Query("*=>[KNN 8 @v $vec as score]").no_content() + query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} + + res = client.ft().search(query, query_params=query_params) + if is_resp2_connection(client): + assert res.total == 8 + assert "doc0" == res.docs[0].id + else: + assert res["total_results"] == 8 + assert "doc0" == res["results"][0]["id"] + + +@pytest.mark.redismod +@skip_ifmodversion_lt("2.4.3", "search") +@skip_if_server_version_lt("8.1.224") +def test_svs_vamana_epsilon_parameter(client): + client.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + {"TYPE": "FLOAT32", "DIM": 6, "DISTANCE_METRIC": "L2", "EPSILON": 0.05}, + ), + ) + ) + + vectors = [] + for i in range(20): + vec = [float(i + j) for j in range(6)] + vectors.append(vec) + client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) + + query = Query("*=>[KNN 5 @v $vec as score]").no_content() + query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} + + res = client.ft().search(query, query_params=query_params) + if is_resp2_connection(client): + assert res.total == 5 + assert "doc0" == res.docs[0].id + else: + assert res["total_results"] == 5 + assert "doc0" == res["results"][0]["id"] + + +@pytest.mark.redismod +@skip_ifmodversion_lt("2.4.3", "search") +@skip_if_server_version_lt("8.1.224") +def test_svs_vamana_all_build_parameters_combined(client): + client.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + { + "TYPE": "FLOAT32", + "DIM": 8, + "DISTANCE_METRIC": "IP", + "CONSTRUCTION_WINDOW_SIZE": 250, + "GRAPH_MAX_DEGREE": 48, + "SEARCH_WINDOW_SIZE": 15, + "EPSILON": 0.02, + }, + ), + ) + ) + + vectors = [] + for i in range(35): + vec = [float(i + j) for j in range(8)] + vectors.append(vec) + client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) + + query = Query("*=>[KNN 7 @v $vec as score]").no_content() + query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} + + res = client.ft().search(query, query_params=query_params) + if is_resp2_connection(client): + assert res.total == 7 + doc_ids = [doc.id for doc in res.docs] + assert len(doc_ids) == 7 + else: + assert res["total_results"] == 7 + doc_ids = [doc["id"] for doc in res["results"]] + assert len(doc_ids) == 7 + + +@pytest.mark.redismod +@skip_ifmodversion_lt("2.4.3", "search") +@skip_if_server_version_lt("8.1.224") +def test_svs_vamana_comprehensive_configuration(client): + client.flushdb() + client.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + { + "TYPE": "FLOAT16", + "DIM": 32, + "DISTANCE_METRIC": "COSINE", + "COMPRESSION": "LVQ8", + "CONSTRUCTION_WINDOW_SIZE": 400, + "GRAPH_MAX_DEGREE": 96, + "SEARCH_WINDOW_SIZE": 25, + "EPSILON": 0.03, + "TRAINING_THRESHOLD": 2048, + }, + ), + ) + ) + + vectors = [] + for i in range(60): + vec = [float(i + j) for j in range(32)] + vectors.append(vec) + client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float16).tobytes()) + + query = Query("*=>[KNN 10 @v $vec as score]").no_content() + query_params = {"vec": np.array(vectors[0], dtype=np.float16).tobytes()} + + res = client.ft().search(query, query_params=query_params) + if is_resp2_connection(client): + assert res.total == 10 + assert "doc0" == res.docs[0].id + else: + assert res["total_results"] == 10 + assert "doc0" == res["results"][0]["id"] + + +@pytest.mark.redismod +@skip_ifmodversion_lt("2.4.3", "search") +@skip_if_server_version_lt("8.1.224") +def test_svs_vamana_hybrid_text_vector_search(client): + client.flushdb() + client.ft().create_index( + ( + TextField("title"), + TextField("content"), + VectorField( + "embedding", + "SVS-VAMANA", + { + "TYPE": "FLOAT32", + "DIM": 6, + "DISTANCE_METRIC": "COSINE", + "SEARCH_WINDOW_SIZE": 20, + }, + ), + ) + ) + + docs = [ + { + "title": "AI Research", + "content": "machine learning algorithms", + "embedding": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0], + }, + { + "title": "Data Science", + "content": "statistical analysis methods", + "embedding": [2.0, 3.0, 4.0, 5.0, 6.0, 7.0], + }, + { + "title": "Deep Learning", + "content": "neural network architectures", + "embedding": [3.0, 4.0, 5.0, 6.0, 7.0, 8.0], + }, + { + "title": "Computer Vision", + "content": "image processing techniques", + "embedding": [10.0, 11.0, 12.0, 13.0, 14.0, 15.0], + }, + ] + + for i, doc in enumerate(docs): + client.hset( + f"doc{i}", + mapping={ + "title": doc["title"], + "content": doc["content"], + "embedding": np.array(doc["embedding"], dtype=np.float32).tobytes(), + }, + ) + + # Hybrid query - text filter + vector similarity + query = "(@title:AI|@content:machine)=>[KNN 2 @embedding $vec]" + q = ( + Query(query) + .return_field("__embedding_score") + .sort_by("__embedding_score", True) + ) + res = client.ft().search( + q, + query_params={ + "vec": np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], dtype=np.float32).tobytes() + }, + ) + + if is_resp2_connection(client): + assert res.total >= 1 + doc_ids = [doc.id for doc in res.docs] + assert "doc0" in doc_ids + else: + assert res["total_results"] >= 1 + doc_ids = [doc["id"] for doc in res["results"]] + assert "doc0" in doc_ids + + +@pytest.mark.redismod +@skip_ifmodversion_lt("2.4.3", "search") +@skip_if_server_version_lt("8.1.224") +def test_svs_vamana_large_dimension_vectors(client): + client.flushdb() + client.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + { + "TYPE": "FLOAT32", + "DIM": 512, + "DISTANCE_METRIC": "L2", + "CONSTRUCTION_WINDOW_SIZE": 300, + "GRAPH_MAX_DEGREE": 64, + }, + ), + ) + ) + + vectors = [] + for i in range(10): + vec = [float(i + j) for j in range(512)] + vectors.append(vec) + client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) + + query = Query("*=>[KNN 5 @v $vec as score]").no_content() + query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} + + res = client.ft().search(query, query_params=query_params) + if is_resp2_connection(client): + assert res.total == 5 + assert "doc0" == res.docs[0].id + else: + assert res["total_results"] == 5 + assert "doc0" == res["results"][0]["id"] + + +@pytest.mark.redismod +@skip_ifmodversion_lt("2.4.3", "search") +@skip_if_server_version_lt("8.1.224") +def test_svs_vamana_training_threshold_behavior(client): + client.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + { + "TYPE": "FLOAT32", + "DIM": 8, + "DISTANCE_METRIC": "L2", + "COMPRESSION": "LVQ8", + "TRAINING_THRESHOLD": 1024, + }, + ), + ) + ) + + vectors = [] + for i in range(20): + vec = [float(i + j) for j in range(8)] + vectors.append(vec) + client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) + + if i >= 5: + query = Query("*=>[KNN 3 @v $vec as score]").no_content() + query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} + res = client.ft().search(query, query_params=query_params) + + if is_resp2_connection(client): + assert res.total >= 1 + else: + assert res["total_results"] >= 1 + + +@pytest.mark.redismod +@skip_ifmodversion_lt("2.4.3", "search") +@skip_if_server_version_lt("8.1.224") +def test_svs_vamana_different_k_values(client): + client.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + { + "TYPE": "FLOAT32", + "DIM": 6, + "DISTANCE_METRIC": "L2", + "SEARCH_WINDOW_SIZE": 15, + }, + ), + ) + ) + + vectors = [] + for i in range(25): + vec = [float(i + j) for j in range(6)] + vectors.append(vec) + client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) + + for k in [1, 3, 5, 10, 15]: + query = Query(f"*=>[KNN {k} @v $vec as score]").no_content() + query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} + res = client.ft().search(query, query_params=query_params) + + if is_resp2_connection(client): + assert res.total == k + assert "doc0" == res.docs[0].id + else: + assert res["total_results"] == k + assert "doc0" == res["results"][0]["id"] + + +@pytest.mark.redismod +@skip_ifmodversion_lt("2.4.3", "search") +@skip_if_server_version_lt("8.1.224") +def test_svs_vamana_vector_field_error(client): + # sortable tag + with pytest.raises(Exception): + client.ft().create_index((VectorField("v", "SVS-VAMANA", {}, sortable=True),)) + + # no_index tag + with pytest.raises(Exception): + client.ft().create_index((VectorField("v", "SVS-VAMANA", {}, no_index=True),)) + + +@pytest.mark.redismod +@skip_ifmodversion_lt("2.4.3", "search") +@skip_if_server_version_lt("8.1.224") +def test_svs_vamana_vector_search_with_parameters(client): + client.ft().create_index( + ( + VectorField( + "v", + "SVS-VAMANA", + { + "TYPE": "FLOAT32", + "DIM": 4, + "DISTANCE_METRIC": "L2", + "CONSTRUCTION_WINDOW_SIZE": 200, + "GRAPH_MAX_DEGREE": 64, + "SEARCH_WINDOW_SIZE": 40, + "EPSILON": 0.01, + }, + ), + ) + ) + + # Create test vectors + vectors = [ + [1.0, 2.0, 3.0, 4.0], + [2.0, 3.0, 4.0, 5.0], + [3.0, 4.0, 5.0, 6.0], + [4.0, 5.0, 6.0, 7.0], + [5.0, 6.0, 7.0, 8.0], + ] + + for i, vec in enumerate(vectors): + client.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes()) + + query = Query("*=>[KNN 3 @v $vec as score]").no_content() + query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()} + + res = client.ft().search(query, query_params=query_params) + if is_resp2_connection(client): + assert res.total == 3 + assert "doc0" == res.docs[0].id + else: + assert res["total_results"] == 3 + assert "doc0" == res["results"][0]["id"]