Skip to content

MLE-19374 Updated cosine and annTopK functions #1779

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1507,11 +1507,11 @@ public interface ModifyPlan extends PreparePlan, PlanBuilderBase.ModifyPlanBase
* @param vectorColumn The column representing the vector ann-indexed column to perform the index lookup against.
* @param queryVector Specifies the query vector to perform the index lookup with.
* @param distanceColumn Optional output column that captures the values of the distance metric of the vectors retrieved from the index associated with vectorColumn and the queryVector.
* @param queryTolerance Specifies the query tolerance to help balance recall and search time. The value is between 0.0 and 1.0. At 0.0, the recall will be highest. At 1.0 the recall will likely see a large degradation, but queries will be quick. The default value is 0.0.
* @param options Optional sequence of strings or a map containing keys and values for the options to this operator.
* @return
* @since 7.1.0
* @since 7.2.0
*/
ModifyPlan annTopK(int k, PlanColumn vectorColumn, ServerExpression queryVector, PlanColumn distanceColumn, float queryTolerance);
ModifyPlan annTopK(int k, PlanColumn vectorColumn, ServerExpression queryVector, PlanColumn distanceColumn, Map<String, Object> options);

/**
* This method restricts the left row set to rows where a row with the same columns and values doesn't exist in the right row set.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,6 @@

package com.marklogic.client.expression;

import com.marklogic.client.type.XsAnyAtomicTypeSeqVal;
import com.marklogic.client.type.XsDoubleVal;
import com.marklogic.client.type.XsFloatVal;
import com.marklogic.client.type.XsStringVal;
import com.marklogic.client.type.XsUnsignedIntVal;
import com.marklogic.client.type.XsUnsignedLongVal;

import com.marklogic.client.type.ServerExpression;

// IMPORTANT: Do not edit. This file is generated.
Expand Down Expand Up @@ -59,15 +52,15 @@ public interface VecExpr {
/**
* Returns the cosine similarity between two vectors. The vectors must be of the same dimension.
*
* <a name="ml-server-type-cosine-similarity"></a>
* <a name="ml-server-type-cosine"></a>

* <p>
* Provides a client interface to the <a href="http://docs.marklogic.com/vec:cosine-similarity" target="mlserverdoc">vec:cosine-similarity</a> server function.
* Provides a client interface to the <a href="http://docs.marklogic.com/vec:cosine" target="mlserverdoc">vec:cosine</a> server function.
* @param vector1 The vector from which to calculate the cosine similarity with vector2. (of <a href="{@docRoot}/doc-files/types/vec_vector.html">vec:vector</a>)
* @param vector2 The vector from which to calculate the cosine similarity with vector1. (of <a href="{@docRoot}/doc-files/types/vec_vector.html">vec:vector</a>)
* @return a server expression with the <a href="{@docRoot}/doc-files/types/xs_double.html">xs:double</a> server data type
*/
public ServerExpression cosineSimilarity(ServerExpression vector1, ServerExpression vector2);
public ServerExpression cosine(ServerExpression vector1, ServerExpression vector2);
/**
* Returns the dimension of the vector passed in.
*
Expand Down Expand Up @@ -187,7 +180,7 @@ public interface VecExpr {
* <p>
* Provides a client interface to the <a href="http://docs.marklogic.com/vec:vector-score" target="mlserverdoc">vec:vector-score</a> server function.
* @param score The cts:score of the matching document. (of <a href="{@docRoot}/doc-files/types/xs_unsignedInt.html">xs:unsignedInt</a>)
* @param similarity The similarity between the vector in the matching document and the query vector. The result of a call to ovec:cosine-similarity(). In the case that the vectors are normalized, pass ovec:dot-product(). Note that vec:euclidean-distance() should not be used here. (of <a href="{@docRoot}/doc-files/types/xs_double.html">xs:double</a>)
* @param similarity The similarity between the vector in the matching document and the query vector. The result of a call to ovec:cosine(). In the case that the vectors are normalized, pass ovec:dot-product(). Note that vec:euclidean-distance() should not be used here. (of <a href="{@docRoot}/doc-files/types/xs_double.html">xs:double</a>)
* @return a server expression with the <a href="{@docRoot}/doc-files/types/xs_unsignedLong.html">xs:unsignedLong</a> server data type
*/
public ServerExpression vectorScore(ServerExpression score, double similarity);
Expand All @@ -199,7 +192,7 @@ public interface VecExpr {
* <p>
* Provides a client interface to the <a href="http://docs.marklogic.com/vec:vector-score" target="mlserverdoc">vec:vector-score</a> server function.
* @param score The cts:score of the matching document. (of <a href="{@docRoot}/doc-files/types/xs_unsignedInt.html">xs:unsignedInt</a>)
* @param similarity The similarity between the vector in the matching document and the query vector. The result of a call to ovec:cosine-similarity(). In the case that the vectors are normalized, pass ovec:dot-product(). Note that vec:euclidean-distance() should not be used here. (of <a href="{@docRoot}/doc-files/types/xs_double.html">xs:double</a>)
* @param similarity The similarity between the vector in the matching document and the query vector. The result of a call to ovec:cosine(). In the case that the vectors are normalized, pass ovec:dot-product(). Note that vec:euclidean-distance() should not be used here. (of <a href="{@docRoot}/doc-files/types/xs_double.html">xs:double</a>)
* @return a server expression with the <a href="{@docRoot}/doc-files/types/xs_unsignedLong.html">xs:unsignedLong</a> server data type
*/
public ServerExpression vectorScore(ServerExpression score, ServerExpression similarity);
Expand All @@ -208,7 +201,7 @@ public interface VecExpr {
* <p>
* Provides a client interface to the <a href="http://docs.marklogic.com/vec:vector-score" target="mlserverdoc">vec:vector-score</a> server function.
* @param score The cts:score of the matching document. (of <a href="{@docRoot}/doc-files/types/xs_unsignedInt.html">xs:unsignedInt</a>)
* @param similarity The similarity between the vector in the matching document and the query vector. The result of a call to ovec:cosine-similarity(). In the case that the vectors are normalized, pass ovec:dot-product(). Note that vec:euclidean-distance() should not be used here. (of <a href="{@docRoot}/doc-files/types/xs_double.html">xs:double</a>)
* @param similarity The similarity between the vector in the matching document and the query vector. The result of a call to ovec:cosine(). In the case that the vectors are normalized, pass ovec:dot-product(). Note that vec:euclidean-distance() should not be used here. (of <a href="{@docRoot}/doc-files/types/xs_double.html">xs:double</a>)
* @param similarityWeight The weight of the vector similarity on the score. The default value is 0.1. If 0.0 is passed in, vector similarity has no effect. If passed a value less than 0.0 or greater than 1.0, throw VEC-VECTORSCORE. (of <a href="{@docRoot}/doc-files/types/xs_double.html">xs:double</a>)
* @return a server expression with the <a href="{@docRoot}/doc-files/types/xs_unsignedLong.html">xs:unsignedLong</a> server data type
*/
Expand All @@ -218,7 +211,7 @@ public interface VecExpr {
* <p>
* Provides a client interface to the <a href="http://docs.marklogic.com/vec:vector-score" target="mlserverdoc">vec:vector-score</a> server function.
* @param score The cts:score of the matching document. (of <a href="{@docRoot}/doc-files/types/xs_unsignedInt.html">xs:unsignedInt</a>)
* @param similarity The similarity between the vector in the matching document and the query vector. The result of a call to ovec:cosine-similarity(). In the case that the vectors are normalized, pass ovec:dot-product(). Note that vec:euclidean-distance() should not be used here. (of <a href="{@docRoot}/doc-files/types/xs_double.html">xs:double</a>)
* @param similarity The similarity between the vector in the matching document and the query vector. The result of a call to ovec:cosine(). In the case that the vectors are normalized, pass ovec:dot-product(). Note that vec:euclidean-distance() should not be used here. (of <a href="{@docRoot}/doc-files/types/xs_double.html">xs:double</a>)
* @param similarityWeight The weight of the vector similarity on the score. The default value is 0.1. If 0.0 is passed in, vector similarity has no effect. If passed a value less than 0.0 or greater than 1.0, throw VEC-VECTORSCORE. (of <a href="{@docRoot}/doc-files/types/xs_double.html">xs:double</a>)
* @return a server expression with the <a href="{@docRoot}/doc-files/types/xs_unsignedLong.html">xs:unsignedLong</a> server data type
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -987,9 +987,9 @@ static class ModifyPlanSubImpl
}

@Override
public ModifyPlan annTopK(int k, PlanColumn vectorColumn, ServerExpression queryVector, PlanColumn distanceColumn, float queryTolerance) {
public ModifyPlan annTopK(int k, PlanColumn vectorColumn, ServerExpression queryVector, PlanColumn distanceColumn, Map<String, Object> options) {
return new PlanBuilderSubImpl.ModifyPlanSubImpl(this, "op", "annTopK", new Object[]{
k, vectorColumn, queryVector, distanceColumn, queryTolerance
k, vectorColumn, queryVector, distanceColumn, new BaseTypeImpl.BaseMapImpl(options)
});
}

Expand Down Expand Up @@ -1029,7 +1029,7 @@ public ModifyPlan facetBy(PlanNamedGroupSeq keys) {
}
@Override
public ModifyPlan facetBy(PlanNamedGroupSeq keys, String countCol) {
return facetBy(keys, (countCol == null) ? (PlanExprCol) null : exprCol(countCol));
return facetBy(keys, (countCol == null) ? null : exprCol(countCol));
}
@Override
public ModifyPlan facetBy(PlanNamedGroupSeq keys, PlanExprCol countCol) {
Expand Down Expand Up @@ -1100,7 +1100,7 @@ public ModifyPlan remove(PlanColumn uriColumn) {
}

static class TemporalRemoval implements BaseArgImpl {
private String template;
private final String template;

public TemporalRemoval(PlanColumn temporalCollection, PlanColumn uriColumn) {
this.template = String.format("{\"temporalCollection\":%s, \"uri\": %s}",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,14 @@ public ServerExpression base64Encode(ServerExpression vector1) {


@Override
public ServerExpression cosineSimilarity(ServerExpression vector1, ServerExpression vector2) {
public ServerExpression cosine(ServerExpression vector1, ServerExpression vector2) {
if (vector1 == null) {
throw new IllegalArgumentException("vector1 parameter for cosineSimilarity() cannot be null");
throw new IllegalArgumentException("vector1 parameter for cosine() cannot be null");
}
if (vector2 == null) {
throw new IllegalArgumentException("vector2 parameter for cosineSimilarity() cannot be null");
throw new IllegalArgumentException("vector2 parameter for cosine() cannot be null");
}
return new XsExprImpl.DoubleCallImpl("vec", "cosine-similarity", new Object[]{ vector1, vector2 });
return new XsExprImpl.DoubleCallImpl("vec", "cosine", new Object[]{ vector1, vector2 });
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;

import java.util.HashMap;
import java.util.List;
import java.util.Map;

import static org.junit.jupiter.api.Assertions.*;

Expand All @@ -38,7 +40,7 @@ void vectorFunctionsHappyPath() {
PlanBuilder.ModifyPlan plan =
op.fromView("vectors", "persons")
.bind(op.as("sampleVector", op.vec.vector(sampleVector)))
.bind(op.as("cosineSimilarity", op.vec.cosineSimilarity(op.col("embedding"), op.col("sampleVector"))))
.bind(op.as("cosine", op.vec.cosine(op.col("embedding"), op.col("sampleVector"))))
.bind(op.as("dotProduct", op.vec.dotProduct(op.col("embedding"), op.col("sampleVector"))))
.bind(op.as("euclideanDistance", op.vec.euclideanDistance(op.col("embedding"), op.col("sampleVector"))))
.bind(op.as("dimension", op.vec.dimension(op.col("sampleVector"))))
Expand All @@ -52,7 +54,7 @@ void vectorFunctionsHappyPath() {
.bind(op.as("subVector", op.vec.subvector(op.col("sampleVector"), op.xs.integer(1), op.xs.integer(1))))
.bind(op.as("vectorScore", op.vec.vectorScore(op.xs.unsignedInt(1), op.xs.doubleVal(0.5))))
.select(
op.col("cosineSimilarity"), op.col("dotProduct"), op.col("euclideanDistance"),
op.col("cosine"), op.col("dotProduct"), op.col("euclideanDistance"),
op.col("name"), op.col("dimension"), op.col("normalize"),
op.col("magnitude"), op.col("get"), op.col("add"), op.col("subtract"),
op.col("base64Encode"), op.col("base64Decode"), op.col("subVector"), op.col("vectorScore")
Expand All @@ -63,8 +65,8 @@ void vectorFunctionsHappyPath() {

rows.forEach(row -> {
// Simple a sanity checks to verify that the functions ran. Very little concern about the actual return values.
double cosineSimilarity = row.getDouble("cosineSimilarity");
assertTrue((cosineSimilarity > 0) && (cosineSimilarity < 1), "Unexpected value: " + cosineSimilarity);
double cosine = row.getDouble("cosine");
assertTrue((cosine > 0) && (cosine < 1), "Unexpected value: " + cosine);
double dotProduct = row.getDouble("dotProduct");
Assertions.assertTrue(dotProduct > 0, "Unexpected value: " + dotProduct);
double euclideanDistance = row.getDouble("euclideanDistance");
Expand All @@ -85,25 +87,25 @@ void vectorFunctionsHappyPath() {
}

@Test
void cosineSimilarity_DimensionMismatch() {
void cosine_DimensionMismatch() {
PlanBuilder.ModifyPlan plan =
op.fromView("vectors", "persons")
.bind(op.as("sampleVector", op.vec.vector(twoDimensionalVector)))
.bind(op.as("cosineSimilarity", op.vec.cosineSimilarity(op.col("embedding"), op.col("sampleVector"))))
.select(op.col("name"), op.col("summary"), op.col("cosineSimilarity"));
.bind(op.as("cosine", op.vec.cosine(op.col("embedding"), op.col("sampleVector"))))
.select(op.col("name"), op.col("summary"), op.col("cosine"));
Exception exception = assertThrows(FailedRequestException.class, () -> resultRows(plan));
String actualMessage = exception.getMessage();
assertTrue(actualMessage.contains("Server Message: VEC-DIMMISMATCH"), "Unexpected message: " + actualMessage);
assertTrue(actualMessage.contains("Mismatched dimension"), "Unexpected message: " + actualMessage);
}

@Test
void cosineSimilarity_InvalidVector() {
void cosine_InvalidVector() {
PlanBuilder.ModifyPlan plan =
op.fromView("vectors", "persons")
.bind(op.as("sampleVector", invalidVector))
.bind(op.as("cosineSimilarity", op.vec.cosineSimilarity(op.col("embedding"), op.col("sampleVector"))))
.select(op.col("name"), op.col("summary"), op.col("cosineSimilarity"));
.bind(op.as("cosine", op.vec.cosine(op.col("embedding"), op.col("sampleVector"))))
.select(op.col("name"), op.col("summary"), op.col("cosine"));
Exception exception = assertThrows(FailedRequestException.class, () -> resultRows(plan));
String actualMessage = exception.getMessage();
assertTrue(actualMessage.contains("Server Message: XDMP-ARGTYPE"), "Unexpected message: " + actualMessage);
Expand Down Expand Up @@ -139,10 +141,16 @@ void vecVectorWithCol() {
assertEquals(2, rows.size());
}

/**
* Updated after 2025-06-06, when the vector functions were updated. That includes annTopK being modified to accept
* an options map as its 5th argument instead of a single query tolerance value.
*/
@Test
void annTopK() {
void annTopKWithOptionsMap() {
Map<String, Object> options = new HashMap<>();
options.put("distance", "cosine");
PlanBuilder.ModifyPlan plan = op.fromView("vectors", "persons")
.annTopK(10, op.col("embedding"), op.vec.vector(sampleVector), op.col("distance"), 0.5f);
.annTopK(10, op.col("embedding"), op.vec.vector(sampleVector), op.col("distance"), options);

List<RowRecord> rows = resultRows(plan);
assertEquals(2, rows.size(), "Verifying that annTopK worked and returned both rows from the view.");
Expand All @@ -158,7 +166,7 @@ void dslAnnTopK() {
String query = "const qualityVector = vec.vector([ 1.1, 2.2, 3.3 ]);\n" +
"op.fromView('vectors', 'persons')\n" +
" .bind(op.as('myVector', op.vec.vector(op.col('embedding'))))\n" +
" .annTopK(2, op.col('myVector'), qualityVector, op.col('distance'), 0.5)";
" .annTopK(2, op.col('myVector'), qualityVector, op.col('distance'), {'distance':'cosine'})";

RawQueryDSLPlan plan = rowManager.newRawQueryDSLPlan(new StringHandle(query));
List<RowRecord> rows = resultRows(plan);
Expand Down