Skip to content

Commit 9c97e20

Browse files
authored
Merge pull request #349 from marklogic/feature/17715-embedder-batch-size
MLE-17715 Can now configure embedder batch size
2 parents 291bc5e + b564a8c commit 9c97e20

File tree

10 files changed

+206
-27
lines changed

10 files changed

+206
-27
lines changed

src/main/java/com/marklogic/spark/Options.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,13 @@ public abstract class Options {
250250
*/
251251
public static final String WRITE_EMBEDDER_EMBEDDING_NAMESPACE = "spark.marklogic.write.embedder.embedding.namespace";
252252

253+
/**
254+
* Defines the number of chunks to send to the embedding model in a single call. Defaults to 1.
255+
*
256+
* @since 2.5.0
257+
*/
258+
public static final String WRITE_EMBEDDER_BATCH_SIZE = "spark.marklogic.write.embedder.batchSize";
259+
253260
private Options() {
254261
}
255262
}

src/main/java/com/marklogic/spark/writer/embedding/EmbedderDocumentProcessor.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
import com.marklogic.client.document.DocumentWriteOperation;
77
import com.marklogic.spark.writer.DocumentProcessor;
8-
import dev.langchain4j.model.embedding.EmbeddingModel;
98

109
import java.util.Iterator;
1110
import java.util.stream.Stream;
@@ -19,9 +18,9 @@ class EmbedderDocumentProcessor implements DocumentProcessor {
1918
private final ChunkSelector chunkSelector;
2019
private final EmbeddingGenerator embeddingGenerator;
2120

22-
EmbedderDocumentProcessor(ChunkSelector chunkSelector, EmbeddingModel embeddingModel) {
21+
EmbedderDocumentProcessor(ChunkSelector chunkSelector, EmbeddingGenerator embeddingGenerator) {
2322
this.chunkSelector = chunkSelector;
24-
this.embeddingGenerator = new EmbeddingGenerator(embeddingModel);
23+
this.embeddingGenerator = embeddingGenerator;
2524
}
2625

2726
@Override

src/main/java/com/marklogic/spark/writer/embedding/EmbedderDocumentProcessorFactory.java

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
import com.marklogic.spark.ConnectorException;
77
import com.marklogic.spark.ContextSupport;
88
import com.marklogic.spark.Options;
9-
import com.marklogic.spark.writer.dom.XPathNamespaceContext;
109
import com.marklogic.spark.writer.DocumentProcessor;
10+
import com.marklogic.spark.writer.dom.XPathNamespaceContext;
1111
import dev.langchain4j.model.embedding.EmbeddingModel;
1212

1313
import java.util.HashMap;
@@ -21,11 +21,21 @@ public static Optional<DocumentProcessor> makeEmbedder(ContextSupport context) {
2121
Optional<EmbeddingModel> embeddingModel = makeEmbeddingModel(context);
2222
if (embeddingModel.isPresent()) {
2323
ChunkSelector chunkSelector = makeChunkSelector(context);
24-
return Optional.of(new EmbedderDocumentProcessor(chunkSelector, embeddingModel.get()));
24+
EmbeddingGenerator embeddingGenerator = makeEmbeddingGenerator(context);
25+
return Optional.of(new EmbedderDocumentProcessor(chunkSelector, embeddingGenerator));
2526
}
2627
return Optional.empty();
2728
}
2829

30+
public static EmbeddingGenerator makeEmbeddingGenerator(ContextSupport context) {
31+
Optional<EmbeddingModel> embeddingModel = makeEmbeddingModel(context);
32+
if (embeddingModel.isPresent()) {
33+
int batchSize = context.getIntOption(Options.WRITE_EMBEDDER_BATCH_SIZE, 1, 1);
34+
return new EmbeddingGenerator(embeddingModel.get(), batchSize);
35+
}
36+
return null;
37+
}
38+
2939
/**
3040
* If the user is also splitting the documents, then we'll know the location of the chunks based on the default
3141
* chunks data structure produced by the splitter. If the user is instead processing documents that already have
@@ -74,7 +84,7 @@ private static ChunkSelector makeXmlChunkSelector(ContextSupport context) {
7484
);
7585
}
7686

77-
public static Optional<EmbeddingModel> makeEmbeddingModel(ContextSupport context) {
87+
private static Optional<EmbeddingModel> makeEmbeddingModel(ContextSupport context) {
7888
if (!context.hasOption(Options.WRITE_EMBEDDER_MODEL_FUNCTION_CLASS_NAME)) {
7989
return Optional.empty();
8090
}

src/main/java/com/marklogic/spark/writer/embedding/EmbeddingGenerator.java

Lines changed: 53 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,36 +4,78 @@
44
package com.marklogic.spark.writer.embedding;
55

66
import com.marklogic.spark.Util;
7+
import dev.langchain4j.data.document.Metadata;
78
import dev.langchain4j.data.embedding.Embedding;
9+
import dev.langchain4j.data.segment.TextSegment;
810
import dev.langchain4j.model.embedding.EmbeddingModel;
911
import dev.langchain4j.model.output.Response;
1012

13+
import java.util.ArrayList;
14+
import java.util.Iterator;
1115
import java.util.List;
16+
import java.util.stream.Collectors;
1217

1318
/**
1419
* Knows how to generate and add embeddings for each chunk. Will soon support a batch size so that more than one
1520
* chunk can be sent to an embedding model in a single call.
1621
*/
1722
public class EmbeddingGenerator {
1823

19-
private EmbeddingModel embeddingModel;
24+
// We don't have any use for metadata, so just need a single instance for constructing text segments.
25+
private static final Metadata TEXT_SEGMENT_METADATA = new Metadata();
26+
27+
private final EmbeddingModel embeddingModel;
28+
private final int batchSize;
2029

2130
public EmbeddingGenerator(EmbeddingModel embeddingModel) {
31+
this(embeddingModel, 1);
32+
}
33+
34+
public EmbeddingGenerator(EmbeddingModel embeddingModel, int batchSize) {
2235
this.embeddingModel = embeddingModel;
36+
this.batchSize = batchSize;
2337
}
2438

2539
public void addEmbeddings(List<Chunk> chunks) {
26-
if (chunks != null) {
27-
chunks.forEach(chunk -> {
28-
String text = chunk.getEmbeddingText();
29-
if (text != null && text.trim().length() > 0) {
30-
Response<Embedding> response = embeddingModel.embed(text);
31-
chunk.addEmbedding(response.content());
32-
} else if (Util.MAIN_LOGGER.isDebugEnabled()) {
33-
Util.MAIN_LOGGER.debug("Not generating embedding for chunk in URI {}; could not find text to use for generating an embedding.",
34-
chunk.getDocumentUri());
40+
if (chunks == null || chunks.isEmpty()) {
41+
return;
42+
}
43+
44+
Iterator<Chunk> chunkIterator = chunks.iterator();
45+
List<Chunk> batch = new ArrayList<>();
46+
while (chunkIterator.hasNext()) {
47+
Chunk chunk = chunkIterator.next();
48+
String text = chunk.getEmbeddingText();
49+
if (text != null && text.trim().length() > 0) {
50+
batch.add(chunk);
51+
if (batch.size() >= this.batchSize) {
52+
addEmbeddingsToChunks(batch);
53+
batch = new ArrayList<>();
3554
}
36-
});
55+
} else if (Util.MAIN_LOGGER.isDebugEnabled()) {
56+
Util.MAIN_LOGGER.debug("Not generating embedding for chunk in URI {}; could not find text to use for generating an embedding.",
57+
chunk.getDocumentUri());
58+
}
59+
}
60+
61+
if (!batch.isEmpty()) {
62+
addEmbeddingsToChunks(batch);
63+
}
64+
}
65+
66+
private void addEmbeddingsToChunks(List<Chunk> chunks) {
67+
List<TextSegment> textSegments = chunks.stream()
68+
.map(chunk -> new TextSegment(chunk.getEmbeddingText(), TEXT_SEGMENT_METADATA))
69+
.collect(Collectors.toList());
70+
71+
Response<List<Embedding>> response = embeddingModel.embedAll(textSegments);
72+
if (Util.MAIN_LOGGER.isDebugEnabled()) {
73+
Util.MAIN_LOGGER.debug("Sent {} chunks; token usage: {}", textSegments.size(), response.tokenUsage());
74+
}
75+
76+
List<Embedding> embeddings = response.content();
77+
for (int i = 0; i < embeddings.size(); i++) {
78+
chunks.get(i).addEmbedding(embeddings.get(i));
3779
}
3880
}
3981
}

src/main/java/com/marklogic/spark/writer/splitter/SplitterDocumentProcessorFactory.java

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,7 @@
1010
import com.marklogic.spark.writer.DocumentProcessor;
1111
import com.marklogic.spark.writer.dom.XPathNamespaceContext;
1212
import com.marklogic.spark.writer.embedding.EmbedderDocumentProcessorFactory;
13-
import com.marklogic.spark.writer.embedding.EmbeddingGenerator;
1413
import dev.langchain4j.data.document.DocumentSplitter;
15-
import dev.langchain4j.model.embedding.EmbeddingModel;
1614

1715
import java.util.Arrays;
1816
import java.util.Optional;
@@ -85,12 +83,6 @@ private static ChunkAssembler makeChunkAssembler(ContextSupport context) {
8583
metadata.getPermissions().addFromDelimitedString(value);
8684
}
8785

88-
EmbeddingGenerator embeddingGenerator = null;
89-
Optional<EmbeddingModel> embeddingModel = EmbedderDocumentProcessorFactory.makeEmbeddingModel(context);
90-
if (embeddingModel.isPresent()) {
91-
embeddingGenerator = new EmbeddingGenerator(embeddingModel.get());
92-
}
93-
9486
return new DefaultChunkAssembler(new ChunkConfig.Builder()
9587
.withMetadata(metadata)
9688
.withMaxChunks(context.getIntOption(Options.WRITE_SPLITTER_SIDECAR_MAX_CHUNKS, 0, 0))
@@ -100,7 +92,7 @@ private static ChunkAssembler makeChunkAssembler(ContextSupport context) {
10092
.withUriSuffix(context.getStringOption(Options.WRITE_SPLITTER_SIDECAR_URI_SUFFIX))
10193
.withXmlNamespace(context.getStringOption(Options.WRITE_SPLITTER_SIDECAR_XML_NAMESPACE))
10294
.build(),
103-
embeddingGenerator
95+
EmbedderDocumentProcessorFactory.makeEmbeddingGenerator(context)
10496
);
10597
}
10698

src/main/resources/marklogic-spark-messages.properties

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,5 @@ spark.marklogic.write.splitter.maxChunkSize=
2020
spark.marklogic.write.splitter.maxOverlapSize=
2121
spark.marklogic.write.embedder.chunks.jsonPointer=
2222
spark.marklogic.write.embedder.chunks.xpath=
23+
spark.marklogic.write.embedder.batchSize=
24+

src/test/java/com/marklogic/spark/writer/embedding/AddEmbeddingsToJsonTest.java

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,45 @@ void chunksIsAnObjectInsteadOfAnArray() {
206206
assertEquals(JsonNodeType.ARRAY, doc.get("embedding").getNodeType());
207207
}
208208

209+
@Test
210+
void testBatchSize() {
211+
TestEmbeddingModel.batchCounter = 0;
212+
213+
readDocument("/marklogic-docs/java-client-intro.json")
214+
.write().format(CONNECTOR_IDENTIFIER)
215+
.option(Options.CLIENT_URI, makeClientUri())
216+
.option(Options.WRITE_SPLITTER_JSON_POINTERS, "/text")
217+
.option(Options.WRITE_PERMISSIONS, DEFAULT_PERMISSIONS)
218+
.option(Options.WRITE_URI_TEMPLATE, "/split-test.json")
219+
.option(Options.WRITE_SPLITTER_MAX_CHUNK_SIZE, 500)
220+
.option(Options.WRITE_EMBEDDER_MODEL_FUNCTION_CLASS_NAME, "com.marklogic.spark.writer.embedding.TestEmbeddingModel")
221+
.option(Options.WRITE_EMBEDDER_BATCH_SIZE, 2)
222+
.mode(SaveMode.Append)
223+
.save();
224+
225+
JsonNode doc = readJsonDocument("/split-test.json");
226+
assertEquals(4, doc.get("chunks").size());
227+
228+
assertEquals(2, TestEmbeddingModel.batchCounter, "Expecting 2 batches to be sent to the test " +
229+
"embedding model, given the batch size of 2 and 4 chunks being created.");
230+
}
231+
232+
@Test
233+
void invalidBatchSize() {
234+
DataFrameWriter writer = readDocument("/marklogic-docs/java-client-intro.json")
235+
.write().format(CONNECTOR_IDENTIFIER)
236+
.option(Options.CLIENT_URI, makeClientUri())
237+
.option(Options.WRITE_SPLITTER_JSON_POINTERS, "/text")
238+
.option(Options.WRITE_PERMISSIONS, DEFAULT_PERMISSIONS)
239+
.option(Options.WRITE_URI_TEMPLATE, "/split-test.json")
240+
.option(Options.WRITE_EMBEDDER_MODEL_FUNCTION_CLASS_NAME, "com.marklogic.spark.writer.embedding.TestEmbeddingModel")
241+
.option(Options.WRITE_EMBEDDER_BATCH_SIZE, "abc")
242+
.mode(SaveMode.Append);
243+
244+
ConnectorException ex = assertThrowsConnectorException(() -> writer.save());
245+
assertEquals("The value of 'spark.marklogic.write.embedder.batchSize' must be numeric.", ex.getMessage());
246+
}
247+
209248
private Dataset<Row> readDocument(String uri) {
210249
return newSparkSession().read().format(CONNECTOR_IDENTIFIER)
211250
.option(Options.CLIENT_URI, makeClientUri())

src/test/java/com/marklogic/spark/writer/embedding/EmbedderTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ void customizedPaths() {
6262
.withTextPointer("/wrapper/custom-text")
6363
.withEmbeddingArrayName("custom-embedding")
6464
.build(),
65-
new AllMiniLmL6V2EmbeddingModel()
65+
new EmbeddingGenerator(new AllMiniLmL6V2EmbeddingModel())
6666
);
6767

6868
DocumentWriteOperation output = embedder.apply(new DocumentWriteOperationImpl("a.json", null, new JacksonHandle(doc))).next();
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
/*
2+
* Copyright © 2024 MarkLogic Corporation. All Rights Reserved.
3+
*/
4+
package com.marklogic.spark.writer.embedding;
5+
6+
import org.junit.jupiter.api.Test;
7+
8+
import java.util.ArrayList;
9+
import java.util.List;
10+
11+
import static org.junit.jupiter.api.Assertions.assertEquals;
12+
13+
class EmbeddingGeneratorTest {
14+
15+
@Test
16+
void test() {
17+
TestEmbeddingModel embeddingModel = new TestEmbeddingModel();
18+
TestEmbeddingModel.batchCounter = 0;
19+
20+
EmbeddingGenerator generator = new EmbeddingGenerator(embeddingModel, 2);
21+
22+
List<Chunk> chunks = new ArrayList<>();
23+
for (int i = 0; i < 5; i++) {
24+
chunks.add(new TestEmbeddingModel.TestChunk("text" + i));
25+
}
26+
27+
generator.addEmbeddings(chunks);
28+
assertEquals(3, embeddingModel.batchCounter, "3 batches should have been sent given the batch size of 2.");
29+
}
30+
31+
}
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
/*
2+
* Copyright © 2024 MarkLogic Corporation. All Rights Reserved.
3+
*/
4+
package com.marklogic.spark.writer.embedding;
5+
6+
import dev.langchain4j.data.embedding.Embedding;
7+
import dev.langchain4j.data.segment.TextSegment;
8+
import dev.langchain4j.model.embedding.EmbeddingModel;
9+
import dev.langchain4j.model.output.Response;
10+
11+
import java.util.Arrays;
12+
import java.util.List;
13+
import java.util.Map;
14+
import java.util.function.Function;
15+
16+
/**
17+
* Used for testing the embedder batch size feature.
18+
*/
19+
class TestEmbeddingModel implements EmbeddingModel, Function<Map<String, String>, EmbeddingModel> {
20+
21+
static int batchCounter;
22+
23+
@Override
24+
public EmbeddingModel apply(Map<String, String> options) {
25+
return this;
26+
}
27+
28+
@Override
29+
public Response<List<Embedding>> embedAll(List<TextSegment> textSegments) {
30+
batchCounter++;
31+
return Response.from(Arrays.asList(new Embedding(new float[]{1})));
32+
}
33+
34+
static class TestChunk implements Chunk {
35+
36+
private final String text;
37+
38+
TestChunk(String text) {
39+
this.text = text;
40+
}
41+
42+
@Override
43+
public String getDocumentUri() {
44+
return "/doesnt/matter.json";
45+
}
46+
47+
@Override
48+
public String getEmbeddingText() {
49+
return text;
50+
}
51+
52+
@Override
53+
public void addEmbedding(Embedding embedding) {
54+
// Don't need to do this for the purposes of our test.
55+
}
56+
}
57+
}

0 commit comments

Comments
 (0)