Skip to content

Commit 8939148

Browse files
ThomasVitalemarkpollack
authored andcommitted
Fixes for memory advisors after recent refactoring
* The defaultConversationId was configurable, but not used. It’s now being used correctly when a custom defaultConversationId is defined. * The memory advisors were missing the required configuration of a Schedule due to a default value missing. Now as default Scheduler is used, automatically protecting from blocking. It can be customised via “scheduler()”, replacing the old “protectFromBlocking()” method. * The new defaultTopK options in VectorStoreChatMemoryAdvisor were documented, but not implemented. That is fixed now. * The memory advisors were not null-safe. Now they are. * Improved tests to check the null-safe behaviour. * Updated the documentation accordingly. Fixes gh-3133 Signed-off-by: Thomas Vitale <ThomasVitale@users.noreply.github.com>
1 parent f346092 commit 8939148

File tree

9 files changed

+228
-90
lines changed

9 files changed

+228
-90
lines changed

advisors/spring-ai-advisors-vector-store/src/main/java/org/springframework/ai/chat/client/advisor/vectorstore/VectorStoreChatMemoryAdvisor.java

Lines changed: 22 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,8 @@
2121
import java.util.List;
2222
import java.util.Map;
2323

24-
import org.slf4j.Logger;
25-
import org.slf4j.LoggerFactory;
24+
import org.springframework.util.Assert;
2625
import reactor.core.scheduler.Scheduler;
27-
import reactor.core.scheduler.Schedulers;
2826

2927
import org.springframework.ai.chat.client.ChatClientRequest;
3028
import org.springframework.ai.chat.client.ChatClientResponse;
@@ -54,16 +52,13 @@
5452
*/
5553
public class VectorStoreChatMemoryAdvisor implements BaseChatMemoryAdvisor {
5654

57-
public static final String CHAT_MEMORY_RETRIEVE_SIZE_KEY = "chat_memory_response_size";
55+
public static final String TOP_K = "chat_memory_vector_store_top_k";
5856

5957
private static final String DOCUMENT_METADATA_CONVERSATION_ID = "conversationId";
6058

6159
private static final String DOCUMENT_METADATA_MESSAGE_TYPE = "messageType";
6260

63-
/**
64-
* The default chat memory retrieve size to use when no retrieve size is provided.
65-
*/
66-
public static final int DEFAULT_TOP_K = 20;
61+
private static final int DEFAULT_TOP_K = 20;
6762

6863
private static final PromptTemplate DEFAULT_SYSTEM_PROMPT_TEMPLATE = new PromptTemplate("""
6964
{instructions}
@@ -78,20 +73,25 @@ public class VectorStoreChatMemoryAdvisor implements BaseChatMemoryAdvisor {
7873

7974
private final PromptTemplate systemPromptTemplate;
8075

81-
protected final int defaultChatMemoryRetrieveSize;
76+
private final int defaultTopK;
8277

8378
private final String defaultConversationId;
8479

8580
private final int order;
8681

8782
private final Scheduler scheduler;
8883

89-
private VectorStore vectorStore;
84+
private final VectorStore vectorStore;
9085

91-
public VectorStoreChatMemoryAdvisor(PromptTemplate systemPromptTemplate, int defaultChatMemoryRetrieveSize,
86+
private VectorStoreChatMemoryAdvisor(PromptTemplate systemPromptTemplate, int defaultTopK,
9287
String defaultConversationId, int order, Scheduler scheduler, VectorStore vectorStore) {
88+
Assert.notNull(systemPromptTemplate, "systemPromptTemplate cannot be null");
89+
Assert.isTrue(defaultTopK > 0, "topK must be greater than 0");
90+
Assert.hasText(defaultConversationId, "defaultConversationId cannot be null or empty");
91+
Assert.notNull(scheduler, "scheduler cannot be null");
92+
Assert.notNull(vectorStore, "vectorStore cannot be null");
9393
this.systemPromptTemplate = systemPromptTemplate;
94-
this.defaultChatMemoryRetrieveSize = defaultChatMemoryRetrieveSize;
94+
this.defaultTopK = defaultTopK;
9595
this.defaultConversationId = defaultConversationId;
9696
this.order = order;
9797
this.scheduler = scheduler;
@@ -114,7 +114,7 @@ public Scheduler getScheduler() {
114114

115115
@Override
116116
public ChatClientRequest before(ChatClientRequest request, AdvisorChain advisorChain) {
117-
String conversationId = getConversationId(request.context());
117+
String conversationId = getConversationId(request.context(), this.defaultConversationId);
118118
String query = request.prompt().getUserMessage() != null ? request.prompt().getUserMessage().getText() : "";
119119
int topK = getChatMemoryTopK(request.context());
120120
String filter = DOCUMENT_METADATA_CONVERSATION_ID + "=='" + conversationId + "'";
@@ -149,9 +149,7 @@ public ChatClientRequest before(ChatClientRequest request, AdvisorChain advisorC
149149
}
150150

151151
private int getChatMemoryTopK(Map<String, Object> context) {
152-
return context.containsKey(CHAT_MEMORY_RETRIEVE_SIZE_KEY)
153-
? Integer.parseInt(context.get(CHAT_MEMORY_RETRIEVE_SIZE_KEY).toString())
154-
: this.defaultChatMemoryRetrieveSize;
152+
return context.containsKey(TOP_K) ? Integer.parseInt(context.get(TOP_K).toString()) : this.defaultTopK;
155153
}
156154

157155
@Override
@@ -164,7 +162,8 @@ public ChatClientResponse after(ChatClientResponse chatClientResponse, AdvisorCh
164162
.map(g -> (Message) g.getOutput())
165163
.toList();
166164
}
167-
this.vectorStore.write(toDocuments(assistantMessages, this.getConversationId(chatClientResponse.context())));
165+
this.vectorStore.write(toDocuments(assistantMessages,
166+
this.getConversationId(chatClientResponse.context(), this.defaultConversationId)));
168167
return chatClientResponse;
169168
}
170169

@@ -202,11 +201,11 @@ public static class Builder {
202201

203202
private PromptTemplate systemPromptTemplate = DEFAULT_SYSTEM_PROMPT_TEMPLATE;
204203

205-
private Integer topK = DEFAULT_TOP_K;
204+
private Integer defaultTopK = DEFAULT_TOP_K;
206205

207206
private String conversationId = ChatMemory.DEFAULT_CONVERSATION_ID;
208207

209-
private Scheduler scheduler;
208+
private Scheduler scheduler = BaseAdvisor.DEFAULT_SCHEDULER;
210209

211210
private int order = Advisor.DEFAULT_CHAT_MEMORY_PRECEDENCE_ORDER;
212211

@@ -232,11 +231,11 @@ public Builder systemPromptTemplate(PromptTemplate systemPromptTemplate) {
232231

233232
/**
234233
* Set the chat memory retrieve size.
235-
* @param topK the chat memory retrieve size
234+
* @param defaultTopK the chat memory retrieve size
236235
* @return this builder
237236
*/
238-
public Builder topK(int topK) {
239-
this.topK = topK;
237+
public Builder defaultTopK(int defaultTopK) {
238+
this.defaultTopK = defaultTopK;
240239
return this;
241240
}
242241

@@ -250,16 +249,6 @@ public Builder conversationId(String conversationId) {
250249
return this;
251250
}
252251

253-
/**
254-
* Set whether to protect from blocking.
255-
* @param protectFromBlocking whether to protect from blocking
256-
* @return the builder
257-
*/
258-
public Builder protectFromBlocking(boolean protectFromBlocking) {
259-
this.scheduler = protectFromBlocking ? BaseAdvisor.DEFAULT_SCHEDULER : Schedulers.immediate();
260-
return this;
261-
}
262-
263252
public Builder scheduler(Scheduler scheduler) {
264253
this.scheduler = scheduler;
265254
return this;
@@ -280,7 +269,7 @@ public Builder order(int order) {
280269
* @return the advisor
281270
*/
282271
public VectorStoreChatMemoryAdvisor build() {
283-
return new VectorStoreChatMemoryAdvisor(this.systemPromptTemplate, this.topK, this.conversationId,
272+
return new VectorStoreChatMemoryAdvisor(this.systemPromptTemplate, this.defaultTopK, this.conversationId,
284273
this.order, this.scheduler, this.vectorStore);
285274
}
286275

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
package org.springframework.ai.chat.client.advisor.vectorstore;
2+
3+
import org.junit.jupiter.api.Test;
4+
import org.mockito.Mockito;
5+
import org.springframework.ai.vectorstore.VectorStore;
6+
7+
import static org.assertj.core.api.Assertions.assertThatThrownBy;
8+
9+
/**
10+
* Unit tests for {@link VectorStoreChatMemoryAdvisor}.
11+
*
12+
* @author Thomas Vitale
13+
*/
14+
class VectorStoreChatMemoryAdvisorTests {
15+
16+
@Test
17+
void whenVectorStoreIsNullThenThrow() {
18+
assertThatThrownBy(() -> VectorStoreChatMemoryAdvisor.builder(null).build())
19+
.isInstanceOf(IllegalArgumentException.class)
20+
.hasMessageContaining("vectorStore cannot be null");
21+
}
22+
23+
@Test
24+
void whenDefaultConversationIdIsNullThenThrow() {
25+
VectorStore vectorStore = Mockito.mock(VectorStore.class);
26+
27+
assertThatThrownBy(() -> VectorStoreChatMemoryAdvisor.builder(vectorStore).conversationId(null).build())
28+
.isInstanceOf(IllegalArgumentException.class)
29+
.hasMessageContaining("defaultConversationId cannot be null or empty");
30+
}
31+
32+
@Test
33+
void whenDefaultConversationIdIsEmptyThenThrow() {
34+
VectorStore vectorStore = Mockito.mock(VectorStore.class);
35+
36+
assertThatThrownBy(() -> VectorStoreChatMemoryAdvisor.builder(vectorStore).conversationId(null).build())
37+
.isInstanceOf(IllegalArgumentException.class)
38+
.hasMessageContaining("defaultConversationId cannot be null or empty");
39+
}
40+
41+
@Test
42+
void whenSchedulerIsNullThenThrow() {
43+
VectorStore vectorStore = Mockito.mock(VectorStore.class);
44+
45+
assertThatThrownBy(() -> VectorStoreChatMemoryAdvisor.builder(vectorStore).scheduler(null).build())
46+
.isInstanceOf(IllegalArgumentException.class)
47+
.hasMessageContaining("scheduler cannot be null");
48+
}
49+
50+
@Test
51+
void whenSystemPromptTemplateIsNullThenThrow() {
52+
VectorStore vectorStore = Mockito.mock(VectorStore.class);
53+
54+
assertThatThrownBy(() -> VectorStoreChatMemoryAdvisor.builder(vectorStore).systemPromptTemplate(null).build())
55+
.isInstanceOf(IllegalArgumentException.class)
56+
.hasMessageContaining("systemPromptTemplate cannot be null");
57+
}
58+
59+
@Test
60+
void whenDefaultTopKIsZeroThenThrow() {
61+
VectorStore vectorStore = Mockito.mock(VectorStore.class);
62+
63+
assertThatThrownBy(() -> VectorStoreChatMemoryAdvisor.builder(vectorStore).defaultTopK(0).build())
64+
.isInstanceOf(IllegalArgumentException.class)
65+
.hasMessageContaining("topK must be greater than 0");
66+
}
67+
68+
@Test
69+
void whenDefaultTopKIsNegativeThenThrow() {
70+
VectorStore vectorStore = Mockito.mock(VectorStore.class);
71+
72+
assertThatThrownBy(() -> VectorStoreChatMemoryAdvisor.builder(vectorStore).defaultTopK(-1).build())
73+
.isInstanceOf(IllegalArgumentException.class)
74+
.hasMessageContaining("topK must be greater than 0");
75+
}
76+
77+
}

spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/MessageChatMemoryAdvisor.java

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,9 @@
1818

1919
import java.util.ArrayList;
2020
import java.util.List;
21-
import java.util.Map;
2221

23-
import org.slf4j.Logger;
24-
import org.slf4j.LoggerFactory;
22+
import org.springframework.util.Assert;
2523
import reactor.core.scheduler.Scheduler;
26-
import reactor.core.scheduler.Schedulers;
2724

2825
import org.springframework.ai.chat.client.ChatClientRequest;
2926
import org.springframework.ai.chat.client.ChatClientResponse;
@@ -40,12 +37,11 @@
4037
*
4138
* @author Christian Tzolov
4239
* @author Mark Pollack
40+
* @author Thomas Vitale
4341
* @since 1.0.0
4442
*/
4543
public class MessageChatMemoryAdvisor implements BaseChatMemoryAdvisor {
4644

47-
private static final Logger logger = LoggerFactory.getLogger(MessageChatMemoryAdvisor.class);
48-
4945
private final ChatMemory chatMemory;
5046

5147
private final String defaultConversationId;
@@ -56,6 +52,9 @@ public class MessageChatMemoryAdvisor implements BaseChatMemoryAdvisor {
5652

5753
private MessageChatMemoryAdvisor(ChatMemory chatMemory, String defaultConversationId, int order,
5854
Scheduler scheduler) {
55+
Assert.notNull(chatMemory, "chatMemory cannot be null");
56+
Assert.hasText(defaultConversationId, "defaultConversationId cannot be null or empty");
57+
Assert.notNull(scheduler, "scheduler cannot be null");
5958
this.chatMemory = chatMemory;
6059
this.defaultConversationId = defaultConversationId;
6160
this.order = order;
@@ -74,7 +73,7 @@ public Scheduler getScheduler() {
7473

7574
@Override
7675
public ChatClientRequest before(ChatClientRequest chatClientRequest, AdvisorChain advisorChain) {
77-
String conversationId = getConversationId(chatClientRequest.context());
76+
String conversationId = getConversationId(chatClientRequest.context(), this.defaultConversationId);
7877

7978
// 1. Retrieve the chat memory for the current conversation.
8079
List<Message> memoryMessages = this.chatMemory.get(conversationId);
@@ -105,7 +104,8 @@ public ChatClientResponse after(ChatClientResponse chatClientResponse, AdvisorCh
105104
.map(g -> (Message) g.getOutput())
106105
.toList();
107106
}
108-
this.chatMemory.add(this.getConversationId(chatClientResponse.context()), assistantMessages);
107+
this.chatMemory.add(this.getConversationId(chatClientResponse.context(), this.defaultConversationId),
108+
assistantMessages);
109109
return chatClientResponse;
110110
}
111111

@@ -119,7 +119,7 @@ public static class Builder {
119119

120120
private int order = Advisor.DEFAULT_CHAT_MEMORY_PRECEDENCE_ORDER;
121121

122-
private Scheduler scheduler;
122+
private Scheduler scheduler = BaseAdvisor.DEFAULT_SCHEDULER;
123123

124124
private ChatMemory chatMemory;
125125

@@ -137,16 +137,6 @@ public Builder conversationId(String conversationId) {
137137
return this;
138138
}
139139

140-
/**
141-
* Set whether to protect from blocking.
142-
* @param protectFromBlocking whether to protect from blocking
143-
* @return the builder
144-
*/
145-
public Builder protectFromBlocking(boolean protectFromBlocking) {
146-
this.scheduler = protectFromBlocking ? BaseAdvisor.DEFAULT_SCHEDULER : Schedulers.immediate();
147-
return this;
148-
}
149-
150140
/**
151141
* Set the order.
152142
* @param order the order

spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/PromptChatMemoryAdvisor.java

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,10 @@
2323

2424
import org.slf4j.Logger;
2525
import org.slf4j.LoggerFactory;
26+
import org.springframework.util.Assert;
2627
import reactor.core.publisher.Flux;
2728
import reactor.core.publisher.Mono;
2829
import reactor.core.scheduler.Scheduler;
29-
import reactor.core.scheduler.Schedulers;
3030

3131
import org.springframework.ai.chat.client.ChatClientMessageAggregator;
3232
import org.springframework.ai.chat.client.ChatClientRequest;
@@ -80,6 +80,10 @@ public class PromptChatMemoryAdvisor implements BaseChatMemoryAdvisor {
8080

8181
private PromptChatMemoryAdvisor(ChatMemory chatMemory, String defaultConversationId, int order, Scheduler scheduler,
8282
PromptTemplate systemPromptTemplate) {
83+
Assert.notNull(chatMemory, "chatMemory cannot be null");
84+
Assert.hasText(defaultConversationId, "defaultConversationId cannot be null or empty");
85+
Assert.notNull(scheduler, "scheduler cannot be null");
86+
Assert.notNull(systemPromptTemplate, "systemPromptTemplate cannot be null");
8387
this.chatMemory = chatMemory;
8488
this.defaultConversationId = defaultConversationId;
8589
this.order = order;
@@ -103,7 +107,7 @@ public Scheduler getScheduler() {
103107

104108
@Override
105109
public ChatClientRequest before(ChatClientRequest chatClientRequest, AdvisorChain advisorChain) {
106-
String conversationId = getConversationId(chatClientRequest.context());
110+
String conversationId = getConversationId(chatClientRequest.context(), this.defaultConversationId);
107111
// 1. Retrieve the chat memory for the current conversation.
108112
List<Message> memoryMessages = this.chatMemory.get(conversationId);
109113
logger.debug("[PromptChatMemoryAdvisor.before] Memory before processing for conversationId={}: {}",
@@ -151,12 +155,15 @@ else if (chatClientResponse.chatResponse() != null && chatClientResponse.chatRes
151155
}
152156

153157
if (!assistantMessages.isEmpty()) {
154-
this.chatMemory.add(this.getConversationId(chatClientResponse.context()), assistantMessages);
158+
this.chatMemory.add(this.getConversationId(chatClientResponse.context(), this.defaultConversationId),
159+
assistantMessages);
155160
logger.debug("[PromptChatMemoryAdvisor.after] Added ASSISTANT messages to memory for conversationId={}: {}",
156-
this.getConversationId(chatClientResponse.context()), assistantMessages);
157-
List<Message> memoryMessages = this.chatMemory.get(this.getConversationId(chatClientResponse.context()));
161+
this.getConversationId(chatClientResponse.context(), this.defaultConversationId),
162+
assistantMessages);
163+
List<Message> memoryMessages = this.chatMemory
164+
.get(this.getConversationId(chatClientResponse.context(), this.defaultConversationId));
158165
logger.debug("[PromptChatMemoryAdvisor.after] Memory after ASSISTANT add for conversationId={}: {}",
159-
this.getConversationId(chatClientResponse.context()), memoryMessages);
166+
this.getConversationId(chatClientResponse.context(), this.defaultConversationId), memoryMessages);
160167
}
161168
return chatClientResponse;
162169
}
@@ -215,16 +222,6 @@ public Builder conversationId(String conversationId) {
215222
return this;
216223
}
217224

218-
/**
219-
* Set whether to protect from blocking.
220-
* @param protectFromBlocking whether to protect from blocking
221-
* @return the builder
222-
*/
223-
public Builder protectFromBlocking(boolean protectFromBlocking) {
224-
this.scheduler = protectFromBlocking ? BaseAdvisor.DEFAULT_SCHEDULER : Schedulers.immediate();
225-
return this;
226-
}
227-
228225
public Builder scheduler(Scheduler scheduler) {
229226
this.scheduler = scheduler;
230227
return this;

0 commit comments

Comments
 (0)