Skip to content

Commit 2782324

Browse files
committed
Not adding invalid embeddings
For now, "invalid" = a vector with all zeroes in it. Which Ollama seems to spit out intermittently. And MarkLogic can't do anything useful with it, and in fact it will throw an error, so we're not adding it.
1 parent 6a70ed7 commit 2782324

File tree

5 files changed

+79
-15
lines changed

5 files changed

+79
-15
lines changed

marklogic-langchain4j/src/main/java/com/marklogic/langchain4j/embedding/EmbeddingGenerator.java

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,11 +97,34 @@ private void addEmbeddingsToChunks(List<Chunk> chunks) {
9797
} else {
9898
List<Embedding> embeddings = response.content();
9999
for (int i = 0; i < embeddings.size(); i++) {
100-
chunks.get(i).addEmbedding(embeddings.get(i));
100+
addEmbeddingToChunk(chunks.get(i), embeddings.get(i));
101101
}
102102
}
103103
}
104104

105+
private void addEmbeddingToChunk(Chunk chunk, Embedding embedding) {
106+
if (vectorIsAllZeroes(embedding)) {
107+
if (Util.LANGCHAIN4J_LOGGER.isDebugEnabled()) {
108+
Util.LANGCHAIN4J_LOGGER.debug("Not adding embedding to chunk as it only contains zeroes; source document URI: {}; text: {}",
109+
chunk.getDocumentUri(), chunk.getEmbeddingText());
110+
} else {
111+
Util.LANGCHAIN4J_LOGGER.warn("Not adding embedding to chunk as it only contains zeroes; source document URI: {}",
112+
chunk.getDocumentUri());
113+
}
114+
} else {
115+
chunk.addEmbedding(embedding);
116+
}
117+
}
118+
119+
private boolean vectorIsAllZeroes(Embedding embedding) {
120+
for (float f : embedding.vector()) {
121+
if (f != 0.0f) {
122+
return false;
123+
}
124+
}
125+
return true;
126+
}
127+
105128
private List<TextSegment> makeTextSegments(List<Chunk> chunks) {
106129
return chunks.stream()
107130
.map(chunk -> new TextSegment(chunk.getEmbeddingText(), TEXT_SEGMENT_METADATA))

marklogic-spark-connector/src/test/java/com/marklogic/spark/writer/embedding/AddEmbeddingsToJsonTest.java

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,42 @@ void addEmbeddingsToExistingSplits() {
103103
verifyEachChunkIsReturnedByAVectorQuery();
104104
}
105105

106+
@ExtendWith(RequiresMarkLogic12.class)
107+
@Test
108+
void vectorHasAllZeroes() {
109+
readDocument("/marklogic-docs/java-client-intro.json")
110+
.repartition(1)
111+
.write().format(CONNECTOR_IDENTIFIER)
112+
.option(Options.CLIENT_URI, makeClientUri())
113+
.option(Options.WRITE_SPLITTER_JSON_POINTERS, "/text")
114+
.option(Options.WRITE_SPLITTER_SIDECAR_MAX_CHUNKS, 10)
115+
.option(Options.WRITE_SPLITTER_SIDECAR_COLLECTIONS, "json-vector-chunks")
116+
.option(Options.WRITE_PERMISSIONS, DEFAULT_PERMISSIONS)
117+
.option(Options.WRITE_URI_TEMPLATE, "/split-test.json")
118+
.option(Options.WRITE_EMBEDDER_MODEL_FUNCTION_CLASS_NAME, "com.marklogic.spark.writer.embedding.TestEmbeddingModel")
119+
.option(Options.WRITE_EMBEDDER_MODEL_FUNCTION_OPTION_PREFIX + "returnZeroesOnFirstCall", "true")
120+
.mode(SaveMode.Append)
121+
.save();
122+
123+
JsonNode doc = readJsonDocument("/split-test.json-chunks-1.json");
124+
JsonNode firstChunk = doc.get("chunks").get(0);
125+
assertFalse(firstChunk.has("embedding"), "The first chunk is given an array of all zeroes by the test " +
126+
"embedding model. Flux should recognize this and not add the `embedding` field, as doing so will cause " +
127+
"issues with the Optic vector library - specifically, a VEC-MAGNITUDEZERO error at least when using " +
128+
"vec.cosineSimilarity and then sorting on the values. A future version of MarkLogic 12 may improve this " +
129+
"by allowing for an array of zeroes to be rejected.");
130+
131+
JsonNode secondChunk = doc.get("chunks").get(1);
132+
assertTrue(secondChunk.has("embedding"), "The test embedding model should generate a valid embedding for " +
133+
"the second chunk, which means it can be queried next using Optic.");
134+
135+
RowManager rowManager = getDatabaseClient().newRowManager();
136+
RowSet<RowRecord> rows = rowManager.resultRows(rowManager.newPlanBuilder().fromView("example", "json_chunks"));
137+
assertEquals(1, rows.stream().count(), "The TDE has nullable=false for the embedding column, as a null " +
138+
"vector will cause issues when querying on vectors. And since invalidValues=ignore, the first chunk " +
139+
"won't be returned; only the second chunk will be.");
140+
}
141+
106142
@Test
107143
void passOptionsToEmbeddingModelFunction() {
108144
DataFrameWriter writer = readDocument("/marklogic-docs/java-client-intro.json")

marklogic-spark-connector/src/test/java/com/marklogic/spark/writer/embedding/TestEmbeddingModel.java

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import dev.langchain4j.data.embedding.Embedding;
88
import dev.langchain4j.data.segment.TextSegment;
99
import dev.langchain4j.model.embedding.EmbeddingModel;
10+
import dev.langchain4j.model.embedding.onnx.allminilml6v2.AllMiniLmL6V2EmbeddingModel;
1011
import dev.langchain4j.model.output.Response;
1112

1213
import java.util.Arrays;
@@ -27,8 +28,13 @@ public static void reset() {
2728
chunkCounter = 0;
2829
}
2930

31+
private static AllMiniLmL6V2EmbeddingModel realEmbeddingModel = new AllMiniLmL6V2EmbeddingModel();
32+
33+
private boolean returnZeroesOnFirstCall;
34+
3035
@Override
3136
public EmbeddingModel apply(Map<String, String> options) {
37+
returnZeroesOnFirstCall = "true".equals(options.get("returnZeroesOnFirstCall"));
3238
return this;
3339
}
3440

@@ -41,7 +47,11 @@ public int dimension() {
4147
public Response<List<Embedding>> embedAll(List<TextSegment> textSegments) {
4248
batchCounter++;
4349
chunkCounter += textSegments.size();
44-
return Response.from(Arrays.asList(new Embedding(new float[]{1})));
50+
if (returnZeroesOnFirstCall) {
51+
returnZeroesOnFirstCall = false;
52+
return Response.from(Arrays.asList(new Embedding(new float[384])));
53+
}
54+
return realEmbeddingModel.embedAll(textSegments);
4555
}
4656

4757
public static class TestChunk implements Chunk {

marklogic-spark-langchain4j/src/main/java/com/marklogic/spark/langchain4j/EmbeddingAdderFactory.java

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ public abstract class EmbeddingAdderFactory {
2121
public static Optional<EmbeddingAdder> makeEmbedder(Context context, DocumentTextSplitter splitter) {
2222
Optional<EmbeddingModel> embeddingModel = makeEmbeddingModel(context);
2323
if (embeddingModel.isPresent()) {
24-
EmbeddingGenerator embeddingGenerator = makeEmbeddingGenerator(context);
24+
EmbeddingGenerator embeddingGenerator = makeEmbeddingGenerator(context, embeddingModel.get());
2525
if (splitter != null) {
2626
return Optional.of(new EmbeddingAdder(splitter, embeddingGenerator));
2727
}
@@ -31,17 +31,12 @@ public static Optional<EmbeddingAdder> makeEmbedder(Context context, DocumentTex
3131
return Optional.empty();
3232
}
3333

34-
public static EmbeddingGenerator makeEmbeddingGenerator(Context context) {
35-
Optional<EmbeddingModel> embeddingModel = makeEmbeddingModel(context);
36-
if (embeddingModel.isPresent()) {
37-
int batchSize = context.getIntOption(Options.WRITE_EMBEDDER_BATCH_SIZE, 1, 1);
38-
EmbeddingModel model = embeddingModel.get();
39-
if (Util.MAIN_LOGGER.isInfoEnabled()) {
40-
Util.MAIN_LOGGER.info("Using embedding model with dimension: {}", model.dimension());
41-
}
42-
return new EmbeddingGenerator(model, batchSize);
34+
private static EmbeddingGenerator makeEmbeddingGenerator(Context context, EmbeddingModel model) {
35+
int batchSize = context.getIntOption(Options.WRITE_EMBEDDER_BATCH_SIZE, 1, 1);
36+
if (Util.MAIN_LOGGER.isInfoEnabled()) {
37+
Util.MAIN_LOGGER.info("Using embedding model with dimension: {}", model.dimension());
4338
}
44-
return null;
39+
return new EmbeddingGenerator(model, batchSize);
4540
}
4641

4742
/**

test-app/src/main/ml-schemas-12/tde/json-vector-chunks.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
"scalarType": "vector",
2020
"val": "vec:vector(embedding)",
2121
"dimension": "384",
22-
"invalidValues": "reject",
23-
"nullable": true
22+
"invalidValues": "ignore",
23+
"nullable": false
2424
}
2525
]
2626
}

0 commit comments

Comments
 (0)