Skip to content

Add tests for vector search INT8/UINT8 types #4091

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Mar 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,7 @@ start: cleanup compile-module
echo "$$REDIS_UDS" | redis-server -
echo "$$REDIS_UNAVAILABLE_CONF" | redis-server -
redis-cli -a cluster --cluster create 127.0.0.1:7479 127.0.0.1:7480 127.0.0.1:7481 --cluster-yes
docker run -p 6479:6379 --name jedis-stack -d redis/redis-stack-server:edge
docker run -p 6479:6379 --name jedis-stack -e PORT=6379 -d redislabs/client-libs-test:8.0-M04-pre

cleanup:
- rm -vf /tmp/redis_cluster_node*.conf 2>/dev/null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.fail;
import static redis.clients.jedis.util.RedisConditions.ModuleVersion.SEARCH_MOD_VER_80M3;

import io.redis.test.annotations.SinceRedisVersion;
import io.redis.test.utils.RedisVersion;
Expand Down Expand Up @@ -39,11 +40,6 @@ public class AggregationTest extends RedisModuleCommandsTestBase {
public static void prepare() {
RedisModuleCommandsTestBase.prepare();
}
//
// @AfterClass
// public static void tearDown() {
//// RedisModuleCommandsTestBase.tearDown();
// }

public AggregationTest(RedisProtocol redisProtocol) {
super(redisProtocol);
Expand Down Expand Up @@ -205,7 +201,7 @@ public void testAggregationBuilderAddScores() {
.apply("@__score * 100", "normalized_score").dialect(3);

AggregationResult res = client.ftAggregate(index, r);
if (RedisConditions.of(client).moduleVersionIsGreatherThan("SEARCH", 79900)) {
if (RedisConditions.of(client).moduleVersionIsGreaterThanOrEqual(SEARCH_MOD_VER_80M3)) {
// Default scorer is BM25
assertEquals(0.6931, res.getRow(0).getDouble("__score"), 0.0001);
assertEquals(69.31, res.getRow(0).getDouble("normalized_score"), 0.01);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,17 @@

import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.Assert.*;
import static org.junit.Assume.assumeTrue;
import static redis.clients.jedis.util.AssertUtil.assertOK;
import static redis.clients.jedis.util.RedisConditions.ModuleVersion.SEARCH_MOD_VER_80M3;

import java.util.*;
import java.util.stream.Collectors;

import io.redis.test.annotations.SinceRedisVersion;
import io.redis.test.utils.RedisVersion;
import org.hamcrest.Matchers;
import org.junit.After;
import org.junit.BeforeClass;
import org.junit.Test;
import org.junit.runner.RunWith;
Expand All @@ -33,6 +36,7 @@
import redis.clients.jedis.search.schemafields.GeoShapeField.CoordinateSystem;
import redis.clients.jedis.search.schemafields.VectorField.VectorAlgorithm;
import redis.clients.jedis.modules.RedisModuleCommandsTestBase;
import redis.clients.jedis.util.RedisConditions;
import redis.clients.jedis.util.RedisVersionUtil;

@RunWith(Parameterized.class)
Expand All @@ -44,11 +48,13 @@ public class SearchWithParamsTest extends RedisModuleCommandsTestBase {
public static void prepare() {
RedisModuleCommandsTestBase.prepare();
}
//
// @AfterClass
// public static void tearDown() {
//// RedisModuleCommandsTestBase.tearDown();
// }

@After
public void cleanUp() {
if (client.ftList().contains(index)) {
client.ftDropIndex(index);
}
}

public SearchWithParamsTest(RedisProtocol protocol) {
super(protocol);
Expand Down Expand Up @@ -1248,6 +1254,32 @@ public void testFlatVectorSimilarity() {
assertEquals("0", doc1.get("__v_score"));
}

@Test
public void testFlatVectorSimilarityInt8() {
assumeTrue("INT8",
RedisConditions.of(client).moduleVersionIsGreaterThanOrEqual(SEARCH_MOD_VER_80M3));
assertOK(client.ftCreate(index,
VectorField.builder().fieldName("v").algorithm(VectorAlgorithm.FLAT)
.addAttribute("TYPE", "INT8").addAttribute("DIM", 2)
.addAttribute("DISTANCE_METRIC", "L2").build()));

byte[] a = { 127, 1 };
byte[] b = { 127, 10 };
byte[] c = { 127, 100 };

client.hset("a".getBytes(), "v".getBytes(), a);
client.hset("b".getBytes(), "v".getBytes(), b);
client.hset("c".getBytes(), "v".getBytes(), c);

FTSearchParams searchParams = FTSearchParams.searchParams().addParam("vec", a)
.sortBy("__v_score", SortingOrder.ASC).returnFields("__v_score");

Document doc1 = client.ftSearch(index, "*=>[KNN 2 @v $vec]", searchParams).getDocuments()
.get(0);
assertEquals("a", doc1.getId());
assertEquals("0", doc1.get("__v_score"));
}

@Test
@SinceRedisVersion(value = "7.4.0", message = "no optional params before 7.4.0")
public void vectorFieldParams() {
Expand Down Expand Up @@ -1286,6 +1318,26 @@ public void bfloat16StorageType() {
.build()));
}

@Test
public void int8StorageType() {
assumeTrue("INT8",
RedisConditions.of(client).moduleVersionIsGreaterThanOrEqual(SEARCH_MOD_VER_80M3));
assertOK(client.ftCreate(index,
VectorField.builder().fieldName("v").algorithm(VectorAlgorithm.HNSW)
.addAttribute("TYPE", "INT8").addAttribute("DIM", 4)
.addAttribute("DISTANCE_METRIC", "L2").build()));
}

@Test
public void uint8StorageType() {
assumeTrue("UINT8",
RedisConditions.of(client).moduleVersionIsGreaterThanOrEqual(SEARCH_MOD_VER_80M3));
assertOK(client.ftCreate(index,
VectorField.builder().fieldName("v").algorithm(VectorAlgorithm.HNSW)
.addAttribute("TYPE", "UINT8").addAttribute("DIM", 4)
.addAttribute("DISTANCE_METRIC", "L2").build()));
}

@Test
public void searchProfile() {
assertOK(client.ftCreate(index, TextField.of("t1"), TextField.of("t2")));
Expand Down
51 changes: 41 additions & 10 deletions src/test/java/redis/clients/jedis/util/RedisConditions.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,33 @@

public class RedisConditions {

public enum ModuleVersion {

SEARCH_MOD_VER_80M3("SEARCH", 79903);

private final String moduleName;
private final int version;

ModuleVersion(String moduleName, int version) {
this.moduleName = moduleName;
this.version = version;
}

public String getModuleName() {
return moduleName;
}

public int getVersion() {
return version;
}
}

private final RedisVersion version;
private final Map<String, Integer> modules;
private final Map<String, CommandInfo> commands;

private RedisConditions(RedisVersion version, Map< String, CommandInfo> commands, Map<String, Integer> modules) {
private RedisConditions(RedisVersion version, Map<String, CommandInfo> commands,
Map<String, Integer> modules) {
this.version = version;
this.commands = commands;
this.modules = modules;
Expand All @@ -31,15 +53,14 @@ private RedisConditions(RedisVersion version, Map< String, CommandInfo> commands
public static RedisConditions of(UnifiedJedis jedis) {
RedisVersion version = RedisVersionUtil.getRedisVersion(jedis);

CommandObject<Map<String, CommandInfo>> commandInfoCmd
= new CommandObject<>(new CommandArguments(COMMAND), CommandInfo.COMMAND_INFO_RESPONSE);
CommandObject<Map<String, CommandInfo>> commandInfoCmd = new CommandObject<>(
new CommandArguments(COMMAND), CommandInfo.COMMAND_INFO_RESPONSE);
Map<String, CommandInfo> commands = jedis.executeCommand(commandInfoCmd);

CommandObject<List<Module>> moduleListCmd
= new CommandObject<>(new CommandArguments(MODULE).add(LIST), MODULE_LIST);
CommandObject<List<Module>> moduleListCmd = new CommandObject<>(
new CommandArguments(MODULE).add(LIST), MODULE_LIST);

Map<String, Integer> modules = jedis.executeCommand(moduleListCmd)
.stream()
Map<String, Integer> modules = jedis.executeCommand(moduleListCmd).stream()
.collect(Collectors.toMap((m) -> m.getName().toUpperCase(), Module::getVersion));

return new RedisConditions(version, commands, modules);
Expand Down Expand Up @@ -68,10 +89,20 @@ public boolean hasModule(String module) {
/**
* @param module
* @param version
* @return {@code true} if the module is present.
* @return {@code true} if the module with the requested minimum version is present.
*/
public boolean moduleVersionIsGreatherThan(String module, int version) {
public boolean moduleVersionIsGreaterThanOrEqual(String module, int version) {
Integer moduleVersion = modules.get(module.toUpperCase());
return moduleVersion != null && moduleVersion > version;
return moduleVersion != null && moduleVersion >= version;
}

/**
* @param moduleVersion
* @return {@code true} if the module version is greater than or equal to the specified version.
*/
public boolean moduleVersionIsGreaterThanOrEqual(ModuleVersion moduleVersion) {
return moduleVersionIsGreaterThanOrEqual(moduleVersion.getModuleName(),
moduleVersion.getVersion());
}

}
Loading