Skip to content

Commit b58226f

Browse files
committed
Fix for adding embeddings to sidecar docs from text docs
1 parent 5bdb785 commit b58226f

File tree

3 files changed

+91
-1
lines changed

3 files changed

+91
-1
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
/*
2+
* Copyright © 2024 MarkLogic Corporation. All Rights Reserved.
3+
*/
4+
package com.marklogic.spark.writer.embedding;
5+
6+
import com.fasterxml.jackson.databind.JsonNode;
7+
import com.fasterxml.jackson.databind.node.JsonNodeType;
8+
import com.marklogic.junit5.XmlNode;
9+
import com.marklogic.spark.AbstractIntegrationTest;
10+
import com.marklogic.spark.Options;
11+
import org.apache.spark.sql.DataFrameWriter;
12+
import org.apache.spark.sql.Row;
13+
import org.apache.spark.sql.SaveMode;
14+
import org.junit.jupiter.api.Test;
15+
16+
import java.util.List;
17+
18+
import static org.junit.jupiter.api.Assertions.assertEquals;
19+
import static org.junit.jupiter.api.Assertions.assertTrue;
20+
21+
/**
22+
* Verifies that when split text from text documents and then adding embeddings to the sidecar docs, the user doesn't
23+
* need to specify the location of the chunks. The connector is expected to determine the location based on whether the
24+
* sidecar docs are JSON or XML.
25+
*/
26+
class AddEmbeddingsFromTextTest extends AbstractIntegrationTest {
27+
28+
private static final String TEST_EMBEDDING_FUNCTION_CLASS = "com.marklogic.spark.writer.embedding.MinilmEmbeddingModelFunction";
29+
30+
@Test
31+
void jsonSidecarDocuments() {
32+
prepareToWriteChunks()
33+
.mode(SaveMode.Append)
34+
.save();
35+
36+
List<String> uris = getUrisInCollection("text-chunks", 4);
37+
for (String uri : uris) {
38+
assertTrue(uri.endsWith(".json"));
39+
JsonNode doc = readJsonDocument(uri);
40+
assertEquals(JsonNodeType.ARRAY, doc.get("chunks").getNodeType());
41+
}
42+
}
43+
44+
@Test
45+
void xmlSidecarDocuments() {
46+
prepareToWriteChunks()
47+
.option(Options.WRITE_SPLITTER_SIDECAR_DOCUMENT_TYPE, "xml")
48+
.mode(SaveMode.Append)
49+
.save();
50+
51+
List<String> uris = getUrisInCollection("text-chunks", 4);
52+
for (String uri : uris) {
53+
assertTrue(uri.endsWith(".xml"));
54+
XmlNode doc = readXmlDocument(uri);
55+
doc.assertElementCount("/node()/chunks/chunk", 1);
56+
}
57+
}
58+
59+
private DataFrameWriter<Row> prepareToWriteChunks() {
60+
return newSparkSession().read().format(CONNECTOR_IDENTIFIER)
61+
.option(Options.CLIENT_URI, makeClientUri())
62+
.option(Options.READ_DOCUMENTS_URIS, "/marklogic-docs/java-client-intro.txt")
63+
.load()
64+
.write().format(CONNECTOR_IDENTIFIER)
65+
.option(Options.CLIENT_URI, makeClientUri())
66+
.option(Options.WRITE_PERMISSIONS, DEFAULT_PERMISSIONS)
67+
.option(Options.WRITE_URI_PREFIX, "/test")
68+
.option(Options.WRITE_SPLITTER_TEXT, true)
69+
.option(Options.WRITE_SPLITTER_MAX_CHUNK_SIZE, 500)
70+
.option(Options.WRITE_SPLITTER_SIDECAR_MAX_CHUNKS, 1)
71+
.option(Options.WRITE_SPLITTER_SIDECAR_COLLECTIONS, "text-chunks")
72+
.option(Options.WRITE_EMBEDDER_MODEL_FUNCTION_CLASS_NAME, TEST_EMBEDDING_FUNCTION_CLASS);
73+
}
74+
75+
76+
}

marklogic-spark-connector/src/test/java/com/marklogic/spark/writer/embedding/TestEmbeddingModel.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,11 @@ public EmbeddingModel apply(Map<String, String> options) {
3232
return this;
3333
}
3434

35+
@Override
36+
public int dimension() {
37+
return 0;
38+
}
39+
3540
@Override
3641
public Response<List<Embedding>> embedAll(List<TextSegment> textSegments) {
3742
batchCounter++;

marklogic-spark-langchain4j/src/main/java/com/marklogic/spark/langchain4j/EmbeddingAdderFactory.java

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import com.marklogic.spark.ConnectorException;
99
import com.marklogic.spark.Context;
1010
import com.marklogic.spark.Options;
11+
import com.marklogic.spark.Util;
1112
import dev.langchain4j.model.embedding.EmbeddingModel;
1213

1314
import java.util.HashMap;
@@ -31,7 +32,11 @@ public static EmbeddingGenerator makeEmbeddingGenerator(Context context) {
3132
Optional<EmbeddingModel> embeddingModel = makeEmbeddingModel(context);
3233
if (embeddingModel.isPresent()) {
3334
int batchSize = context.getIntOption(Options.WRITE_EMBEDDER_BATCH_SIZE, 1, 1);
34-
return new EmbeddingGenerator(embeddingModel.get(), batchSize);
35+
EmbeddingModel model = embeddingModel.get();
36+
if (Util.MAIN_LOGGER.isInfoEnabled()) {
37+
Util.MAIN_LOGGER.info("Using embedding model with dimension: {}", model.dimension());
38+
}
39+
return new EmbeddingGenerator(model, batchSize);
3540
}
3641
return null;
3742
}
@@ -55,6 +60,10 @@ private static ChunkSelector makeChunkSelector(Context context) {
5560
return makeJsonChunkSelector(context);
5661
} else if (context.hasOption(Options.WRITE_EMBEDDER_CHUNKS_XPATH)) {
5762
return makeXmlChunkSelector(context);
63+
} else if (context.hasOption(Options.WRITE_SPLITTER_TEXT)) {
64+
return "xml".equalsIgnoreCase(context.getStringOption(Options.WRITE_SPLITTER_SIDECAR_DOCUMENT_TYPE)) ?
65+
makeXmlChunkSelector(context) :
66+
makeJsonChunkSelector(context);
5867
}
5968
throw new ConnectorException(String.format("To generate embeddings on documents, you must specify either " +
6069
"%s or %s to define the location of chunks in documents.",

0 commit comments

Comments
 (0)