Skip to content

Commit 2082a59

Browse files
YunKuiLumarkpollack
authored andcommitted
feat: Add custom template support to KeywordMetadataEnricher
- Support the use of custom templates - Adding builder pattern - Adding unit tests - Adjusted the relevant documents Signed-off-by: YunKui Lu <luyunkui95@gmail.com>
1 parent 2e579b1 commit 2082a59

File tree

3 files changed

+273
-17
lines changed

3 files changed

+273
-17
lines changed

spring-ai-docs/src/main/antora/modules/ROOT/pages/api/etl-pipeline.adoc

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -709,18 +709,26 @@ class MyKeywordEnricher {
709709
}
710710
711711
List<Document> enrichDocuments(List<Document> documents) {
712-
KeywordMetadataEnricher enricher = new KeywordMetadataEnricher(this.chatModel, 5);
712+
KeywordMetadataEnricher enricher = KeywordMetadataEnricher.builder(chatModel)
713+
.keywordCount(5)
714+
.build();
715+
716+
// Or use custom templates
717+
KeywordMetadataEnricher enricher = KeywordMetadataEnricher.builder(chatModel)
718+
.keywordsTemplate(YOUR_CUSTOM_TEMPLATE)
719+
.build();
720+
713721
return enricher.apply(documents);
714722
}
715723
}
716724
----
717725

718-
==== Constructor
726+
==== Constructor Options
719727

720-
The `KeywordMetadataEnricher` constructor takes two parameters:
728+
The `KeywordMetadataEnricher` provides two constructor options:
721729

722-
1. `ChatModel chatModel`: The AI model used for generating keywords.
723-
2. `int keywordCount`: The number of keywords to extract for each document.
730+
1. `KeywordMetadataEnricher(ChatModel chatModel, int keywordCount)`: To use the default template and extract a specified number of keywords.
731+
2. `KeywordMetadataEnricher(ChatModel chatModel, PromptTemplate keywordsTemplate)`: To use a custom template for keyword extraction.
724732

725733
==== Behavior
726734

@@ -734,7 +742,8 @@ The `KeywordMetadataEnricher` processes documents as follows:
734742

735743
==== Customization
736744

737-
The keyword extraction prompt can be customized by modifying the `KEYWORDS_TEMPLATE` constant in the class. The default template is:
745+
You can use the default template or customize the template through the keywordsTemplate parameter.
746+
The default template is:
738747

739748
[source,java]
740749
----
@@ -748,7 +757,14 @@ Where `+{context_str}+` is replaced with the document content, and `%s` is repla
748757
[source,java]
749758
----
750759
ChatModel chatModel = // initialize your chat model
751-
KeywordMetadataEnricher enricher = new KeywordMetadataEnricher(chatModel, 5);
760+
KeywordMetadataEnricher enricher = KeywordMetadataEnricher.builder(chatModel)
761+
.keywordCount(5)
762+
.build();
763+
764+
// Or use custom templates
765+
KeywordMetadataEnricher enricher = KeywordMetadataEnricher.builder(chatModel)
766+
.keywordsTemplate(new PromptTemplate("Extract 5 important keywords from the following text and separate them with commas:\n{context_str}"))
767+
.build();
752768
753769
Document doc = new Document("This is a document about artificial intelligence and its applications in modern technology.");
754770
@@ -766,6 +782,7 @@ System.out.println("Extracted keywords: " + keywords);
766782
* The enricher adds the "excerpt_keywords" metadata field to each processed document.
767783
* The generated keywords are returned as a comma-separated string.
768784
* This enricher is particularly useful for improving document searchability and for generating tags or categories for documents.
785+
* In the Builder pattern, if the `keywordsTemplate` parameter is set, the `keywordCount` parameter will be ignored.
769786

770787
=== SummaryMetadataEnricher
771788
The `SummaryMetadataEnricher` is a `DocumentTransformer` that uses a generative AI model to create summaries for documents and add them as metadata. It can generate summaries for the current document, as well as adjacent documents (previous and next).

spring-ai-model/src/main/java/org/springframework/ai/model/transformer/KeywordMetadataEnricher.java

Lines changed: 81 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919
import java.util.List;
2020
import java.util.Map;
2121

22+
import org.slf4j.Logger;
23+
import org.slf4j.LoggerFactory;
24+
2225
import org.springframework.ai.chat.model.ChatModel;
2326
import org.springframework.ai.chat.prompt.Prompt;
2427
import org.springframework.ai.chat.prompt.PromptTemplate;
@@ -30,45 +33,113 @@
3033
* Keyword extractor that uses generative to extract 'excerpt_keywords' metadata field.
3134
*
3235
* @author Christian Tzolov
36+
* @author YunKui Lu
3337
*/
3438
public class KeywordMetadataEnricher implements DocumentTransformer {
3539

40+
private static final Logger logger = LoggerFactory.getLogger(KeywordMetadataEnricher.class);
41+
3642
public static final String CONTEXT_STR_PLACEHOLDER = "context_str";
3743

3844
public static final String KEYWORDS_TEMPLATE = """
3945
{context_str}. Give %s unique keywords for this
4046
document. Format as comma separated. Keywords: """;
4147

42-
private static final String EXCERPT_KEYWORDS_METADATA_KEY = "excerpt_keywords";
48+
public static final String EXCERPT_KEYWORDS_METADATA_KEY = "excerpt_keywords";
4349

4450
/**
4551
* Model predictor
4652
*/
4753
private final ChatModel chatModel;
4854

4955
/**
50-
* The number of keywords to extract.
56+
* The prompt template to use for keyword extraction.
5157
*/
52-
private final int keywordCount;
58+
private final PromptTemplate keywordsTemplate;
5359

60+
/**
61+
* Create a new {@link KeywordMetadataEnricher} instance.
62+
* @param chatModel the model predictor to use for keyword extraction.
63+
* @param keywordCount the number of keywords to extract.
64+
*/
5465
public KeywordMetadataEnricher(ChatModel chatModel, int keywordCount) {
55-
Assert.notNull(chatModel, "ChatModel must not be null");
56-
Assert.isTrue(keywordCount >= 1, "Document count must be >= 1");
66+
Assert.notNull(chatModel, "chatModel must not be null");
67+
Assert.isTrue(keywordCount >= 1, "keywordCount must be >= 1");
68+
69+
this.chatModel = chatModel;
70+
this.keywordsTemplate = new PromptTemplate(String.format(KEYWORDS_TEMPLATE, keywordCount));
71+
}
72+
73+
/**
74+
* Create a new {@link KeywordMetadataEnricher} instance.
75+
* @param chatModel the model predictor to use for keyword extraction.
76+
* @param keywordsTemplate the prompt template to use for keyword extraction.
77+
*/
78+
public KeywordMetadataEnricher(ChatModel chatModel, PromptTemplate keywordsTemplate) {
79+
Assert.notNull(chatModel, "chatModel must not be null");
80+
Assert.notNull(keywordsTemplate, "keywordsTemplate must not be null");
5781

5882
this.chatModel = chatModel;
59-
this.keywordCount = keywordCount;
83+
this.keywordsTemplate = keywordsTemplate;
6084
}
6185

6286
@Override
6387
public List<Document> apply(List<Document> documents) {
6488
for (Document document : documents) {
65-
66-
var template = new PromptTemplate(String.format(KEYWORDS_TEMPLATE, this.keywordCount));
67-
Prompt prompt = template.create(Map.of(CONTEXT_STR_PLACEHOLDER, document.getText()));
89+
Prompt prompt = this.keywordsTemplate.create(Map.of(CONTEXT_STR_PLACEHOLDER, document.getText()));
6890
String keywords = this.chatModel.call(prompt).getResult().getOutput().getText();
69-
document.getMetadata().putAll(Map.of(EXCERPT_KEYWORDS_METADATA_KEY, keywords));
91+
document.getMetadata().put(EXCERPT_KEYWORDS_METADATA_KEY, keywords);
7092
}
7193
return documents;
7294
}
7395

96+
// Exposed for testing purposes
97+
PromptTemplate getKeywordsTemplate() {
98+
return this.keywordsTemplate;
99+
}
100+
101+
public static Builder builder(ChatModel chatModel) {
102+
return new Builder(chatModel);
103+
}
104+
105+
public static class Builder {
106+
107+
private final ChatModel chatModel;
108+
109+
private int keywordCount;
110+
111+
private PromptTemplate keywordsTemplate;
112+
113+
public Builder(ChatModel chatModel) {
114+
Assert.notNull(chatModel, "The chatModel must not be null");
115+
this.chatModel = chatModel;
116+
}
117+
118+
public Builder keywordCount(int keywordCount) {
119+
Assert.isTrue(keywordCount >= 1, "The keywordCount must be >= 1");
120+
this.keywordCount = keywordCount;
121+
return this;
122+
}
123+
124+
public Builder keywordsTemplate(PromptTemplate keywordsTemplate) {
125+
Assert.notNull(keywordsTemplate, "The keywordsTemplate must not be null");
126+
this.keywordsTemplate = keywordsTemplate;
127+
return this;
128+
}
129+
130+
public KeywordMetadataEnricher build() {
131+
if (this.keywordsTemplate != null) {
132+
133+
if (this.keywordCount != 0) {
134+
logger.warn("keywordCount will be ignored as keywordsTemplate is set.");
135+
}
136+
137+
return new KeywordMetadataEnricher(this.chatModel, this.keywordsTemplate);
138+
}
139+
140+
return new KeywordMetadataEnricher(this.chatModel, this.keywordCount);
141+
}
142+
143+
}
144+
74145
}
Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
package org.springframework.ai.model.transformer;
2+
3+
import java.util.List;
4+
import java.util.Map;
5+
6+
import org.junit.jupiter.api.Test;
7+
import org.junit.jupiter.api.extension.ExtendWith;
8+
import org.mockito.ArgumentCaptor;
9+
import org.mockito.Captor;
10+
import org.mockito.Mock;
11+
import org.mockito.junit.jupiter.MockitoExtension;
12+
13+
import org.springframework.ai.chat.messages.AssistantMessage;
14+
import org.springframework.ai.chat.model.ChatModel;
15+
import org.springframework.ai.chat.model.ChatResponse;
16+
import org.springframework.ai.chat.model.Generation;
17+
import org.springframework.ai.chat.prompt.Prompt;
18+
import org.springframework.ai.chat.prompt.PromptTemplate;
19+
import org.springframework.ai.document.Document;
20+
21+
import static org.assertj.core.api.Assertions.assertThat;
22+
import static org.junit.jupiter.api.Assertions.assertThrows;
23+
import static org.mockito.BDDMockito.given;
24+
import static org.mockito.Mockito.*;
25+
import static org.springframework.ai.model.transformer.KeywordMetadataEnricher.*;
26+
27+
/**
28+
* @author YunKui Lu
29+
*/
30+
@ExtendWith(MockitoExtension.class)
31+
class KeywordMetadataEnricherTest {
32+
33+
@Mock
34+
private ChatModel chatModel;
35+
36+
@Captor
37+
private ArgumentCaptor<Prompt> promptCaptor;
38+
39+
private final String CUSTOM_TEMPLATE = "Custom template: {context_str}";
40+
41+
@Test
42+
void testUseWithDefaultTemplate() {
43+
// 1. Prepare test data
44+
// @formatter:off
45+
List<Document> documents = List.of(
46+
new Document("content1"),
47+
new Document("content2"),
48+
new Document("content3"));// @formatter:on
49+
int keywordCount = 3;
50+
51+
// 2. Mock
52+
given(chatModel.call(any(Prompt.class))).willReturn(
53+
new ChatResponse(List.of(new Generation(new AssistantMessage("keyword1-1, keyword1-2, keyword1-3")))),
54+
new ChatResponse(List.of(new Generation(new AssistantMessage("keyword2-1, keyword2-2, keyword2-3")))),
55+
new ChatResponse(List.of(new Generation(new AssistantMessage("keyword3-1, keyword3-2, keyword3-3")))));
56+
57+
// 3. Create instance
58+
KeywordMetadataEnricher keywordMetadataEnricher = new KeywordMetadataEnricher(chatModel, keywordCount);
59+
60+
// 4. Apply
61+
keywordMetadataEnricher.apply(documents);
62+
63+
// 5. Assert
64+
verify(chatModel, times(3)).call(promptCaptor.capture());
65+
66+
assertThat(promptCaptor.getAllValues().get(0).getUserMessage().getText())
67+
.isEqualTo(getDefaultTemplatePromptText(keywordCount, "content1"));
68+
assertThat(promptCaptor.getAllValues().get(1).getUserMessage().getText())
69+
.isEqualTo(getDefaultTemplatePromptText(keywordCount, "content2"));
70+
assertThat(promptCaptor.getAllValues().get(2).getUserMessage().getText())
71+
.isEqualTo(getDefaultTemplatePromptText(keywordCount, "content3"));
72+
73+
assertThat(documents.get(0).getMetadata()).containsEntry(EXCERPT_KEYWORDS_METADATA_KEY,
74+
"keyword1-1, keyword1-2, keyword1-3");
75+
assertThat(documents.get(1).getMetadata()).containsEntry(EXCERPT_KEYWORDS_METADATA_KEY,
76+
"keyword2-1, keyword2-2, keyword2-3");
77+
assertThat(documents.get(2).getMetadata()).containsEntry(EXCERPT_KEYWORDS_METADATA_KEY,
78+
"keyword3-1, keyword3-2, keyword3-3");
79+
}
80+
81+
@Test
82+
void testUseCustomTemplate() {
83+
// 1. Prepare test data
84+
// @formatter:off
85+
List<Document> documents = List.of(
86+
new Document("content1"),
87+
new Document("content2"),
88+
new Document("content3"));// @formatter:on
89+
PromptTemplate promptTemplate = new PromptTemplate(CUSTOM_TEMPLATE);
90+
91+
// 2. Mock
92+
given(chatModel.call(any(Prompt.class))).willReturn(
93+
new ChatResponse(List.of(new Generation(new AssistantMessage("keyword1-1, keyword1-2, keyword1-3")))),
94+
new ChatResponse(List.of(new Generation(new AssistantMessage("keyword2-1, keyword2-2, keyword2-3")))),
95+
new ChatResponse(List.of(new Generation(new AssistantMessage("keyword3-1, keyword3-2, keyword3-3")))));
96+
97+
// 3. Create instance
98+
KeywordMetadataEnricher keywordMetadataEnricher = new KeywordMetadataEnricher(this.chatModel, promptTemplate);
99+
100+
// 4. Apply
101+
keywordMetadataEnricher.apply(documents);
102+
103+
// 5. Assert
104+
verify(chatModel, times(documents.size())).call(promptCaptor.capture());
105+
106+
assertThat(promptCaptor.getAllValues().get(0).getUserMessage().getText())
107+
.isEqualTo("Custom template: content1");
108+
assertThat(promptCaptor.getAllValues().get(1).getUserMessage().getText())
109+
.isEqualTo("Custom template: content2");
110+
assertThat(promptCaptor.getAllValues().get(2).getUserMessage().getText())
111+
.isEqualTo("Custom template: content3");
112+
113+
assertThat(documents.get(0).getMetadata()).containsEntry(EXCERPT_KEYWORDS_METADATA_KEY,
114+
"keyword1-1, keyword1-2, keyword1-3");
115+
assertThat(documents.get(1).getMetadata()).containsEntry(EXCERPT_KEYWORDS_METADATA_KEY,
116+
"keyword2-1, keyword2-2, keyword2-3");
117+
assertThat(documents.get(2).getMetadata()).containsEntry(EXCERPT_KEYWORDS_METADATA_KEY,
118+
"keyword3-1, keyword3-2, keyword3-3");
119+
}
120+
121+
@Test
122+
void testConstructorThrowsException() {
123+
assertThrows(IllegalArgumentException.class, () -> new KeywordMetadataEnricher(null, 3),
124+
"chatModel must not be null");
125+
126+
assertThrows(IllegalArgumentException.class, () -> new KeywordMetadataEnricher(chatModel, 0),
127+
"keywordCount must be >= 1");
128+
129+
assertThrows(IllegalArgumentException.class, () -> new KeywordMetadataEnricher(chatModel, null),
130+
"keywordsTemplate must not be null");
131+
}
132+
133+
@Test
134+
void testBuilderThrowsException() {
135+
assertThrows(IllegalArgumentException.class, () -> KeywordMetadataEnricher.builder(null),
136+
"The chatModel must not be null");
137+
138+
Builder builder = builder(chatModel);
139+
assertThrows(IllegalArgumentException.class, () -> builder.keywordCount(0), "The keywordCount must be >= 1");
140+
141+
assertThrows(IllegalArgumentException.class, () -> builder.keywordsTemplate(null),
142+
"The keywordsTemplate must not be null");
143+
}
144+
145+
@Test
146+
void testBuilderWithKeywordCount() {
147+
int keywordCount = 3;
148+
KeywordMetadataEnricher enricher = builder(chatModel).keywordCount(keywordCount).build();
149+
150+
assertThat(enricher.getKeywordsTemplate().getTemplate())
151+
.isEqualTo(String.format(KEYWORDS_TEMPLATE, keywordCount));
152+
}
153+
154+
@Test
155+
void testBuilderWithKeywordsTemplate() {
156+
PromptTemplate template = new PromptTemplate(CUSTOM_TEMPLATE);
157+
KeywordMetadataEnricher enricher = builder(chatModel).keywordsTemplate(template).build();
158+
159+
assertThat(enricher).extracting("chatModel", "keywordsTemplate").containsExactly(chatModel, template);
160+
}
161+
162+
private String getDefaultTemplatePromptText(int keywordCount, String documentContent) {
163+
PromptTemplate promptTemplate = new PromptTemplate(String.format(KEYWORDS_TEMPLATE, keywordCount));
164+
Prompt prompt = promptTemplate.create(Map.of(CONTEXT_STR_PLACEHOLDER, documentContent));
165+
return prompt.getContents();
166+
}
167+
168+
}

0 commit comments

Comments
 (0)