Skip to content

Commit 3bba4e5

Browse files
authored
Merge pull request #371 from marklogic/feature/embed-with-custom-namespace
Refactored EmbeddingAdder
2 parents b593c8c + 3242420 commit 3bba4e5

File tree

6 files changed

+76
-31
lines changed

6 files changed

+76
-31
lines changed

marklogic-langchain4j/src/main/java/com/marklogic/langchain4j/embedding/EmbeddingAdder.java

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,20 +26,36 @@ public class EmbeddingAdder implements Function<DocumentWriteOperation, Iterator
2626

2727
private List<DocumentWriteOperation> pendingSourceDocuments = new ArrayList<>();
2828

29-
public EmbeddingAdder(ChunkSelector chunkSelector, EmbeddingGenerator embeddingGenerator, DocumentTextSplitter documentTextSplitter) {
30-
this.chunkSelector = chunkSelector;
31-
this.embeddingGenerator = embeddingGenerator;
29+
/**
30+
* Use this when a user has configured a splitter, as the splitter will return {@code DocumentAndChunks} instances
31+
* that avoid the need for using a {@code ChunkSelector} to find chunks.
32+
*
33+
* @param documentTextSplitter
34+
* @param embeddingGenerator
35+
*/
36+
public EmbeddingAdder(DocumentTextSplitter documentTextSplitter, EmbeddingGenerator embeddingGenerator) {
3237
this.documentTextSplitter = documentTextSplitter;
38+
this.embeddingGenerator = embeddingGenerator;
39+
this.chunkSelector = null;
3340
}
3441

3542
/**
36-
* I think we can hold onto documents here? addEmbeddings could return true/false if it actually sends anything.
43+
* Use this constructor when the user has not configured a splitter, as the {@code ChunkSelector} is needed to find
44+
* chunks in each document.
3745
*
38-
* @param sourceDocument the function argument
39-
* @return
46+
* @param chunkSelector
47+
* @param embeddingGenerator
4048
*/
49+
public EmbeddingAdder(ChunkSelector chunkSelector, EmbeddingGenerator embeddingGenerator) {
50+
this.chunkSelector = chunkSelector;
51+
this.embeddingGenerator = embeddingGenerator;
52+
this.documentTextSplitter = null;
53+
}
54+
4155
@Override
4256
public Iterator<DocumentWriteOperation> apply(DocumentWriteOperation sourceDocument) {
57+
// If the user configured a splitter, then follow a path where the source document is split, which will produce
58+
// DocumentAndChunks instances. Which means the ChunkSelector isn't needed.
4359
if (documentTextSplitter != null) {
4460
return splitAndAddEmbeddings(sourceDocument);
4561
}

marklogic-spark-api/src/main/java/com/marklogic/spark/Context.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
public abstract class Context implements Serializable {
1111

12-
protected final Map<String, String> properties;
12+
private final Map<String, String> properties;
1313

1414
protected Context(Map<String, String> properties) {
1515
this.properties = properties;
@@ -49,7 +49,7 @@ public final long getNumericOption(String optionName, long defaultValue, long mi
4949
public final boolean getBooleanOption(String option, boolean defaultValue) {
5050
return hasOption(option) ? Boolean.parseBoolean(getStringOption(option)) : defaultValue;
5151
}
52-
52+
5353
public final String getOptionNameForMessage(String option) {
5454
return Util.getOptionNameForErrorMessage(option);
5555
}

marklogic-spark-connector/src/main/java/com/marklogic/spark/ContextSupport.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,18 +79,18 @@ protected final Map<String, String> buildConnectionProperties() {
7979
Map<String, String> connectionProps = new HashMap<>();
8080
connectionProps.put("spark.marklogic.client.authType", "digest");
8181
connectionProps.put("spark.marklogic.client.connectionType", "gateway");
82-
connectionProps.putAll(this.properties);
82+
connectionProps.putAll(getProperties());
8383
if (optionExists(Options.CLIENT_URI)) {
84-
parseConnectionString(properties.get(Options.CLIENT_URI), connectionProps);
84+
parseConnectionString(getProperties().get(Options.CLIENT_URI), connectionProps);
8585
}
86-
if ("true".equalsIgnoreCase(properties.get(Options.CLIENT_SSL_ENABLED))) {
86+
if ("true".equalsIgnoreCase(getProperties().get(Options.CLIENT_SSL_ENABLED))) {
8787
connectionProps.put("spark.marklogic.client.sslProtocol", "default");
8888
}
8989
return connectionProps;
9090
}
9191

9292
public final boolean optionExists(String option) {
93-
String value = properties.get(option);
93+
String value = getProperties().get(option);
9494
return value != null && value.trim().length() > 0;
9595
}
9696

marklogic-spark-connector/src/test/java/com/marklogic/langchain4j/embedding/EmbedderTest.java

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,7 @@ class EmbedderTest extends AbstractIntegrationTest {
3434
@Test
3535
void defaultPaths() {
3636
DocumentTextSplitter splitter = newJsonSplitter(500, 2, "/text");
37-
EmbeddingAdder embedder = new EmbeddingAdder(
38-
new JsonChunkSelector.Builder().build(), new EmbeddingGenerator(new AllMiniLmL6V2EmbeddingModel()), splitter
39-
);
37+
EmbeddingAdder embedder = new EmbeddingAdder(splitter, new EmbeddingGenerator(new AllMiniLmL6V2EmbeddingModel()));
4038

4139
Iterator<DocumentWriteOperation> docs = embedder.apply(readJsonDocument());
4240

@@ -66,8 +64,7 @@ void customizedPaths() {
6664
.withTextPointer("/wrapper/custom-text")
6765
.withEmbeddingArrayName("custom-embedding")
6866
.build(),
69-
new EmbeddingGenerator(new AllMiniLmL6V2EmbeddingModel()),
70-
null
67+
new EmbeddingGenerator(new AllMiniLmL6V2EmbeddingModel())
7168
);
7269

7370
DocumentWriteOperation output = embedder.apply(new DocumentWriteOperationImpl("a.json", null, new JacksonHandle(doc))).next();
@@ -82,10 +79,7 @@ void customizedPaths() {
8279
@Test
8380
void xml() {
8481
DocumentTextSplitter splitter = newXmlSplitter(500, 2, "/node()/text");
85-
EmbeddingAdder embedder = new EmbeddingAdder(
86-
new DOMChunkSelector(null, new XmlChunkConfig()),
87-
new EmbeddingGenerator(new AllMiniLmL6V2EmbeddingModel()), splitter
88-
);
82+
EmbeddingAdder embedder = new EmbeddingAdder(splitter, new EmbeddingGenerator(new AllMiniLmL6V2EmbeddingModel()));
8983

9084
Iterator<DocumentWriteOperation> docs = embedder.apply(readXmlDocument());
9185

marklogic-spark-connector/src/test/java/com/marklogic/spark/writer/embedding/AddEmbeddingsToXmlTest.java

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -85,15 +85,36 @@ void sidecarWithNamespace() {
8585
.mode(SaveMode.Append)
8686
.save();
8787

88-
XmlNode doc = readXmlDocument("/split-test.xml-chunks-1.xml", Namespace.getNamespace("ex", "org:example"));
89-
doc.assertElementCount("/ex:sidecar/ex:chunks/ex:chunk", 4);
90-
for (XmlNode chunk : doc.getXmlNodes("/ex:sidecar/ex:chunks/ex:chunk")) {
91-
chunk.assertElementExists("/ex:chunk/ex:text");
92-
chunk.assertElementExists("For now, the embedding still defaults to the empty namespace. We may change " +
93-
"this soon to be a MarkLogic-specific namespace to better distinguish it from the users " +
94-
"content.", "/ex:chunk/embedding");
95-
}
88+
verifyChunksInNamespacedSidecar();
89+
verifyEachChunkIsReturnedByAVectorQuery("namespaced_xml_chunks");
90+
}
91+
92+
/**
93+
* This test verifies that when the source document does not have a namespace but the sidecar document does,
94+
* the chunks still get embeddings because the connector doesn't need to use a ChunkSelector. That is due to the
95+
* connector knowing that the splitter will return instances of DocumentAndChunks, which means the embedder can
96+
* access the chunks without having to find them.
97+
*/
98+
@ExtendWith(RequiresMarkLogic12.class)
99+
@Test
100+
void sidecarWithCustomNamespace() {
101+
readDocument("/marklogic-docs/java-client-intro.xml")
102+
.write().format(CONNECTOR_IDENTIFIER)
103+
.option(Options.CLIENT_URI, makeClientUri())
104+
.option(Options.XPATH_NAMESPACE_PREFIX + "ex", "org:example")
105+
.option(Options.WRITE_SPLITTER_XPATH, "/node()/text/text()")
106+
.option(Options.WRITE_PERMISSIONS, DEFAULT_PERMISSIONS)
107+
.option(Options.WRITE_URI_TEMPLATE, "/split-test.xml")
108+
.option(Options.WRITE_SPLITTER_MAX_CHUNK_SIZE, 500)
109+
.option(Options.WRITE_SPLITTER_SIDECAR_MAX_CHUNKS, 4)
110+
.option(Options.WRITE_SPLITTER_SIDECAR_ROOT_NAME, "sidecar")
111+
.option(Options.WRITE_SPLITTER_SIDECAR_XML_NAMESPACE, "org:example")
112+
.option(Options.WRITE_SPLITTER_SIDECAR_COLLECTIONS, "namespaced-xml-vector-chunks")
113+
.option(Options.WRITE_EMBEDDER_MODEL_FUNCTION_CLASS_NAME, TEST_EMBEDDING_FUNCTION_CLASS)
114+
.mode(SaveMode.Append)
115+
.save();
96116

117+
verifyChunksInNamespacedSidecar();
97118
verifyEachChunkIsReturnedByAVectorQuery("namespaced_xml_chunks");
98119
}
99120

@@ -247,4 +268,15 @@ private void verifyEachChunkIsReturnedByAVectorQuery(String viewName) {
247268

248269
assertEquals(4, counter, "Each test is expected to produce 4 chunks based on the max chunk size of 500.");
249270
}
271+
272+
private void verifyChunksInNamespacedSidecar() {
273+
XmlNode doc = readXmlDocument("/split-test.xml-chunks-1.xml", Namespace.getNamespace("ex", "org:example"));
274+
doc.assertElementCount("/ex:sidecar/ex:chunks/ex:chunk", 4);
275+
for (XmlNode chunk : doc.getXmlNodes("/ex:sidecar/ex:chunks/ex:chunk")) {
276+
chunk.assertElementExists("/ex:chunk/ex:text");
277+
chunk.assertElementExists("For now, the embedding still defaults to the empty namespace. We may change " +
278+
"this soon to be a MarkLogic-specific namespace to better distinguish it from the users " +
279+
"content.", "/ex:chunk/embedding");
280+
}
281+
}
250282
}

marklogic-spark-langchain4j/src/main/java/com/marklogic/spark/langchain4j/EmbeddingAdderFactory.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,12 @@ public abstract class EmbeddingAdderFactory {
2121
public static Optional<EmbeddingAdder> makeEmbedder(Context context, DocumentTextSplitter splitter) {
2222
Optional<EmbeddingModel> embeddingModel = makeEmbeddingModel(context);
2323
if (embeddingModel.isPresent()) {
24-
ChunkSelector chunkSelector = makeChunkSelector(context);
2524
EmbeddingGenerator embeddingGenerator = makeEmbeddingGenerator(context);
26-
return Optional.of(new EmbeddingAdder(chunkSelector, embeddingGenerator, splitter));
25+
if (splitter != null) {
26+
return Optional.of(new EmbeddingAdder(splitter, embeddingGenerator));
27+
}
28+
ChunkSelector chunkSelector = makeChunkSelector(context);
29+
return Optional.of(new EmbeddingAdder(chunkSelector, embeddingGenerator));
2730
}
2831
return Optional.empty();
2932
}

0 commit comments

Comments
 (0)