Skip to content

Commit 15fe2e5

Browse files
committed
Make chat history available to RetrievalAugmentationAdvisor
* Extend Query with conversation history in RetrievalAugmentationAdvisor * Add integration tests for query compression and rewrite Signed-off-by: Thomas Vitale <ThomasVitale@users.noreply.github.com>
1 parent d7fe07b commit 15fe2e5

File tree

2 files changed

+78
-2
lines changed

2 files changed

+78
-2
lines changed

spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/RetrievalAugmentationAdvisor.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,11 @@ public static Builder builder() {
102102
public AdvisedRequest before(AdvisedRequest request) {
103103
Map<String, Object> context = new HashMap<>(request.adviseContext());
104104

105-
// 0. Create a query from the user text and parameters.
106-
Query originalQuery = new Query(new PromptTemplate(request.userText(), request.userParams()).render());
105+
// 0. Create a query from the user text, parameters, and conversation history.
106+
Query originalQuery = Query.builder()
107+
.text(new PromptTemplate(request.userText(), request.userParams()).render())
108+
.history(request.messages())
109+
.build();
107110

108111
// 1. Transform original user query based on a chain of query transformers.
109112
Query transformedQuery = originalQuery;

spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/client/advisor/RetrievalAugmentationAdvisorIT.java

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,10 @@
2424
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
2525

2626
import org.springframework.ai.chat.client.ChatClient;
27+
import org.springframework.ai.chat.client.advisor.AbstractChatMemoryAdvisor;
28+
import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor;
2729
import org.springframework.ai.chat.client.advisor.RetrievalAugmentationAdvisor;
30+
import org.springframework.ai.chat.memory.InMemoryChatMemory;
2831
import org.springframework.ai.chat.model.ChatResponse;
2932
import org.springframework.ai.document.Document;
3033
import org.springframework.ai.document.DocumentReader;
@@ -34,6 +37,8 @@
3437
import org.springframework.ai.integration.tests.TestApplication;
3538
import org.springframework.ai.openai.OpenAiChatModel;
3639
import org.springframework.ai.rag.preretrieval.query.expansion.MultiQueryExpander;
40+
import org.springframework.ai.rag.preretrieval.query.transformation.CompressionQueryTransformer;
41+
import org.springframework.ai.rag.preretrieval.query.transformation.RewriteQueryTransformer;
3742
import org.springframework.ai.rag.preretrieval.query.transformation.TranslationQueryTransformer;
3843
import org.springframework.ai.rag.retrieval.search.VectorStoreDocumentRetriever;
3944
import org.springframework.ai.reader.markdown.MarkdownDocumentReader;
@@ -103,6 +108,74 @@ void ragBasic() {
103108
evaluateRelevancy(question, chatResponse);
104109
}
105110

111+
@Test
112+
void ragWithCompression() {
113+
MessageChatMemoryAdvisor memoryAdvisor = MessageChatMemoryAdvisor.builder(new InMemoryChatMemory()).build();
114+
115+
RetrievalAugmentationAdvisor ragAdvisor = RetrievalAugmentationAdvisor.builder()
116+
.queryTransformers(CompressionQueryTransformer.builder()
117+
.chatClientBuilder(ChatClient.builder(this.openAiChatModel))
118+
.build())
119+
.documentRetriever(VectorStoreDocumentRetriever.builder().vectorStore(this.pgVectorStore).build())
120+
.build();
121+
122+
ChatClient chatClient = ChatClient.builder(this.openAiChatModel)
123+
.defaultAdvisors(memoryAdvisor, ragAdvisor)
124+
.build();
125+
126+
String conversationId = "007";
127+
128+
ChatResponse chatResponse1 = chatClient.prompt()
129+
.user("Where does the adventure of Anacletus and Birba take place?")
130+
.advisors(advisors -> advisors.param(AbstractChatMemoryAdvisor.CHAT_MEMORY_CONVERSATION_ID_KEY,
131+
conversationId))
132+
.call()
133+
.chatResponse();
134+
135+
assertThat(chatResponse1).isNotNull();
136+
String response1 = chatResponse1.getResult().getOutput().getText();
137+
System.out.println(response1);
138+
139+
ChatResponse chatResponse2 = chatClient.prompt()
140+
.user("Did they meet any cow?")
141+
.advisors(advisors -> advisors.param(AbstractChatMemoryAdvisor.CHAT_MEMORY_CONVERSATION_ID_KEY,
142+
conversationId))
143+
.call()
144+
.chatResponse();
145+
146+
assertThat(chatResponse2).isNotNull();
147+
String response2 = chatResponse2.getResult().getOutput().getText();
148+
System.out.println(response2);
149+
assertThat(response2.toLowerCase()).containsIgnoringCase("Fergus");
150+
}
151+
152+
@Test
153+
void ragWithRewrite() {
154+
String question = "Where are the main characters going?";
155+
156+
RetrievalAugmentationAdvisor ragAdvisor = RetrievalAugmentationAdvisor.builder()
157+
.queryTransformers(RewriteQueryTransformer.builder()
158+
.chatClientBuilder(ChatClient.builder(this.openAiChatModel))
159+
.targetSearchSystem("vector store")
160+
.build())
161+
.documentRetriever(VectorStoreDocumentRetriever.builder().vectorStore(this.pgVectorStore).build())
162+
.build();
163+
164+
ChatResponse chatResponse = ChatClient.builder(this.openAiChatModel).build().prompt()
165+
.user(question)
166+
.advisors(ragAdvisor)
167+
.call()
168+
.chatResponse();
169+
170+
assertThat(chatResponse).isNotNull();
171+
172+
String response = chatResponse.getResult().getOutput().getText();
173+
System.out.println(response);
174+
assertThat(response).containsIgnoringCase("Loch of the Stars");
175+
176+
evaluateRelevancy(question, chatResponse);
177+
}
178+
106179
@Test
107180
void ragWithTranslation() {
108181
String question = "Hvor finder Anacletus og Birbas eventyr sted?";

0 commit comments

Comments
 (0)