Skip to content

Commit 848a3fd

Browse files
committed
refactor: Simplify chat memory advisor hierarchy and remove deprecated API
- Remove deprecated ChatMemory.get(String conversationId, int lastN) method - Replace AbstractChatMemoryAdvisor with BaseChatMemoryAdvisor interface in api package - Make constructors private in all memory advisor implementations to enforce builder usage - Rename CHAT_MEMORY_CONVERSATION_ID_KEY to CONVERSATION_ID and move to ChatMemory interface - In VectorStoreChatMemoryAdvisor: - Rename DEFAULT_CHAT_MEMORY_RESPONSE_SIZE (100) to DEFAULT_TOP_K (20) - Rename builder method chatMemoryRetrieveSize() to topK() - Remove systemTextAdvise() builder method - In PromptChatMemoryAdvisor: - Remove systemTextAdvise() builder method - Fix bug where only the last user message was stored from prompts with multiple messages - Enhance logging in memory advisors to aid in debugging - Add comprehensive tests for all advisor implementations: - Unit tests for builder behavior - Integration tests for the various chat memory advisors Signed-off-by: Mark Pollack <mark.pollack@broadcom.com>
1 parent 7b15e18 commit 848a3fd

File tree

22 files changed

+1685
-473
lines changed

22 files changed

+1685
-473
lines changed

advisors/spring-ai-advisors-vector-store/src/main/java/org/springframework/ai/chat/client/advisor/vectorstore/VectorStoreChatMemoryAdvisor.java

Lines changed: 148 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -20,24 +20,25 @@
2020
import java.util.HashMap;
2121
import java.util.List;
2222
import java.util.Map;
23-
import java.util.stream.Collectors;
2423

25-
import reactor.core.publisher.Flux;
24+
import org.slf4j.Logger;
25+
import org.slf4j.LoggerFactory;
26+
import reactor.core.scheduler.Scheduler;
27+
import reactor.core.scheduler.Schedulers;
2628

27-
import org.springframework.ai.chat.client.advisor.AbstractChatMemoryAdvisor;
2829
import org.springframework.ai.chat.client.ChatClientRequest;
2930
import org.springframework.ai.chat.client.ChatClientResponse;
30-
import org.springframework.ai.chat.client.advisor.api.CallAdvisorChain;
31-
import org.springframework.ai.chat.client.advisor.api.StreamAdvisorChain;
31+
import org.springframework.ai.chat.client.advisor.api.Advisor;
32+
import org.springframework.ai.chat.client.advisor.api.AdvisorChain;
33+
import org.springframework.ai.chat.client.advisor.api.BaseAdvisor;
34+
import org.springframework.ai.chat.client.advisor.api.BaseChatMemoryAdvisor;
35+
import org.springframework.ai.chat.memory.ChatMemory;
3236
import org.springframework.ai.chat.messages.AssistantMessage;
3337
import org.springframework.ai.chat.messages.Message;
3438
import org.springframework.ai.chat.messages.MessageType;
35-
import org.springframework.ai.chat.messages.SystemMessage;
3639
import org.springframework.ai.chat.messages.UserMessage;
37-
import org.springframework.ai.chat.model.MessageAggregator;
3840
import org.springframework.ai.chat.prompt.PromptTemplate;
3941
import org.springframework.ai.document.Document;
40-
import org.springframework.ai.vectorstore.SearchRequest;
4142
import org.springframework.ai.vectorstore.VectorStore;
4243

4344
/**
@@ -48,14 +49,22 @@
4849
* @author Christian Tzolov
4950
* @author Thomas Vitale
5051
* @author Oganes Bozoyan
52+
* @author Mark Pollack
5153
* @since 1.0.0
5254
*/
53-
public class VectorStoreChatMemoryAdvisor extends AbstractChatMemoryAdvisor<VectorStore> {
55+
public class VectorStoreChatMemoryAdvisor implements BaseChatMemoryAdvisor {
56+
57+
public static final String CHAT_MEMORY_RETRIEVE_SIZE_KEY = "chat_memory_response_size";
5458

5559
private static final String DOCUMENT_METADATA_CONVERSATION_ID = "conversationId";
5660

5761
private static final String DOCUMENT_METADATA_MESSAGE_TYPE = "messageType";
5862

63+
/**
64+
* The default chat memory retrieve size to use when no retrieve size is provided.
65+
*/
66+
public static final int DEFAULT_TOP_K = 20;
67+
5968
private static final PromptTemplate DEFAULT_SYSTEM_PROMPT_TEMPLATE = new PromptTemplate("""
6069
{instructions}
6170
@@ -69,71 +78,84 @@ public class VectorStoreChatMemoryAdvisor extends AbstractChatMemoryAdvisor<Vect
6978

7079
private final PromptTemplate systemPromptTemplate;
7180

72-
private VectorStoreChatMemoryAdvisor(VectorStore vectorStore, String defaultConversationId,
73-
int chatHistoryWindowSize, boolean protectFromBlocking, PromptTemplate systemPromptTemplate, int order) {
74-
super(vectorStore, defaultConversationId, chatHistoryWindowSize, protectFromBlocking, order);
81+
protected final int defaultChatMemoryRetrieveSize;
82+
83+
private final String defaultConversationId;
84+
85+
private final int order;
86+
87+
private final Scheduler scheduler;
88+
89+
private VectorStore vectorStore;
90+
91+
public VectorStoreChatMemoryAdvisor(PromptTemplate systemPromptTemplate, int defaultChatMemoryRetrieveSize,
92+
String defaultConversationId, int order, Scheduler scheduler, VectorStore vectorStore) {
7593
this.systemPromptTemplate = systemPromptTemplate;
94+
this.defaultChatMemoryRetrieveSize = defaultChatMemoryRetrieveSize;
95+
this.defaultConversationId = defaultConversationId;
96+
this.order = order;
97+
this.scheduler = scheduler;
98+
this.vectorStore = vectorStore;
7699
}
77100

78101
public static Builder builder(VectorStore chatMemory) {
79102
return new Builder(chatMemory);
80103
}
81104

82105
@Override
83-
public ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAdvisorChain callAdvisorChain) {
84-
chatClientRequest = this.before(chatClientRequest);
85-
86-
ChatClientResponse chatClientResponse = callAdvisorChain.nextCall(chatClientRequest);
87-
88-
this.after(chatClientResponse);
89-
90-
return chatClientResponse;
106+
public int getOrder() {
107+
return order;
91108
}
92109

93110
@Override
94-
public Flux<ChatClientResponse> adviseStream(ChatClientRequest chatClientRequest,
95-
StreamAdvisorChain streamAdvisorChain) {
96-
Flux<ChatClientResponse> chatClientResponses = this.doNextWithProtectFromBlockingBefore(chatClientRequest,
97-
streamAdvisorChain, this::before);
98-
99-
return new MessageAggregator().aggregateChatClientResponse(chatClientResponses, this::after);
111+
public Scheduler getScheduler() {
112+
return this.scheduler;
100113
}
101114

102-
private ChatClientRequest before(ChatClientRequest chatClientRequest) {
103-
String conversationId = this.doGetConversationId(chatClientRequest.context());
104-
int chatMemoryRetrieveSize = this.doGetChatMemoryRetrieveSize(chatClientRequest.context());
105-
106-
// 1. Retrieve the chat memory for the current conversation.
107-
var searchRequest = SearchRequest.builder()
108-
.query(chatClientRequest.prompt().getUserMessage().getText())
109-
.topK(chatMemoryRetrieveSize)
110-
.filterExpression(DOCUMENT_METADATA_CONVERSATION_ID + "=='" + conversationId + "'")
115+
@Override
116+
public ChatClientRequest before(ChatClientRequest request, AdvisorChain advisorChain) {
117+
String conversationId = getConversationId(request.context());
118+
String query = request.prompt().getUserMessage() != null ? request.prompt().getUserMessage().getText() : "";
119+
int topK = getChatMemoryTopK(request.context());
120+
String filter = DOCUMENT_METADATA_CONVERSATION_ID + "=='" + conversationId + "'";
121+
var searchRequest = org.springframework.ai.vectorstore.SearchRequest.builder()
122+
.query(query)
123+
.topK(topK)
124+
.filterExpression(filter)
111125
.build();
126+
java.util.List<org.springframework.ai.document.Document> documents = this.vectorStore
127+
.similaritySearch(searchRequest);
112128

113-
List<Document> documents = this.getChatMemoryStore().similaritySearch(searchRequest);
114-
115-
// 2. Processed memory messages as a string.
116129
String longTermMemory = documents == null ? ""
117-
: documents.stream().map(Document::getText).collect(Collectors.joining(System.lineSeparator()));
130+
: documents.stream()
131+
.map(org.springframework.ai.document.Document::getText)
132+
.collect(java.util.stream.Collectors.joining(System.lineSeparator()));
118133

119-
// 2. Augment the system message.
120-
SystemMessage systemMessage = chatClientRequest.prompt().getSystemMessage();
134+
org.springframework.ai.chat.messages.SystemMessage systemMessage = request.prompt().getSystemMessage();
121135
String augmentedSystemText = this.systemPromptTemplate
122-
.render(Map.of("instructions", systemMessage.getText(), "long_term_memory", longTermMemory));
136+
.render(java.util.Map.of("instructions", systemMessage.getText(), "long_term_memory", longTermMemory));
123137

124-
// 3. Create a new request with the augmented system message.
125-
ChatClientRequest processedChatClientRequest = chatClientRequest.mutate()
126-
.prompt(chatClientRequest.prompt().augmentSystemMessage(augmentedSystemText))
138+
ChatClientRequest processedChatClientRequest = request.mutate()
139+
.prompt(request.prompt().augmentSystemMessage(augmentedSystemText))
127140
.build();
128141

129-
// 4. Add the new user message to the conversation memory.
130-
UserMessage userMessage = processedChatClientRequest.prompt().getUserMessage();
131-
this.getChatMemoryStore().write(toDocuments(List.of(userMessage), conversationId));
142+
org.springframework.ai.chat.messages.UserMessage userMessage = processedChatClientRequest.prompt()
143+
.getUserMessage();
144+
if (userMessage != null) {
145+
this.vectorStore.write(toDocuments(java.util.List.of(userMessage), conversationId));
146+
}
132147

133148
return processedChatClientRequest;
134149
}
135150

136-
private void after(ChatClientResponse chatClientResponse) {
151+
private int getChatMemoryTopK(Map<String, Object> context) {
152+
return context.containsKey(CHAT_MEMORY_RETRIEVE_SIZE_KEY)
153+
? Integer.parseInt(context.get(CHAT_MEMORY_RETRIEVE_SIZE_KEY).toString())
154+
: this.defaultChatMemoryRetrieveSize;
155+
}
156+
157+
@Override
158+
public ChatClientResponse after(ChatClientResponse chatClientResponse, AdvisorChain advisorChain) {
137159
List<Message> assistantMessages = new ArrayList<>();
138160
if (chatClientResponse.chatResponse() != null) {
139161
assistantMessages = chatClientResponse.chatResponse()
@@ -142,8 +164,8 @@ private void after(ChatClientResponse chatClientResponse) {
142164
.map(g -> (Message) g.getOutput())
143165
.toList();
144166
}
145-
this.getChatMemoryStore()
146-
.write(toDocuments(assistantMessages, this.doGetConversationId(chatClientResponse.context())));
167+
this.vectorStore.write(toDocuments(assistantMessages, this.getConversationId(chatClientResponse.context())));
168+
return chatClientResponse;
147169
}
148170

149171
private List<Document> toDocuments(List<Message> messages, String conversationId) {
@@ -173,28 +195,93 @@ else if (message instanceof AssistantMessage assistantMessage) {
173195
return docs;
174196
}
175197

176-
public static class Builder extends AbstractChatMemoryAdvisor.AbstractBuilder<VectorStore> {
198+
/**
199+
* Builder for VectorStoreChatMemoryAdvisor.
200+
*/
201+
public static class Builder {
177202

178203
private PromptTemplate systemPromptTemplate = DEFAULT_SYSTEM_PROMPT_TEMPLATE;
179204

180-
protected Builder(VectorStore chatMemory) {
181-
super(chatMemory);
182-
}
205+
private Integer topK = DEFAULT_TOP_K;
183206

184-
public Builder systemTextAdvise(String systemTextAdvise) {
185-
this.systemPromptTemplate = new PromptTemplate(systemTextAdvise);
186-
return this;
207+
private String conversationId = ChatMemory.DEFAULT_CONVERSATION_ID;
208+
209+
private Scheduler scheduler;
210+
211+
private int order = Advisor.DEFAULT_CHAT_MEMORY_PRECEDENCE_ORDER;
212+
213+
private VectorStore vectorStore;
214+
215+
/**
216+
* Creates a new builder instance.
217+
* @param vectorStore the vector store to use
218+
*/
219+
protected Builder(VectorStore vectorStore) {
220+
this.vectorStore = vectorStore;
187221
}
188222

223+
/**
224+
* Set the system prompt template.
225+
* @param systemPromptTemplate the system prompt template
226+
* @return this builder
227+
*/
189228
public Builder systemPromptTemplate(PromptTemplate systemPromptTemplate) {
190229
this.systemPromptTemplate = systemPromptTemplate;
191230
return this;
192231
}
193232

194-
@Override
233+
/**
234+
* Set the chat memory retrieve size.
235+
* @param topK the chat memory retrieve size
236+
* @return this builder
237+
*/
238+
public Builder topK(int topK) {
239+
this.topK = topK;
240+
return this;
241+
}
242+
243+
/**
244+
* Set the conversation id.
245+
* @param conversationId the conversation id
246+
* @return the builder
247+
*/
248+
public Builder conversationId(String conversationId) {
249+
this.conversationId = conversationId;
250+
return this;
251+
}
252+
253+
/**
254+
* Set whether to protect from blocking.
255+
* @param protectFromBlocking whether to protect from blocking
256+
* @return the builder
257+
*/
258+
public Builder protectFromBlocking(boolean protectFromBlocking) {
259+
this.scheduler = protectFromBlocking ? BaseAdvisor.DEFAULT_SCHEDULER : Schedulers.immediate();
260+
return this;
261+
}
262+
263+
public Builder scheduler(Scheduler scheduler) {
264+
this.scheduler = scheduler;
265+
return this;
266+
}
267+
268+
/**
269+
* Set the order.
270+
* @param order the order
271+
* @return the builder
272+
*/
273+
public Builder order(int order) {
274+
this.order = order;
275+
return this;
276+
}
277+
278+
/**
279+
* Build the advisor.
280+
* @return the advisor
281+
*/
195282
public VectorStoreChatMemoryAdvisor build() {
196-
return new VectorStoreChatMemoryAdvisor(this.chatMemory, this.conversationId, this.chatMemoryRetrieveSize,
197-
this.protectFromBlocking, this.systemPromptTemplate, this.order);
283+
return new VectorStoreChatMemoryAdvisor(this.systemPromptTemplate, this.topK, this.conversationId,
284+
this.order, this.scheduler, this.vectorStore);
198285
}
199286

200287
}

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientMemoryAdvisorReproIT.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ void messageChatMemoryAdvisor_withPromptMessages_throwsException() {
3939
ChatMemory chatMemory = MessageWindowChatMemory.builder()
4040
.chatMemoryRepository(new InMemoryChatMemoryRepository())
4141
.build();
42-
MessageChatMemoryAdvisor advisor = new MessageChatMemoryAdvisor(chatMemory);
42+
MessageChatMemoryAdvisor advisor = MessageChatMemoryAdvisor.builder(chatMemory).build();
4343

4444
ChatClient chatClient = ChatClient.builder(chatModel).defaultAdvisors(advisor).build();
4545

0 commit comments

Comments
 (0)