Skip to content

Commit 75006f2

Browse files
authored
Merge pull request #373 from marklogic/feature/zero-vector-fix
Not adding invalid embeddings
2 parents 6a70ed7 + 2782324 commit 75006f2

File tree

5 files changed

+79
-15
lines changed

5 files changed

+79
-15
lines changed

marklogic-langchain4j/src/main/java/com/marklogic/langchain4j/embedding/EmbeddingGenerator.java

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,11 +97,34 @@ private void addEmbeddingsToChunks(List<Chunk> chunks) {
9797
} else {
9898
List<Embedding> embeddings = response.content();
9999
for (int i = 0; i < embeddings.size(); i++) {
100-
chunks.get(i).addEmbedding(embeddings.get(i));
100+
addEmbeddingToChunk(chunks.get(i), embeddings.get(i));
101101
}
102102
}
103103
}
104104

105+
private void addEmbeddingToChunk(Chunk chunk, Embedding embedding) {
106+
if (vectorIsAllZeroes(embedding)) {
107+
if (Util.LANGCHAIN4J_LOGGER.isDebugEnabled()) {
108+
Util.LANGCHAIN4J_LOGGER.debug("Not adding embedding to chunk as it only contains zeroes; source document URI: {}; text: {}",
109+
chunk.getDocumentUri(), chunk.getEmbeddingText());
110+
} else {
111+
Util.LANGCHAIN4J_LOGGER.warn("Not adding embedding to chunk as it only contains zeroes; source document URI: {}",
112+
chunk.getDocumentUri());
113+
}
114+
} else {
115+
chunk.addEmbedding(embedding);
116+
}
117+
}
118+
119+
private boolean vectorIsAllZeroes(Embedding embedding) {
120+
for (float f : embedding.vector()) {
121+
if (f != 0.0f) {
122+
return false;
123+
}
124+
}
125+
return true;
126+
}
127+
105128
private List<TextSegment> makeTextSegments(List<Chunk> chunks) {
106129
return chunks.stream()
107130
.map(chunk -> new TextSegment(chunk.getEmbeddingText(), TEXT_SEGMENT_METADATA))

marklogic-spark-connector/src/test/java/com/marklogic/spark/writer/embedding/AddEmbeddingsToJsonTest.java

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,42 @@ void addEmbeddingsToExistingSplits() {
103103
verifyEachChunkIsReturnedByAVectorQuery();
104104
}
105105

106+
@ExtendWith(RequiresMarkLogic12.class)
107+
@Test
108+
void vectorHasAllZeroes() {
109+
readDocument("/marklogic-docs/java-client-intro.json")
110+
.repartition(1)
111+
.write().format(CONNECTOR_IDENTIFIER)
112+
.option(Options.CLIENT_URI, makeClientUri())
113+
.option(Options.WRITE_SPLITTER_JSON_POINTERS, "/text")
114+
.option(Options.WRITE_SPLITTER_SIDECAR_MAX_CHUNKS, 10)
115+
.option(Options.WRITE_SPLITTER_SIDECAR_COLLECTIONS, "json-vector-chunks")
116+
.option(Options.WRITE_PERMISSIONS, DEFAULT_PERMISSIONS)
117+
.option(Options.WRITE_URI_TEMPLATE, "/split-test.json")
118+
.option(Options.WRITE_EMBEDDER_MODEL_FUNCTION_CLASS_NAME, "com.marklogic.spark.writer.embedding.TestEmbeddingModel")
119+
.option(Options.WRITE_EMBEDDER_MODEL_FUNCTION_OPTION_PREFIX + "returnZeroesOnFirstCall", "true")
120+
.mode(SaveMode.Append)
121+
.save();
122+
123+
JsonNode doc = readJsonDocument("/split-test.json-chunks-1.json");
124+
JsonNode firstChunk = doc.get("chunks").get(0);
125+
assertFalse(firstChunk.has("embedding"), "The first chunk is given an array of all zeroes by the test " +
126+
"embedding model. Flux should recognize this and not add the `embedding` field, as doing so will cause " +
127+
"issues with the Optic vector library - specifically, a VEC-MAGNITUDEZERO error at least when using " +
128+
"vec.cosineSimilarity and then sorting on the values. A future version of MarkLogic 12 may improve this " +
129+
"by allowing for an array of zeroes to be rejected.");
130+
131+
JsonNode secondChunk = doc.get("chunks").get(1);
132+
assertTrue(secondChunk.has("embedding"), "The test embedding model should generate a valid embedding for " +
133+
"the second chunk, which means it can be queried next using Optic.");
134+
135+
RowManager rowManager = getDatabaseClient().newRowManager();
136+
RowSet<RowRecord> rows = rowManager.resultRows(rowManager.newPlanBuilder().fromView("example", "json_chunks"));
137+
assertEquals(1, rows.stream().count(), "The TDE has nullable=false for the embedding column, as a null " +
138+
"vector will cause issues when querying on vectors. And since invalidValues=ignore, the first chunk " +
139+
"won't be returned; only the second chunk will be.");
140+
}
141+
106142
@Test
107143
void passOptionsToEmbeddingModelFunction() {
108144
DataFrameWriter writer = readDocument("/marklogic-docs/java-client-intro.json")

marklogic-spark-connector/src/test/java/com/marklogic/spark/writer/embedding/TestEmbeddingModel.java

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import dev.langchain4j.data.embedding.Embedding;
88
import dev.langchain4j.data.segment.TextSegment;
99
import dev.langchain4j.model.embedding.EmbeddingModel;
10+
import dev.langchain4j.model.embedding.onnx.allminilml6v2.AllMiniLmL6V2EmbeddingModel;
1011
import dev.langchain4j.model.output.Response;
1112

1213
import java.util.Arrays;
@@ -27,8 +28,13 @@ public static void reset() {
2728
chunkCounter = 0;
2829
}
2930

31+
private static AllMiniLmL6V2EmbeddingModel realEmbeddingModel = new AllMiniLmL6V2EmbeddingModel();
32+
33+
private boolean returnZeroesOnFirstCall;
34+
3035
@Override
3136
public EmbeddingModel apply(Map<String, String> options) {
37+
returnZeroesOnFirstCall = "true".equals(options.get("returnZeroesOnFirstCall"));
3238
return this;
3339
}
3440

@@ -41,7 +47,11 @@ public int dimension() {
4147
public Response<List<Embedding>> embedAll(List<TextSegment> textSegments) {
4248
batchCounter++;
4349
chunkCounter += textSegments.size();
44-
return Response.from(Arrays.asList(new Embedding(new float[]{1})));
50+
if (returnZeroesOnFirstCall) {
51+
returnZeroesOnFirstCall = false;
52+
return Response.from(Arrays.asList(new Embedding(new float[384])));
53+
}
54+
return realEmbeddingModel.embedAll(textSegments);
4555
}
4656

4757
public static class TestChunk implements Chunk {

marklogic-spark-langchain4j/src/main/java/com/marklogic/spark/langchain4j/EmbeddingAdderFactory.java

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ public abstract class EmbeddingAdderFactory {
2121
public static Optional<EmbeddingAdder> makeEmbedder(Context context, DocumentTextSplitter splitter) {
2222
Optional<EmbeddingModel> embeddingModel = makeEmbeddingModel(context);
2323
if (embeddingModel.isPresent()) {
24-
EmbeddingGenerator embeddingGenerator = makeEmbeddingGenerator(context);
24+
EmbeddingGenerator embeddingGenerator = makeEmbeddingGenerator(context, embeddingModel.get());
2525
if (splitter != null) {
2626
return Optional.of(new EmbeddingAdder(splitter, embeddingGenerator));
2727
}
@@ -31,17 +31,12 @@ public static Optional<EmbeddingAdder> makeEmbedder(Context context, DocumentTex
3131
return Optional.empty();
3232
}
3333

34-
public static EmbeddingGenerator makeEmbeddingGenerator(Context context) {
35-
Optional<EmbeddingModel> embeddingModel = makeEmbeddingModel(context);
36-
if (embeddingModel.isPresent()) {
37-
int batchSize = context.getIntOption(Options.WRITE_EMBEDDER_BATCH_SIZE, 1, 1);
38-
EmbeddingModel model = embeddingModel.get();
39-
if (Util.MAIN_LOGGER.isInfoEnabled()) {
40-
Util.MAIN_LOGGER.info("Using embedding model with dimension: {}", model.dimension());
41-
}
42-
return new EmbeddingGenerator(model, batchSize);
34+
private static EmbeddingGenerator makeEmbeddingGenerator(Context context, EmbeddingModel model) {
35+
int batchSize = context.getIntOption(Options.WRITE_EMBEDDER_BATCH_SIZE, 1, 1);
36+
if (Util.MAIN_LOGGER.isInfoEnabled()) {
37+
Util.MAIN_LOGGER.info("Using embedding model with dimension: {}", model.dimension());
4338
}
44-
return null;
39+
return new EmbeddingGenerator(model, batchSize);
4540
}
4641

4742
/**

test-app/src/main/ml-schemas-12/tde/json-vector-chunks.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
"scalarType": "vector",
2020
"val": "vec:vector(embedding)",
2121
"dimension": "384",
22-
"invalidValues": "reject",
23-
"nullable": true
22+
"invalidValues": "ignore",
23+
"nullable": false
2424
}
2525
]
2626
}

0 commit comments

Comments
 (0)