Skip to content

Commit b564a8c

Browse files
committed
MLE-17715 Can now configure embedder batch size
Going to test this out next on some larger datasets to see what kinds of errors we'll need to handle from the embedding model.
1 parent 291bc5e commit b564a8c

File tree

10 files changed

+206
-27
lines changed

10 files changed

+206
-27
lines changed

src/main/java/com/marklogic/spark/Options.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,13 @@ public abstract class Options {
250250
*/
251251
public static final String WRITE_EMBEDDER_EMBEDDING_NAMESPACE = "spark.marklogic.write.embedder.embedding.namespace";
252252

253+
/**
254+
* Defines the number of chunks to send to the embedding model in a single call. Defaults to 1.
255+
*
256+
* @since 2.5.0
257+
*/
258+
public static final String WRITE_EMBEDDER_BATCH_SIZE = "spark.marklogic.write.embedder.batchSize";
259+
253260
private Options() {
254261
}
255262
}

src/main/java/com/marklogic/spark/writer/embedding/EmbedderDocumentProcessor.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
import com.marklogic.client.document.DocumentWriteOperation;
77
import com.marklogic.spark.writer.DocumentProcessor;
8-
import dev.langchain4j.model.embedding.EmbeddingModel;
98

109
import java.util.Iterator;
1110
import java.util.stream.Stream;
@@ -19,9 +18,9 @@ class EmbedderDocumentProcessor implements DocumentProcessor {
1918
private final ChunkSelector chunkSelector;
2019
private final EmbeddingGenerator embeddingGenerator;
2120

22-
EmbedderDocumentProcessor(ChunkSelector chunkSelector, EmbeddingModel embeddingModel) {
21+
EmbedderDocumentProcessor(ChunkSelector chunkSelector, EmbeddingGenerator embeddingGenerator) {
2322
this.chunkSelector = chunkSelector;
24-
this.embeddingGenerator = new EmbeddingGenerator(embeddingModel);
23+
this.embeddingGenerator = embeddingGenerator;
2524
}
2625

2726
@Override

src/main/java/com/marklogic/spark/writer/embedding/EmbedderDocumentProcessorFactory.java

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
import com.marklogic.spark.ConnectorException;
77
import com.marklogic.spark.ContextSupport;
88
import com.marklogic.spark.Options;
9-
import com.marklogic.spark.writer.dom.XPathNamespaceContext;
109
import com.marklogic.spark.writer.DocumentProcessor;
10+
import com.marklogic.spark.writer.dom.XPathNamespaceContext;
1111
import dev.langchain4j.model.embedding.EmbeddingModel;
1212

1313
import java.util.HashMap;
@@ -21,11 +21,21 @@ public static Optional<DocumentProcessor> makeEmbedder(ContextSupport context) {
2121
Optional<EmbeddingModel> embeddingModel = makeEmbeddingModel(context);
2222
if (embeddingModel.isPresent()) {
2323
ChunkSelector chunkSelector = makeChunkSelector(context);
24-
return Optional.of(new EmbedderDocumentProcessor(chunkSelector, embeddingModel.get()));
24+
EmbeddingGenerator embeddingGenerator = makeEmbeddingGenerator(context);
25+
return Optional.of(new EmbedderDocumentProcessor(chunkSelector, embeddingGenerator));
2526
}
2627
return Optional.empty();
2728
}
2829

30+
public static EmbeddingGenerator makeEmbeddingGenerator(ContextSupport context) {
31+
Optional<EmbeddingModel> embeddingModel = makeEmbeddingModel(context);
32+
if (embeddingModel.isPresent()) {
33+
int batchSize = context.getIntOption(Options.WRITE_EMBEDDER_BATCH_SIZE, 1, 1);
34+
return new EmbeddingGenerator(embeddingModel.get(), batchSize);
35+
}
36+
return null;
37+
}
38+
2939
/**
3040
* If the user is also splitting the documents, then we'll know the location of the chunks based on the default
3141
* chunks data structure produced by the splitter. If the user is instead processing documents that already have
@@ -74,7 +84,7 @@ private static ChunkSelector makeXmlChunkSelector(ContextSupport context) {
7484
);
7585
}
7686

77-
public static Optional<EmbeddingModel> makeEmbeddingModel(ContextSupport context) {
87+
private static Optional<EmbeddingModel> makeEmbeddingModel(ContextSupport context) {
7888
if (!context.hasOption(Options.WRITE_EMBEDDER_MODEL_FUNCTION_CLASS_NAME)) {
7989
return Optional.empty();
8090
}

src/main/java/com/marklogic/spark/writer/embedding/EmbeddingGenerator.java

Lines changed: 53 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,36 +4,78 @@
44
package com.marklogic.spark.writer.embedding;
55

66
import com.marklogic.spark.Util;
7+
import dev.langchain4j.data.document.Metadata;
78
import dev.langchain4j.data.embedding.Embedding;
9+
import dev.langchain4j.data.segment.TextSegment;
810
import dev.langchain4j.model.embedding.EmbeddingModel;
911
import dev.langchain4j.model.output.Response;
1012

13+
import java.util.ArrayList;
14+
import java.util.Iterator;
1115
import java.util.List;
16+
import java.util.stream.Collectors;
1217

1318
/**
1419
* Knows how to generate and add embeddings for each chunk. Will soon support a batch size so that more than one
1520
* chunk can be sent to an embedding model in a single call.
1621
*/
1722
public class EmbeddingGenerator {
1823

19-
private EmbeddingModel embeddingModel;
24+
// We don't have any use for metadata, so just need a single instance for constructing text segments.
25+
private static final Metadata TEXT_SEGMENT_METADATA = new Metadata();
26+
27+
private final EmbeddingModel embeddingModel;
28+
private final int batchSize;
2029

2130
public EmbeddingGenerator(EmbeddingModel embeddingModel) {
31+
this(embeddingModel, 1);
32+
}
33+
34+
public EmbeddingGenerator(EmbeddingModel embeddingModel, int batchSize) {
2235
this.embeddingModel = embeddingModel;
36+
this.batchSize = batchSize;
2337
}
2438

2539
public void addEmbeddings(List<Chunk> chunks) {
26-
if (chunks != null) {
27-
chunks.forEach(chunk -> {
28-
String text = chunk.getEmbeddingText();
29-
if (text != null && text.trim().length() > 0) {
30-
Response<Embedding> response = embeddingModel.embed(text);
31-
chunk.addEmbedding(response.content());
32-
} else if (Util.MAIN_LOGGER.isDebugEnabled()) {
33-
Util.MAIN_LOGGER.debug("Not generating embedding for chunk in URI {}; could not find text to use for generating an embedding.",
34-
chunk.getDocumentUri());
40+
if (chunks == null || chunks.isEmpty()) {
41+
return;
42+
}
43+
44+
Iterator<Chunk> chunkIterator = chunks.iterator();
45+
List<Chunk> batch = new ArrayList<>();
46+
while (chunkIterator.hasNext()) {
47+
Chunk chunk = chunkIterator.next();
48+
String text = chunk.getEmbeddingText();
49+
if (text != null && text.trim().length() > 0) {
50+
batch.add(chunk);
51+
if (batch.size() >= this.batchSize) {
52+
addEmbeddingsToChunks(batch);
53+
batch = new ArrayList<>();
3554
}
36-
});
55+
} else if (Util.MAIN_LOGGER.isDebugEnabled()) {
56+
Util.MAIN_LOGGER.debug("Not generating embedding for chunk in URI {}; could not find text to use for generating an embedding.",
57+
chunk.getDocumentUri());
58+
}
59+
}
60+
61+
if (!batch.isEmpty()) {
62+
addEmbeddingsToChunks(batch);
63+
}
64+
}
65+
66+
private void addEmbeddingsToChunks(List<Chunk> chunks) {
67+
List<TextSegment> textSegments = chunks.stream()
68+
.map(chunk -> new TextSegment(chunk.getEmbeddingText(), TEXT_SEGMENT_METADATA))
69+
.collect(Collectors.toList());
70+
71+
Response<List<Embedding>> response = embeddingModel.embedAll(textSegments);
72+
if (Util.MAIN_LOGGER.isDebugEnabled()) {
73+
Util.MAIN_LOGGER.debug("Sent {} chunks; token usage: {}", textSegments.size(), response.tokenUsage());
74+
}
75+
76+
List<Embedding> embeddings = response.content();
77+
for (int i = 0; i < embeddings.size(); i++) {
78+
chunks.get(i).addEmbedding(embeddings.get(i));
3779
}
3880
}
3981
}

src/main/java/com/marklogic/spark/writer/splitter/SplitterDocumentProcessorFactory.java

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,7 @@
1010
import com.marklogic.spark.writer.DocumentProcessor;
1111
import com.marklogic.spark.writer.dom.XPathNamespaceContext;
1212
import com.marklogic.spark.writer.embedding.EmbedderDocumentProcessorFactory;
13-
import com.marklogic.spark.writer.embedding.EmbeddingGenerator;
1413
import dev.langchain4j.data.document.DocumentSplitter;
15-
import dev.langchain4j.model.embedding.EmbeddingModel;
1614

1715
import java.util.Arrays;
1816
import java.util.Optional;
@@ -85,12 +83,6 @@ private static ChunkAssembler makeChunkAssembler(ContextSupport context) {
8583
metadata.getPermissions().addFromDelimitedString(value);
8684
}
8785

88-
EmbeddingGenerator embeddingGenerator = null;
89-
Optional<EmbeddingModel> embeddingModel = EmbedderDocumentProcessorFactory.makeEmbeddingModel(context);
90-
if (embeddingModel.isPresent()) {
91-
embeddingGenerator = new EmbeddingGenerator(embeddingModel.get());
92-
}
93-
9486
return new DefaultChunkAssembler(new ChunkConfig.Builder()
9587
.withMetadata(metadata)
9688
.withMaxChunks(context.getIntOption(Options.WRITE_SPLITTER_SIDECAR_MAX_CHUNKS, 0, 0))
@@ -100,7 +92,7 @@ private static ChunkAssembler makeChunkAssembler(ContextSupport context) {
10092
.withUriSuffix(context.getStringOption(Options.WRITE_SPLITTER_SIDECAR_URI_SUFFIX))
10193
.withXmlNamespace(context.getStringOption(Options.WRITE_SPLITTER_SIDECAR_XML_NAMESPACE))
10294
.build(),
103-
embeddingGenerator
95+
EmbedderDocumentProcessorFactory.makeEmbeddingGenerator(context)
10496
);
10597
}
10698

src/main/resources/marklogic-spark-messages.properties

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,5 @@ spark.marklogic.write.splitter.maxChunkSize=
2020
spark.marklogic.write.splitter.maxOverlapSize=
2121
spark.marklogic.write.embedder.chunks.jsonPointer=
2222
spark.marklogic.write.embedder.chunks.xpath=
23+
spark.marklogic.write.embedder.batchSize=
24+

src/test/java/com/marklogic/spark/writer/embedding/AddEmbeddingsToJsonTest.java

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,45 @@ void chunksIsAnObjectInsteadOfAnArray() {
206206
assertEquals(JsonNodeType.ARRAY, doc.get("embedding").getNodeType());
207207
}
208208

209+
@Test
210+
void testBatchSize() {
211+
TestEmbeddingModel.batchCounter = 0;
212+
213+
readDocument("/marklogic-docs/java-client-intro.json")
214+
.write().format(CONNECTOR_IDENTIFIER)
215+
.option(Options.CLIENT_URI, makeClientUri())
216+
.option(Options.WRITE_SPLITTER_JSON_POINTERS, "/text")
217+
.option(Options.WRITE_PERMISSIONS, DEFAULT_PERMISSIONS)
218+
.option(Options.WRITE_URI_TEMPLATE, "/split-test.json")
219+
.option(Options.WRITE_SPLITTER_MAX_CHUNK_SIZE, 500)
220+
.option(Options.WRITE_EMBEDDER_MODEL_FUNCTION_CLASS_NAME, "com.marklogic.spark.writer.embedding.TestEmbeddingModel")
221+
.option(Options.WRITE_EMBEDDER_BATCH_SIZE, 2)
222+
.mode(SaveMode.Append)
223+
.save();
224+
225+
JsonNode doc = readJsonDocument("/split-test.json");
226+
assertEquals(4, doc.get("chunks").size());
227+
228+
assertEquals(2, TestEmbeddingModel.batchCounter, "Expecting 2 batches to be sent to the test " +
229+
"embedding model, given the batch size of 2 and 4 chunks being created.");
230+
}
231+
232+
@Test
233+
void invalidBatchSize() {
234+
DataFrameWriter writer = readDocument("/marklogic-docs/java-client-intro.json")
235+
.write().format(CONNECTOR_IDENTIFIER)
236+
.option(Options.CLIENT_URI, makeClientUri())
237+
.option(Options.WRITE_SPLITTER_JSON_POINTERS, "/text")
238+
.option(Options.WRITE_PERMISSIONS, DEFAULT_PERMISSIONS)
239+
.option(Options.WRITE_URI_TEMPLATE, "/split-test.json")
240+
.option(Options.WRITE_EMBEDDER_MODEL_FUNCTION_CLASS_NAME, "com.marklogic.spark.writer.embedding.TestEmbeddingModel")
241+
.option(Options.WRITE_EMBEDDER_BATCH_SIZE, "abc")
242+
.mode(SaveMode.Append);
243+
244+
ConnectorException ex = assertThrowsConnectorException(() -> writer.save());
245+
assertEquals("The value of 'spark.marklogic.write.embedder.batchSize' must be numeric.", ex.getMessage());
246+
}
247+
209248
private Dataset<Row> readDocument(String uri) {
210249
return newSparkSession().read().format(CONNECTOR_IDENTIFIER)
211250
.option(Options.CLIENT_URI, makeClientUri())

src/test/java/com/marklogic/spark/writer/embedding/EmbedderTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ void customizedPaths() {
6262
.withTextPointer("/wrapper/custom-text")
6363
.withEmbeddingArrayName("custom-embedding")
6464
.build(),
65-
new AllMiniLmL6V2EmbeddingModel()
65+
new EmbeddingGenerator(new AllMiniLmL6V2EmbeddingModel())
6666
);
6767

6868
DocumentWriteOperation output = embedder.apply(new DocumentWriteOperationImpl("a.json", null, new JacksonHandle(doc))).next();
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
/*
2+
* Copyright © 2024 MarkLogic Corporation. All Rights Reserved.
3+
*/
4+
package com.marklogic.spark.writer.embedding;
5+
6+
import org.junit.jupiter.api.Test;
7+
8+
import java.util.ArrayList;
9+
import java.util.List;
10+
11+
import static org.junit.jupiter.api.Assertions.assertEquals;
12+
13+
class EmbeddingGeneratorTest {
14+
15+
@Test
16+
void test() {
17+
TestEmbeddingModel embeddingModel = new TestEmbeddingModel();
18+
TestEmbeddingModel.batchCounter = 0;
19+
20+
EmbeddingGenerator generator = new EmbeddingGenerator(embeddingModel, 2);
21+
22+
List<Chunk> chunks = new ArrayList<>();
23+
for (int i = 0; i < 5; i++) {
24+
chunks.add(new TestEmbeddingModel.TestChunk("text" + i));
25+
}
26+
27+
generator.addEmbeddings(chunks);
28+
assertEquals(3, embeddingModel.batchCounter, "3 batches should have been sent given the batch size of 2.");
29+
}
30+
31+
}
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
/*
2+
* Copyright © 2024 MarkLogic Corporation. All Rights Reserved.
3+
*/
4+
package com.marklogic.spark.writer.embedding;
5+
6+
import dev.langchain4j.data.embedding.Embedding;
7+
import dev.langchain4j.data.segment.TextSegment;
8+
import dev.langchain4j.model.embedding.EmbeddingModel;
9+
import dev.langchain4j.model.output.Response;
10+
11+
import java.util.Arrays;
12+
import java.util.List;
13+
import java.util.Map;
14+
import java.util.function.Function;
15+
16+
/**
17+
* Used for testing the embedder batch size feature.
18+
*/
19+
class TestEmbeddingModel implements EmbeddingModel, Function<Map<String, String>, EmbeddingModel> {
20+
21+
static int batchCounter;
22+
23+
@Override
24+
public EmbeddingModel apply(Map<String, String> options) {
25+
return this;
26+
}
27+
28+
@Override
29+
public Response<List<Embedding>> embedAll(List<TextSegment> textSegments) {
30+
batchCounter++;
31+
return Response.from(Arrays.asList(new Embedding(new float[]{1})));
32+
}
33+
34+
static class TestChunk implements Chunk {
35+
36+
private final String text;
37+
38+
TestChunk(String text) {
39+
this.text = text;
40+
}
41+
42+
@Override
43+
public String getDocumentUri() {
44+
return "/doesnt/matter.json";
45+
}
46+
47+
@Override
48+
public String getEmbeddingText() {
49+
return text;
50+
}
51+
52+
@Override
53+
public void addEmbedding(Embedding embedding) {
54+
// Don't need to do this for the purposes of our test.
55+
}
56+
}
57+
}

0 commit comments

Comments
 (0)