20
20
import java .util .HashMap ;
21
21
import java .util .List ;
22
22
import java .util .Map ;
23
- import java .util .stream .Collectors ;
24
23
25
- import reactor .core .publisher .Flux ;
24
+ import org .slf4j .Logger ;
25
+ import org .slf4j .LoggerFactory ;
26
+ import reactor .core .scheduler .Scheduler ;
27
+ import reactor .core .scheduler .Schedulers ;
26
28
27
- import org .springframework .ai .chat .client .advisor .AbstractChatMemoryAdvisor ;
28
29
import org .springframework .ai .chat .client .ChatClientRequest ;
29
30
import org .springframework .ai .chat .client .ChatClientResponse ;
30
- import org .springframework .ai .chat .client .advisor .api .CallAdvisorChain ;
31
- import org .springframework .ai .chat .client .advisor .api .StreamAdvisorChain ;
31
+ import org .springframework .ai .chat .client .advisor .api .Advisor ;
32
+ import org .springframework .ai .chat .client .advisor .api .AdvisorChain ;
33
+ import org .springframework .ai .chat .client .advisor .api .BaseAdvisor ;
34
+ import org .springframework .ai .chat .client .advisor .api .BaseChatMemoryAdvisor ;
35
+ import org .springframework .ai .chat .memory .ChatMemory ;
32
36
import org .springframework .ai .chat .messages .AssistantMessage ;
33
37
import org .springframework .ai .chat .messages .Message ;
34
38
import org .springframework .ai .chat .messages .MessageType ;
35
- import org .springframework .ai .chat .messages .SystemMessage ;
36
39
import org .springframework .ai .chat .messages .UserMessage ;
37
- import org .springframework .ai .chat .model .MessageAggregator ;
38
40
import org .springframework .ai .chat .prompt .PromptTemplate ;
39
41
import org .springframework .ai .document .Document ;
40
- import org .springframework .ai .vectorstore .SearchRequest ;
41
42
import org .springframework .ai .vectorstore .VectorStore ;
42
43
43
44
/**
48
49
* @author Christian Tzolov
49
50
* @author Thomas Vitale
50
51
* @author Oganes Bozoyan
52
+ * @author Mark Pollack
51
53
* @since 1.0.0
52
54
*/
53
- public class VectorStoreChatMemoryAdvisor extends AbstractChatMemoryAdvisor <VectorStore > {
55
+ public class VectorStoreChatMemoryAdvisor implements BaseChatMemoryAdvisor {
56
+
57
+ public static final String CHAT_MEMORY_RETRIEVE_SIZE_KEY = "chat_memory_response_size" ;
54
58
55
59
private static final String DOCUMENT_METADATA_CONVERSATION_ID = "conversationId" ;
56
60
57
61
private static final String DOCUMENT_METADATA_MESSAGE_TYPE = "messageType" ;
58
62
63
+ /**
64
+ * The default chat memory retrieve size to use when no retrieve size is provided.
65
+ */
66
+ public static final int DEFAULT_TOP_K = 20 ;
67
+
59
68
private static final PromptTemplate DEFAULT_SYSTEM_PROMPT_TEMPLATE = new PromptTemplate ("""
60
69
{instructions}
61
70
@@ -69,71 +78,84 @@ public class VectorStoreChatMemoryAdvisor extends AbstractChatMemoryAdvisor<Vect
69
78
70
79
private final PromptTemplate systemPromptTemplate ;
71
80
72
- private VectorStoreChatMemoryAdvisor (VectorStore vectorStore , String defaultConversationId ,
73
- int chatHistoryWindowSize , boolean protectFromBlocking , PromptTemplate systemPromptTemplate , int order ) {
74
- super (vectorStore , defaultConversationId , chatHistoryWindowSize , protectFromBlocking , order );
81
+ protected final int defaultChatMemoryRetrieveSize ;
82
+
83
+ private final String defaultConversationId ;
84
+
85
+ private final int order ;
86
+
87
+ private final Scheduler scheduler ;
88
+
89
+ private VectorStore vectorStore ;
90
+
91
+ public VectorStoreChatMemoryAdvisor (PromptTemplate systemPromptTemplate , int defaultChatMemoryRetrieveSize ,
92
+ String defaultConversationId , int order , Scheduler scheduler , VectorStore vectorStore ) {
75
93
this .systemPromptTemplate = systemPromptTemplate ;
94
+ this .defaultChatMemoryRetrieveSize = defaultChatMemoryRetrieveSize ;
95
+ this .defaultConversationId = defaultConversationId ;
96
+ this .order = order ;
97
+ this .scheduler = scheduler ;
98
+ this .vectorStore = vectorStore ;
76
99
}
77
100
78
101
public static Builder builder (VectorStore chatMemory ) {
79
102
return new Builder (chatMemory );
80
103
}
81
104
82
105
@ Override
83
- public ChatClientResponse adviseCall (ChatClientRequest chatClientRequest , CallAdvisorChain callAdvisorChain ) {
84
- chatClientRequest = this .before (chatClientRequest );
85
-
86
- ChatClientResponse chatClientResponse = callAdvisorChain .nextCall (chatClientRequest );
87
-
88
- this .after (chatClientResponse );
89
-
90
- return chatClientResponse ;
106
+ public int getOrder () {
107
+ return order ;
91
108
}
92
109
93
110
@ Override
94
- public Flux <ChatClientResponse > adviseStream (ChatClientRequest chatClientRequest ,
95
- StreamAdvisorChain streamAdvisorChain ) {
96
- Flux <ChatClientResponse > chatClientResponses = this .doNextWithProtectFromBlockingBefore (chatClientRequest ,
97
- streamAdvisorChain , this ::before );
98
-
99
- return new MessageAggregator ().aggregateChatClientResponse (chatClientResponses , this ::after );
111
+ public Scheduler getScheduler () {
112
+ return this .scheduler ;
100
113
}
101
114
102
- private ChatClientRequest before (ChatClientRequest chatClientRequest ) {
103
- String conversationId = this .doGetConversationId (chatClientRequest .context ());
104
- int chatMemoryRetrieveSize = this .doGetChatMemoryRetrieveSize (chatClientRequest .context ());
105
-
106
- // 1. Retrieve the chat memory for the current conversation.
107
- var searchRequest = SearchRequest .builder ()
108
- .query (chatClientRequest .prompt ().getUserMessage ().getText ())
109
- .topK (chatMemoryRetrieveSize )
110
- .filterExpression (DOCUMENT_METADATA_CONVERSATION_ID + "=='" + conversationId + "'" )
115
+ @ Override
116
+ public ChatClientRequest before (ChatClientRequest request , AdvisorChain advisorChain ) {
117
+ String conversationId = getConversationId (request .context ());
118
+ String query = request .prompt ().getUserMessage () != null ? request .prompt ().getUserMessage ().getText () : "" ;
119
+ int topK = getChatMemoryTopK (request .context ());
120
+ String filter = DOCUMENT_METADATA_CONVERSATION_ID + "=='" + conversationId + "'" ;
121
+ var searchRequest = org .springframework .ai .vectorstore .SearchRequest .builder ()
122
+ .query (query )
123
+ .topK (topK )
124
+ .filterExpression (filter )
111
125
.build ();
126
+ java .util .List <org .springframework .ai .document .Document > documents = this .vectorStore
127
+ .similaritySearch (searchRequest );
112
128
113
- List <Document > documents = this .getChatMemoryStore ().similaritySearch (searchRequest );
114
-
115
- // 2. Processed memory messages as a string.
116
129
String longTermMemory = documents == null ? ""
117
- : documents .stream ().map (Document ::getText ).collect (Collectors .joining (System .lineSeparator ()));
130
+ : documents .stream ()
131
+ .map (org .springframework .ai .document .Document ::getText )
132
+ .collect (java .util .stream .Collectors .joining (System .lineSeparator ()));
118
133
119
- // 2. Augment the system message.
120
- SystemMessage systemMessage = chatClientRequest .prompt ().getSystemMessage ();
134
+ org .springframework .ai .chat .messages .SystemMessage systemMessage = request .prompt ().getSystemMessage ();
121
135
String augmentedSystemText = this .systemPromptTemplate
122
- .render (Map .of ("instructions" , systemMessage .getText (), "long_term_memory" , longTermMemory ));
136
+ .render (java . util . Map .of ("instructions" , systemMessage .getText (), "long_term_memory" , longTermMemory ));
123
137
124
- // 3. Create a new request with the augmented system message.
125
- ChatClientRequest processedChatClientRequest = chatClientRequest .mutate ()
126
- .prompt (chatClientRequest .prompt ().augmentSystemMessage (augmentedSystemText ))
138
+ ChatClientRequest processedChatClientRequest = request .mutate ()
139
+ .prompt (request .prompt ().augmentSystemMessage (augmentedSystemText ))
127
140
.build ();
128
141
129
- // 4. Add the new user message to the conversation memory.
130
- UserMessage userMessage = processedChatClientRequest .prompt ().getUserMessage ();
131
- this .getChatMemoryStore ().write (toDocuments (List .of (userMessage ), conversationId ));
142
+ org .springframework .ai .chat .messages .UserMessage userMessage = processedChatClientRequest .prompt ()
143
+ .getUserMessage ();
144
+ if (userMessage != null ) {
145
+ this .vectorStore .write (toDocuments (java .util .List .of (userMessage ), conversationId ));
146
+ }
132
147
133
148
return processedChatClientRequest ;
134
149
}
135
150
136
- private void after (ChatClientResponse chatClientResponse ) {
151
+ private int getChatMemoryTopK (Map <String , Object > context ) {
152
+ return context .containsKey (CHAT_MEMORY_RETRIEVE_SIZE_KEY )
153
+ ? Integer .parseInt (context .get (CHAT_MEMORY_RETRIEVE_SIZE_KEY ).toString ())
154
+ : this .defaultChatMemoryRetrieveSize ;
155
+ }
156
+
157
+ @ Override
158
+ public ChatClientResponse after (ChatClientResponse chatClientResponse , AdvisorChain advisorChain ) {
137
159
List <Message > assistantMessages = new ArrayList <>();
138
160
if (chatClientResponse .chatResponse () != null ) {
139
161
assistantMessages = chatClientResponse .chatResponse ()
@@ -142,8 +164,8 @@ private void after(ChatClientResponse chatClientResponse) {
142
164
.map (g -> (Message ) g .getOutput ())
143
165
.toList ();
144
166
}
145
- this .getChatMemoryStore ()
146
- . write ( toDocuments ( assistantMessages , this . doGetConversationId ( chatClientResponse . context ()))) ;
167
+ this .vectorStore . write ( toDocuments ( assistantMessages , this . getConversationId ( chatClientResponse . context ())));
168
+ return chatClientResponse ;
147
169
}
148
170
149
171
private List <Document > toDocuments (List <Message > messages , String conversationId ) {
@@ -173,28 +195,93 @@ else if (message instanceof AssistantMessage assistantMessage) {
173
195
return docs ;
174
196
}
175
197
176
- public static class Builder extends AbstractChatMemoryAdvisor .AbstractBuilder <VectorStore > {
198
+ /**
199
+ * Builder for VectorStoreChatMemoryAdvisor.
200
+ */
201
+ public static class Builder {
177
202
178
203
private PromptTemplate systemPromptTemplate = DEFAULT_SYSTEM_PROMPT_TEMPLATE ;
179
204
180
- protected Builder (VectorStore chatMemory ) {
181
- super (chatMemory );
182
- }
205
+ private Integer topK = DEFAULT_TOP_K ;
183
206
184
- public Builder systemTextAdvise (String systemTextAdvise ) {
185
- this .systemPromptTemplate = new PromptTemplate (systemTextAdvise );
186
- return this ;
207
+ private String conversationId = ChatMemory .DEFAULT_CONVERSATION_ID ;
208
+
209
+ private Scheduler scheduler ;
210
+
211
+ private int order = Advisor .DEFAULT_CHAT_MEMORY_PRECEDENCE_ORDER ;
212
+
213
+ private VectorStore vectorStore ;
214
+
215
+ /**
216
+ * Creates a new builder instance.
217
+ * @param vectorStore the vector store to use
218
+ */
219
+ protected Builder (VectorStore vectorStore ) {
220
+ this .vectorStore = vectorStore ;
187
221
}
188
222
223
+ /**
224
+ * Set the system prompt template.
225
+ * @param systemPromptTemplate the system prompt template
226
+ * @return this builder
227
+ */
189
228
public Builder systemPromptTemplate (PromptTemplate systemPromptTemplate ) {
190
229
this .systemPromptTemplate = systemPromptTemplate ;
191
230
return this ;
192
231
}
193
232
194
- @ Override
233
+ /**
234
+ * Set the chat memory retrieve size.
235
+ * @param topK the chat memory retrieve size
236
+ * @return this builder
237
+ */
238
+ public Builder topK (int topK ) {
239
+ this .topK = topK ;
240
+ return this ;
241
+ }
242
+
243
+ /**
244
+ * Set the conversation id.
245
+ * @param conversationId the conversation id
246
+ * @return the builder
247
+ */
248
+ public Builder conversationId (String conversationId ) {
249
+ this .conversationId = conversationId ;
250
+ return this ;
251
+ }
252
+
253
+ /**
254
+ * Set whether to protect from blocking.
255
+ * @param protectFromBlocking whether to protect from blocking
256
+ * @return the builder
257
+ */
258
+ public Builder protectFromBlocking (boolean protectFromBlocking ) {
259
+ this .scheduler = protectFromBlocking ? BaseAdvisor .DEFAULT_SCHEDULER : Schedulers .immediate ();
260
+ return this ;
261
+ }
262
+
263
+ public Builder scheduler (Scheduler scheduler ) {
264
+ this .scheduler = scheduler ;
265
+ return this ;
266
+ }
267
+
268
+ /**
269
+ * Set the order.
270
+ * @param order the order
271
+ * @return the builder
272
+ */
273
+ public Builder order (int order ) {
274
+ this .order = order ;
275
+ return this ;
276
+ }
277
+
278
+ /**
279
+ * Build the advisor.
280
+ * @return the advisor
281
+ */
195
282
public VectorStoreChatMemoryAdvisor build () {
196
- return new VectorStoreChatMemoryAdvisor (this .chatMemory , this .conversationId , this .chatMemoryRetrieveSize ,
197
- this .protectFromBlocking , this .systemPromptTemplate , this .order );
283
+ return new VectorStoreChatMemoryAdvisor (this .systemPromptTemplate , this .topK , this .conversationId ,
284
+ this .order , this .scheduler , this .vectorStore );
198
285
}
199
286
200
287
}
0 commit comments