37
37
import org .springframework .ai .chat .model .ChatModel ;
38
38
import org .springframework .ai .chat .model .ChatResponse ;
39
39
import org .springframework .ai .chat .model .Generation ;
40
- import org .springframework .ai .chat .model .MessageAggregator ;
41
40
import org .springframework .ai .chat .prompt .Prompt ;
42
41
43
42
import static org .assertj .core .api .Assertions .assertThat ;
44
43
import static org .mockito .BDDMockito .given ;
45
44
46
45
/**
46
+ * Tests for the ChatClient with a focus on verifying the handling of conversation memory
47
+ * and the integration of PromptChatMemoryAdvisor to ensure accurate responses based on
48
+ * previous interactions.
49
+ *
47
50
* @author Christian Tzolov
48
51
* @author Alexandros Pappas
49
52
*/
@@ -63,32 +66,33 @@ private String join(Flux<String> fluxContent) {
63
66
@ Test
64
67
public void promptChatMemory () {
65
68
66
- var builder = ChatResponseMetadata .builder ()
67
- .id ("124" )
68
- .usage (new MessageAggregator .DefaultUsage (1 , 2 , 3 ))
69
- .model ("gpt4o" )
70
- .keyValue ("created" , 0L )
71
- .keyValue ("system-fingerprint" , "john doe" );
72
- ChatResponseMetadata chatResponseMetadata = builder .build ();
69
+ // Create a ChatResponseMetadata instance with default values
70
+ ChatResponseMetadata chatResponseMetadata = ChatResponseMetadata .builder ().build ();
73
71
72
+ // Mock the chatModel to return predefined ChatResponse objects when called
74
73
given (this .chatModel .call (this .promptCaptor .capture ()))
75
74
.willReturn (
76
75
new ChatResponse (List .of (new Generation (new AssistantMessage ("Hello John" ))), chatResponseMetadata ))
77
76
.willReturn (new ChatResponse (List .of (new Generation (new AssistantMessage ("Your name is John" ))),
78
77
chatResponseMetadata ));
79
78
79
+ // Initialize an in-memory chat memory to store conversation history
80
80
ChatMemory chatMemory = new InMemoryChatMemory ();
81
81
82
+ // Build a ChatClient with default system text and a memory advisor
82
83
var chatClient = ChatClient .builder (this .chatModel )
83
84
.defaultSystem ("Default system text." )
84
85
.defaultAdvisors (new PromptChatMemoryAdvisor (chatMemory ))
85
86
.build ();
86
87
88
+ // Simulate a user prompt and verify the response
87
89
ChatResponse chatResponse = chatClient .prompt ().user ("my name is John" ).call ().chatResponse ();
88
90
91
+ // Assert that the response content matches the expected output
89
92
String content = chatResponse .getResult ().getOutput ().getText ();
90
93
assertThat (content ).isEqualTo ("Hello John" );
91
94
95
+ // Capture and verify the system message instructions
92
96
Message systemMessage = this .promptCaptor .getValue ().getInstructions ().get (0 );
93
97
assertThat (systemMessage .getText ()).isEqualToIgnoringWhitespace ("""
94
98
Default system text.
@@ -101,13 +105,17 @@ public void promptChatMemory() {
101
105
""" );
102
106
assertThat (systemMessage .getMessageType ()).isEqualTo (MessageType .SYSTEM );
103
107
108
+ // Capture and verify the user message instructions
104
109
Message userMessage = this .promptCaptor .getValue ().getInstructions ().get (1 );
105
110
assertThat (userMessage .getText ()).isEqualToIgnoringWhitespace ("my name is John" );
106
111
112
+ // Simulate another user prompt and verify the response
107
113
content = chatClient .prompt ().user ("What is my name?" ).call ().content ();
108
114
115
+ // Assert that the response content matches the expected output
109
116
assertThat (content ).isEqualTo ("Your name is John" );
110
117
118
+ // Capture and verify the updated system message instructions
111
119
systemMessage = this .promptCaptor .getValue ().getInstructions ().get (0 );
112
120
assertThat (systemMessage .getText ()).isEqualToIgnoringWhitespace ("""
113
121
Default system text.
@@ -122,13 +130,15 @@ public void promptChatMemory() {
122
130
""" );
123
131
assertThat (systemMessage .getMessageType ()).isEqualTo (MessageType .SYSTEM );
124
132
133
+ // Capture and verify the updated user message instructions
125
134
userMessage = this .promptCaptor .getValue ().getInstructions ().get (1 );
126
135
assertThat (userMessage .getText ()).isEqualToIgnoringWhitespace ("What is my name?" );
127
136
}
128
137
129
138
@ Test
130
139
public void streamingPromptChatMemory () {
131
140
141
+ // Mock the chatModel to stream predefined ChatResponse objects
132
142
given (this .chatModel .stream (this .promptCaptor .capture ())).willReturn (Flux .generate (
133
143
() -> new ChatResponse (List .of (new Generation (new AssistantMessage ("Hello John" )))), (state , sink ) -> {
134
144
sink .next (state );
@@ -143,17 +153,22 @@ public void streamingPromptChatMemory() {
143
153
return state ;
144
154
}));
145
155
156
+ // Initialize an in-memory chat memory to store conversation history
146
157
ChatMemory chatMemory = new InMemoryChatMemory ();
147
158
159
+ // Build a ChatClient with default system text and a memory advisor
148
160
var chatClient = ChatClient .builder (this .chatModel )
149
161
.defaultSystem ("Default system text." )
150
162
.defaultAdvisors (new PromptChatMemoryAdvisor (chatMemory ))
151
163
.build ();
152
164
165
+ // Simulate a streaming user prompt and verify the response
153
166
var content = join (chatClient .prompt ().user ("my name is John" ).stream ().content ());
154
167
168
+ // Assert that the streamed content matches the expected output
155
169
assertThat (content ).isEqualTo ("Hello John" );
156
170
171
+ // Capture and verify the system message instructions
157
172
Message systemMessage = this .promptCaptor .getValue ().getInstructions ().get (0 );
158
173
assertThat (systemMessage .getText ()).isEqualToIgnoringWhitespace ("""
159
174
Default system text.
@@ -166,13 +181,17 @@ public void streamingPromptChatMemory() {
166
181
""" );
167
182
assertThat (systemMessage .getMessageType ()).isEqualTo (MessageType .SYSTEM );
168
183
184
+ // Capture and verify the user message instructions
169
185
Message userMessage = this .promptCaptor .getValue ().getInstructions ().get (1 );
170
186
assertThat (userMessage .getText ()).isEqualToIgnoringWhitespace ("my name is John" );
171
187
188
+ // Simulate another streaming user prompt and verify the response
172
189
content = join (chatClient .prompt ().user ("What is my name?" ).stream ().content ());
173
190
191
+ // Assert that the streamed content matches the expected output
174
192
assertThat (content ).isEqualTo ("Your name is John" );
175
193
194
+ // Capture and verify the updated system message instructions
176
195
systemMessage = this .promptCaptor .getValue ().getInstructions ().get (0 );
177
196
assertThat (systemMessage .getText ()).isEqualToIgnoringWhitespace ("""
178
197
Default system text.
@@ -187,6 +206,7 @@ public void streamingPromptChatMemory() {
187
206
""" );
188
207
assertThat (systemMessage .getMessageType ()).isEqualTo (MessageType .SYSTEM );
189
208
209
+ // Capture and verify the updated user message instructions
190
210
userMessage = this .promptCaptor .getValue ().getInstructions ().get (1 );
191
211
assertThat (userMessage .getText ()).isEqualToIgnoringWhitespace ("What is my name?" );
192
212
}
0 commit comments