Skip to content

Commit 9763a07

Browse files
authored
Merge pull request #354 from marklogic/feature/document-batch-size-new
MLE-17715 Embedder batch size can now handle chunks across documents
2 parents 9f647bc + a7747c0 commit 9763a07

22 files changed

+378
-133
lines changed

Jenkinsfile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ pipeline{
117117
sh label:'mlcleanup', script: '''#!/bin/bash
118118
cd marklogic-spark-connector
119119
docker-compose down -v || true
120+
sudo /usr/local/sbin/mladmin delete $WORKSPACE/marklogic-spark-connector/docker/caddy/
120121
sudo /usr/local/sbin/mladmin delete $WORKSPACE/marklogic-spark-connector/docker/marklogic/logs/
121122
'''
122123
}

docker-compose.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ services:
77
caddy-load-balancer:
88
image: caddy:2-alpine
99
volumes:
10-
- ./caddy/data:/data
11-
- ./caddy/config/Caddyfile:/etc/caddy/Caddyfile
10+
# Not mapping the Caddy data directory, as that causes issues for Jenkins.
11+
- ./docker/caddy/config/Caddyfile:/etc/caddy/Caddyfile
1212
ports:
1313
# Expand this range as needed. See Caddyfile for which ports are used for reverse proxies.
1414
- "8115:8115"
File renamed without changes.

src/main/java/com/marklogic/spark/Util.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,12 @@ public interface Util {
1818
*/
1919
Logger MAIN_LOGGER = LoggerFactory.getLogger("com.marklogic.spark");
2020

21+
/**
22+
* Intended for log messages pertaining to the embedder feature. Uses a separate logger so that it can be enabled
23+
* at the info/debug level without enabling any other log messages.
24+
*/
25+
Logger EMBEDDER_LOGGER = LoggerFactory.getLogger("com.marklogic.spark.embedder");
26+
2127
static boolean hasOption(Map<String, String> properties, String... options) {
2228
return Stream.of(options)
2329
.anyMatch(option -> properties.get(option) != null && properties.get(option).trim().length() > 0);

src/main/java/com/marklogic/spark/writer/DocumentProcessorFactory.java

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,24 @@
55

66
import com.marklogic.spark.ContextSupport;
77
import com.marklogic.spark.writer.embedding.EmbedderDocumentProcessorFactory;
8+
import com.marklogic.spark.writer.splitter.SplitterDocumentProcessor;
89
import com.marklogic.spark.writer.splitter.SplitterDocumentProcessorFactory;
910

1011
import java.util.Optional;
1112

1213
abstract class DocumentProcessorFactory {
1314

1415
static DocumentProcessor buildDocumentProcessor(ContextSupport context) {
15-
Optional<DocumentProcessor> splitter = SplitterDocumentProcessorFactory.makeSplitter(context);
16-
if (splitter.isPresent()) {
17-
return splitter.get();
16+
Optional<SplitterDocumentProcessor> splitter = SplitterDocumentProcessorFactory.makeSplitter(context);
17+
18+
Optional<DocumentProcessor> embedder = EmbedderDocumentProcessorFactory.makeEmbedder(
19+
context, splitter.isPresent() ? splitter.get() : null
20+
);
21+
22+
if (embedder.isPresent()) {
23+
return embedder.get();
1824
}
19-
Optional<DocumentProcessor> embedder = EmbedderDocumentProcessorFactory.makeEmbedder(context);
20-
return embedder.isPresent() ? embedder.get() : null;
25+
return splitter.isPresent() ? splitter.get() : null;
2126
}
2227

2328
private DocumentProcessorFactory() {

src/main/java/com/marklogic/spark/writer/WriteBatcherDataWriter.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
import java.util.Set;
3737
import java.util.concurrent.atomic.AtomicInteger;
3838
import java.util.concurrent.atomic.AtomicReference;
39+
import java.util.function.Supplier;
3940

4041
/**
4142
* Uses the Java Client's WriteBatcher to handle writing rows as documents to MarkLogic.
@@ -108,6 +109,8 @@ public WriterCommitMessage commit() {
108109
// in a document. Those are retrieved here.
109110
buildAndWriteDocuments(rowConverter.getRemainingDocumentInputs());
110111

112+
flushDocumentProcessor();
113+
111114
this.writeBatcher.flushAndWait();
112115

113116
throwWriteFailureIfExists();
@@ -179,6 +182,18 @@ private Set<String> getGraphNames() {
179182
null;
180183
}
181184

185+
/**
186+
* A document processor can implement Supplier so that it can batch up documents to be written and then return
187+
* any pending documents during the commit operation. This allows for the embedder processor to batch calls to the
188+
* embedding model.
189+
*/
190+
private void flushDocumentProcessor() {
191+
if (this.documentProcessor instanceof Supplier) {
192+
Iterator<DocumentWriteOperation> remainingDocuments = ((Supplier<Iterator<DocumentWriteOperation>>) this.documentProcessor).get();
193+
remainingDocuments.forEachRemaining(this::writeDocument);
194+
}
195+
}
196+
182197
private void addBatchListeners(WriteBatcher writeBatcher) {
183198
writeBatcher.onBatchSuccess(batch -> this.successItemCount.getAndAdd(batch.getItems().length));
184199
if (writeContext.isAbortOnFailure()) {

src/main/java/com/marklogic/spark/writer/embedding/ChunkSelector.java

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55

66
import com.marklogic.client.document.DocumentWriteOperation;
77

8-
import java.util.List;
9-
108
/**
119
* Abstracts how chunks are selected from a JSON or XML document.
1210
*/
@@ -19,22 +17,4 @@ public interface ChunkSelector {
1917
*/
2018
DocumentAndChunks selectChunks(DocumentWriteOperation sourceDocument);
2119

22-
class DocumentAndChunks {
23-
24-
private final DocumentWriteOperation documentToWrite;
25-
private final List<Chunk> chunks;
26-
27-
DocumentAndChunks(DocumentWriteOperation documentToWrite, List<Chunk> chunks) {
28-
this.documentToWrite = documentToWrite;
29-
this.chunks = chunks;
30-
}
31-
32-
public DocumentWriteOperation getDocumentToWrite() {
33-
return documentToWrite;
34-
}
35-
36-
public List<Chunk> getChunks() {
37-
return chunks;
38-
}
39-
}
4020
}

src/main/java/com/marklogic/spark/writer/embedding/DOMChunkSelector.java

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,20 @@
2020
import java.util.ArrayList;
2121
import java.util.List;
2222

23-
public class DOMChunkSelector implements ChunkSelector {
23+
class DOMChunkSelector implements ChunkSelector {
2424

2525
private final XPathFactory xpathFactory;
2626
private final XPathExpression chunksExpression;
2727
private final XmlChunkConfig xmlChunkConfig;
2828
private final DOMHelper domHelper;
2929

30-
public DOMChunkSelector(String chunksExpression, XmlChunkConfig xmlChunkConfig) {
30+
DOMChunkSelector(String chunksExpression, XmlChunkConfig xmlChunkConfig) {
3131
this.xpathFactory = XPathFactory.newInstance();
3232
this.xmlChunkConfig = xmlChunkConfig;
3333
this.domHelper = new DOMHelper(xmlChunkConfig.getNamespaceContext());
34-
this.chunksExpression = domHelper.compileXPath(chunksExpression, "selecting chunks");
34+
35+
String chunksXPath = chunksExpression != null ? chunksExpression : "/node()/chunks";
36+
this.chunksExpression = domHelper.compileXPath(chunksXPath, "selecting chunks");
3537
}
3638

3739
@Override
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
/*
2+
* Copyright © 2024 MarkLogic Corporation. All Rights Reserved.
3+
*/
4+
package com.marklogic.spark.writer.embedding;
5+
6+
import com.marklogic.client.document.DocumentWriteOperation;
7+
import com.marklogic.client.io.marker.AbstractWriteHandle;
8+
import com.marklogic.client.io.marker.DocumentMetadataWriteHandle;
9+
import org.jetbrains.annotations.NotNull;
10+
11+
import java.util.List;
12+
13+
/**
14+
* Encapsulates a document to be written to MarkLogic along with an optional list of chunks that have been extracted
15+
* from it. Capturing the list of chunks is useful when a user wishes to use both the splitter and embedder. In that
16+
* scenario, the embedder can reuse the list of chunks produced by the splitter without having to find the chunks
17+
* itself.
18+
*/
19+
public class DocumentAndChunks implements DocumentWriteOperation {
20+
21+
private final DocumentWriteOperation documentToWrite;
22+
private final List<Chunk> chunks;
23+
24+
public DocumentAndChunks(DocumentWriteOperation documentToWrite, List<Chunk> chunks) {
25+
this.documentToWrite = documentToWrite;
26+
this.chunks = chunks;
27+
}
28+
29+
public DocumentWriteOperation getDocumentToWrite() {
30+
return documentToWrite;
31+
}
32+
33+
public List<Chunk> getChunks() {
34+
return chunks;
35+
}
36+
37+
public boolean hasChunks() {
38+
return chunks != null && !chunks.isEmpty();
39+
}
40+
41+
@Override
42+
public OperationType getOperationType() {
43+
return OperationType.DOCUMENT_WRITE;
44+
}
45+
46+
@Override
47+
public String getUri() {
48+
return documentToWrite.getUri();
49+
}
50+
51+
@Override
52+
public DocumentMetadataWriteHandle getMetadata() {
53+
return documentToWrite.getMetadata();
54+
}
55+
56+
@Override
57+
public AbstractWriteHandle getContent() {
58+
return documentToWrite.getContent();
59+
}
60+
61+
@Override
62+
public String getTemporalDocumentURI() {
63+
return documentToWrite.getTemporalDocumentURI();
64+
}
65+
66+
@Override
67+
public int compareTo(@NotNull DocumentWriteOperation o) {
68+
return documentToWrite.compareTo(o);
69+
}
70+
}

src/main/java/com/marklogic/spark/writer/embedding/EmbedderDocumentProcessor.java

Lines changed: 86 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,31 +4,111 @@
44
package com.marklogic.spark.writer.embedding;
55

66
import com.marklogic.client.document.DocumentWriteOperation;
7+
import com.marklogic.spark.Util;
78
import com.marklogic.spark.writer.DocumentProcessor;
9+
import com.marklogic.spark.writer.splitter.SplitterDocumentProcessor;
810

11+
import java.util.ArrayList;
912
import java.util.Iterator;
13+
import java.util.List;
14+
import java.util.function.Supplier;
1015
import java.util.stream.Stream;
1116

1217
/**
1318
* Supports a use case where a document already has chunks in it, which must be selected via a {@code ChunkSelector}.
1419
* The {@code EmbeddingModel} is then used to generate and add an embedding to each chunk in a given document.
1520
*/
16-
class EmbedderDocumentProcessor implements DocumentProcessor {
21+
class EmbedderDocumentProcessor implements DocumentProcessor, Supplier<Iterator<DocumentWriteOperation>> {
1722

1823
private final ChunkSelector chunkSelector;
1924
private final EmbeddingGenerator embeddingGenerator;
25+
private final SplitterDocumentProcessor splitterDocumentProcessor;
2026

21-
EmbedderDocumentProcessor(ChunkSelector chunkSelector, EmbeddingGenerator embeddingGenerator) {
27+
private List<DocumentWriteOperation> pendingSourceDocuments = new ArrayList<>();
28+
29+
EmbedderDocumentProcessor(ChunkSelector chunkSelector, EmbeddingGenerator embeddingGenerator, SplitterDocumentProcessor splitterDocumentProcessor) {
2230
this.chunkSelector = chunkSelector;
2331
this.embeddingGenerator = embeddingGenerator;
32+
this.splitterDocumentProcessor = splitterDocumentProcessor;
2433
}
2534

35+
/**
36+
* I think we can hold onto documents here? addEmbeddings could return true/false if it actually sends anything.
37+
*
38+
* @param sourceDocument the function argument
39+
* @return
40+
*/
2641
@Override
2742
public Iterator<DocumentWriteOperation> apply(DocumentWriteOperation sourceDocument) {
28-
ChunkSelector.DocumentAndChunks documentAndChunks = chunkSelector.selectChunks(sourceDocument);
29-
if (documentAndChunks.getChunks() != null && !documentAndChunks.getChunks().isEmpty()) {
30-
embeddingGenerator.addEmbeddings(documentAndChunks.getChunks());
43+
if (splitterDocumentProcessor != null) {
44+
return splitAndAddEmbeddings(sourceDocument);
45+
}
46+
47+
DocumentAndChunks documentAndChunks = chunkSelector.selectChunks(sourceDocument);
48+
return documentAndChunks.hasChunks() ?
49+
addEmbeddingsToExistingChunks(documentAndChunks) :
50+
// If no chunks are found, embeddings can't be added, so just return the source document.
51+
Stream.of(documentAndChunks.getDocumentToWrite()).iterator();
52+
}
53+
54+
@Override
55+
public Iterator<DocumentWriteOperation> get() {
56+
// Return any pending source documents - i.e. those with chunks that didn't add up to the embedding generator's
57+
// batch size, and thus embeddings haven't been added.
58+
if (pendingSourceDocuments != null && !pendingSourceDocuments.isEmpty()) {
59+
if (Util.EMBEDDER_LOGGER.isInfoEnabled()) {
60+
Util.EMBEDDER_LOGGER.info("Pending source document count: {}; generating embeddings for each document.",
61+
pendingSourceDocuments.size());
62+
}
63+
embeddingGenerator.generateEmbeddingsForPendingChunks();
64+
return pendingSourceDocuments.iterator();
65+
}
66+
return Stream.<DocumentWriteOperation>empty().iterator();
67+
}
68+
69+
private Iterator<DocumentWriteOperation> splitAndAddEmbeddings(DocumentWriteOperation sourceDocument) {
70+
Iterator<DocumentWriteOperation> splitDocuments = splitterDocumentProcessor.apply(sourceDocument);
71+
72+
// Track the list of documents to return. A document won't be returned immediately if it has chunks but the
73+
// embedding generator doesn't receive enough chunks to meet its batch size threshold.
74+
List<DocumentWriteOperation> documentsToReturn = new ArrayList<>();
75+
76+
splitDocuments.forEachRemaining(splitDoc -> {
77+
boolean hasChunks = splitDoc instanceof DocumentAndChunks && ((DocumentAndChunks) splitDoc).hasChunks();
78+
if (hasChunks) {
79+
DocumentAndChunks documentAndChunks = (DocumentAndChunks) splitDoc;
80+
pendingSourceDocuments.add(documentAndChunks);
81+
boolean embeddingsWereGenerated = embeddingGenerator.addEmbeddings(documentAndChunks);
82+
// If the embedding generator received enough chunks to exceed its batch size, then all the pending
83+
// documents can be added to the list of documents to return, as we know those documents will have had
84+
// embeddings added to them.
85+
if (embeddingsWereGenerated) {
86+
documentsToReturn.addAll(pendingSourceDocuments);
87+
pendingSourceDocuments.clear();
88+
}
89+
} else {
90+
// If the document doesn't have any chunks, it can be returned immediately.
91+
documentsToReturn.add(splitDoc);
92+
}
93+
});
94+
95+
return documentsToReturn.iterator();
96+
}
97+
98+
/**
99+
* For existing chunks - add the document to the pending list. Then add embeddings. If embeddings were generated,
100+
* return an iterator over all the pending documents, which now have embeddings.
101+
*/
102+
private Iterator<DocumentWriteOperation> addEmbeddingsToExistingChunks(DocumentAndChunks documentAndChunks) {
103+
pendingSourceDocuments.add(documentAndChunks);
104+
boolean embeddingsWereGenerated = embeddingGenerator.addEmbeddings(documentAndChunks);
105+
if (embeddingsWereGenerated) {
106+
List<DocumentWriteOperation> documentsWithEmbeddings = new ArrayList<>();
107+
documentsWithEmbeddings.addAll(pendingSourceDocuments);
108+
pendingSourceDocuments.clear();
109+
return documentsWithEmbeddings.iterator();
110+
} else {
111+
return Stream.<DocumentWriteOperation>empty().iterator();
31112
}
32-
return Stream.of(documentAndChunks.getDocumentToWrite()).iterator();
33113
}
34114
}

0 commit comments

Comments
 (0)