Skip to content

Commit dd897de

Browse files
authored
Merge pull request #347 from marklogic/feature/embedder-xml-dom
Verifying that DOM is a better approach for selecting chunks
2 parents 2190a6f + 72207da commit dd897de

File tree

4 files changed

+214
-1
lines changed

4 files changed

+214
-1
lines changed
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
/*
2+
* Copyright © 2024 MarkLogic Corporation. All Rights Reserved.
3+
*/
4+
package com.marklogic.spark.writer.embedding;
5+
6+
import com.marklogic.spark.ConnectorException;
7+
import dev.langchain4j.data.embedding.Embedding;
8+
import org.w3c.dom.Document;
9+
import org.w3c.dom.Element;
10+
import org.w3c.dom.NodeList;
11+
12+
import javax.xml.xpath.XPathConstants;
13+
import javax.xml.xpath.XPathExpressionException;
14+
import javax.xml.xpath.XPathFactory;
15+
16+
public class DOMChunk implements Chunk {
17+
18+
private final String documentUri;
19+
private final Document document;
20+
private final Element chunkElement;
21+
private final String textExpression;
22+
private final XPathFactory xpathFactory;
23+
24+
public DOMChunk(String documentUri, Document document, Element chunkElement, String textExpression, XPathFactory xpathFactory) {
25+
this.documentUri = documentUri;
26+
this.document = document;
27+
this.chunkElement = chunkElement;
28+
this.textExpression = textExpression;
29+
this.xpathFactory = xpathFactory;
30+
}
31+
32+
@Override
33+
public String getDocumentUri() {
34+
return documentUri;
35+
}
36+
37+
@Override
38+
public String getEmbeddingText() {
39+
NodeList embeddingTextNodes;
40+
try {
41+
embeddingTextNodes = (NodeList) xpathFactory.newXPath().evaluate(textExpression, chunkElement, XPathConstants.NODESET);
42+
} catch (XPathExpressionException e) {
43+
throw new ConnectorException(String.format("Unable to evaluate XPath expression: %s; cause: %s",
44+
textExpression, e.getMessage()), e);
45+
}
46+
47+
return concatenateNodesIntoString(embeddingTextNodes);
48+
}
49+
50+
@Override
51+
public void addEmbedding(Embedding embedding) {
52+
this.document.createElement("embedding").setTextContent(embedding.vectorAsList().toString());
53+
}
54+
55+
private String concatenateNodesIntoString(NodeList embeddingTextNodes) {
56+
StringBuilder builder = new StringBuilder();
57+
for (int i = 0; i < embeddingTextNodes.getLength(); i++) {
58+
if (i > 0) {
59+
builder.append(" ");
60+
}
61+
builder.append(embeddingTextNodes.item(i).getTextContent());
62+
}
63+
return builder.toString().trim();
64+
}
65+
}
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
/*
2+
* Copyright © 2024 MarkLogic Corporation. All Rights Reserved.
3+
*/
4+
package com.marklogic.spark.writer.embedding;
5+
6+
import com.marklogic.client.document.DocumentWriteOperation;
7+
import com.marklogic.client.impl.DocumentWriteOperationImpl;
8+
import com.marklogic.client.impl.HandleAccessor;
9+
import com.marklogic.client.io.DOMHandle;
10+
import com.marklogic.client.io.marker.AbstractWriteHandle;
11+
import com.marklogic.spark.ConnectorException;
12+
import org.w3c.dom.Document;
13+
import org.w3c.dom.Element;
14+
import org.w3c.dom.Node;
15+
import org.w3c.dom.NodeList;
16+
import org.xml.sax.InputSource;
17+
18+
import javax.xml.parsers.DocumentBuilderFactory;
19+
import javax.xml.xpath.XPathConstants;
20+
import javax.xml.xpath.XPathExpression;
21+
import javax.xml.xpath.XPathExpressionException;
22+
import javax.xml.xpath.XPathFactory;
23+
import java.io.StringReader;
24+
import java.util.ArrayList;
25+
import java.util.List;
26+
27+
public class DOMChunkSelector implements ChunkSelector {
28+
29+
private final XPathFactory xpathFactory;
30+
private final XPathExpression chunksExpression;
31+
private final String chunkTextExpression;
32+
private final DocumentBuilderFactory documentBuilderFactory;
33+
34+
public DOMChunkSelector(String chunksExpression, String chunkTextExpression) {
35+
this.xpathFactory = XPathFactory.newInstance();
36+
try {
37+
this.chunksExpression = this.xpathFactory.newXPath().compile(chunksExpression);
38+
} catch (XPathExpressionException e) {
39+
throw new ConnectorException(String.format(
40+
"Unable to compile XPath expression for selecting chunks: %s; cause: %s", chunksExpression, e.getMessage()), e);
41+
}
42+
this.chunkTextExpression = chunkTextExpression;
43+
this.documentBuilderFactory = DocumentBuilderFactory.newInstance();
44+
}
45+
46+
@Override
47+
public DocumentAndChunks selectChunks(DocumentWriteOperation sourceDocument) {
48+
Document doc = extractDocument(sourceDocument);
49+
50+
NodeList chunkNodes = selectChunkNodes(doc);
51+
if (chunkNodes.getLength() == 0) {
52+
return new DocumentAndChunks(sourceDocument, null);
53+
}
54+
55+
List<Chunk> chunks = makeChunks(sourceDocument, doc, chunkNodes);
56+
DocumentWriteOperation docToWrite = new DocumentWriteOperationImpl(sourceDocument.getUri(),
57+
sourceDocument.getMetadata(), new DOMHandle(doc));
58+
return new DocumentAndChunks(docToWrite, chunks);
59+
}
60+
61+
private Document extractDocument(DocumentWriteOperation sourceDocument) {
62+
AbstractWriteHandle handle = sourceDocument.getContent();
63+
if (handle instanceof DOMHandle) {
64+
return ((DOMHandle) handle).get();
65+
}
66+
String xml = HandleAccessor.contentAsString(handle);
67+
try {
68+
return documentBuilderFactory.newDocumentBuilder().parse(new InputSource(new StringReader(xml)));
69+
} catch (Exception e) {
70+
throw new ConnectorException(String.format("Unable to parse XML for document with URI: %s; cause: %s",
71+
sourceDocument.getUri(), e.getMessage()), e);
72+
}
73+
}
74+
75+
private NodeList selectChunkNodes(Document doc) {
76+
try {
77+
return (NodeList) chunksExpression.evaluate(doc, XPathConstants.NODESET);
78+
} catch (XPathExpressionException e) {
79+
throw new ConnectorException(String.format(
80+
"Unable to evaluate XPath expression for selecting chunks: %s; cause: %s", chunksExpression, e.getMessage()), e);
81+
}
82+
}
83+
84+
private List<Chunk> makeChunks(DocumentWriteOperation sourceDocument, Document document, NodeList chunkNodes) {
85+
List<Chunk> chunks = new ArrayList<>();
86+
for (int i = 0; i < chunkNodes.getLength(); i++) {
87+
Node node = chunkNodes.item(i);
88+
if (node.getNodeType() != Node.ELEMENT_NODE) {
89+
throw new ConnectorException(String.format("XPath expression for selecting chunks must only " +
90+
"select elements; XPath: %s; document URI: %s", chunksExpression, sourceDocument.getUri()));
91+
}
92+
chunks.add(new DOMChunk(sourceDocument.getUri(), document, (Element) node, chunkTextExpression, xpathFactory));
93+
}
94+
return chunks;
95+
}
96+
}

src/main/java/com/marklogic/spark/writer/embedding/XmlChunkSelector.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@ static class Builder {
3535

3636
public XmlChunkSelector build() {
3737
String tmp = chunksXPathExpression != null ? chunksXPathExpression : "/node()/chunks";
38-
XPathExpression<Element> chunksExpression = XPathFactory.instance().compile(tmp, Filters.element(), null, xpathNamespaces);
38+
XPathExpression<Element> chunksExpression = xpathNamespaces != null ?
39+
XPathFactory.instance().compile(tmp, Filters.element(), null, xpathNamespaces) :
40+
XPathFactory.instance().compile(tmp, Filters.element());
3941
return new XmlChunkSelector(chunksExpression, textXPathExpression, embeddingName, embeddingNamespace, xpathNamespaces);
4042
}
4143

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
/*
2+
* Copyright © 2024 MarkLogic Corporation. All Rights Reserved.
3+
*/
4+
package com.marklogic.spark.writer.embedding;
5+
6+
import com.marklogic.client.document.DocumentWriteOperation;
7+
import com.marklogic.client.impl.DocumentWriteOperationImpl;
8+
import com.marklogic.client.io.DocumentMetadataHandle;
9+
import com.marklogic.client.io.Format;
10+
import com.marklogic.client.io.StringHandle;
11+
import org.junit.jupiter.params.ParameterizedTest;
12+
import org.junit.jupiter.params.provider.CsvSource;
13+
14+
import static org.junit.jupiter.api.Assertions.assertEquals;
15+
16+
class DOMChunkSelectorTest {
17+
18+
private static final String CHUNKS_EXPRESSION = "/root/chunk";
19+
20+
private static final String XML = "<root>" +
21+
"<chunk>" +
22+
"<text hidden='false'>Hello <b>bold text</b></text>" +
23+
"<other>Other text</other>" +
24+
"<status enabled='true'/>" +
25+
"</chunk>" +
26+
"</root>";
27+
28+
@ParameterizedTest
29+
@CsvSource({
30+
"text,Hello bold text",
31+
"node()/@*,false true",
32+
"node(),Hello bold text Other text",
33+
"node()/text(),Hello Other text",
34+
"node()[self::text or self::other]//text(),Hello bold text Other text"
35+
})
36+
void test(String textExpression, String expectedChunkText) {
37+
String actualChunkText = new DOMChunkSelector(CHUNKS_EXPRESSION, textExpression)
38+
.selectChunks(makeDocument(XML))
39+
.getChunks().get(0).getEmbeddingText();
40+
41+
assertEquals(expectedChunkText, actualChunkText);
42+
}
43+
44+
private DocumentWriteOperation makeDocument(String xml) {
45+
return new DocumentWriteOperationImpl(
46+
"/test.xml", new DocumentMetadataHandle(),
47+
new StringHandle(xml).withFormat(Format.XML)
48+
);
49+
}
50+
}

0 commit comments

Comments
 (0)