Skip to content

Commit 2cbfb22

Browse files
committed
Polish ChatClientAdvisorTests
1 parent 846165d commit 2cbfb22

File tree

1 file changed

+28
-8
lines changed

1 file changed

+28
-8
lines changed

spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/ChatClientAdvisorTests.java

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,16 @@
3737
import org.springframework.ai.chat.model.ChatModel;
3838
import org.springframework.ai.chat.model.ChatResponse;
3939
import org.springframework.ai.chat.model.Generation;
40-
import org.springframework.ai.chat.model.MessageAggregator;
4140
import org.springframework.ai.chat.prompt.Prompt;
4241

4342
import static org.assertj.core.api.Assertions.assertThat;
4443
import static org.mockito.BDDMockito.given;
4544

4645
/**
46+
* Tests for the ChatClient with a focus on verifying the handling of conversation memory
47+
* and the integration of PromptChatMemoryAdvisor to ensure accurate responses based on
48+
* previous interactions.
49+
*
4750
* @author Christian Tzolov
4851
* @author Alexandros Pappas
4952
*/
@@ -63,32 +66,33 @@ private String join(Flux<String> fluxContent) {
6366
@Test
6467
public void promptChatMemory() {
6568

66-
var builder = ChatResponseMetadata.builder()
67-
.id("124")
68-
.usage(new MessageAggregator.DefaultUsage(1, 2, 3))
69-
.model("gpt4o")
70-
.keyValue("created", 0L)
71-
.keyValue("system-fingerprint", "john doe");
72-
ChatResponseMetadata chatResponseMetadata = builder.build();
69+
// Create a ChatResponseMetadata instance with default values
70+
ChatResponseMetadata chatResponseMetadata = ChatResponseMetadata.builder().build();
7371

72+
// Mock the chatModel to return predefined ChatResponse objects when called
7473
given(this.chatModel.call(this.promptCaptor.capture()))
7574
.willReturn(
7675
new ChatResponse(List.of(new Generation(new AssistantMessage("Hello John"))), chatResponseMetadata))
7776
.willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("Your name is John"))),
7877
chatResponseMetadata));
7978

79+
// Initialize an in-memory chat memory to store conversation history
8080
ChatMemory chatMemory = new InMemoryChatMemory();
8181

82+
// Build a ChatClient with default system text and a memory advisor
8283
var chatClient = ChatClient.builder(this.chatModel)
8384
.defaultSystem("Default system text.")
8485
.defaultAdvisors(new PromptChatMemoryAdvisor(chatMemory))
8586
.build();
8687

88+
// Simulate a user prompt and verify the response
8789
ChatResponse chatResponse = chatClient.prompt().user("my name is John").call().chatResponse();
8890

91+
// Assert that the response content matches the expected output
8992
String content = chatResponse.getResult().getOutput().getText();
9093
assertThat(content).isEqualTo("Hello John");
9194

95+
// Capture and verify the system message instructions
9296
Message systemMessage = this.promptCaptor.getValue().getInstructions().get(0);
9397
assertThat(systemMessage.getText()).isEqualToIgnoringWhitespace("""
9498
Default system text.
@@ -101,13 +105,17 @@ public void promptChatMemory() {
101105
""");
102106
assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM);
103107

108+
// Capture and verify the user message instructions
104109
Message userMessage = this.promptCaptor.getValue().getInstructions().get(1);
105110
assertThat(userMessage.getText()).isEqualToIgnoringWhitespace("my name is John");
106111

112+
// Simulate another user prompt and verify the response
107113
content = chatClient.prompt().user("What is my name?").call().content();
108114

115+
// Assert that the response content matches the expected output
109116
assertThat(content).isEqualTo("Your name is John");
110117

118+
// Capture and verify the updated system message instructions
111119
systemMessage = this.promptCaptor.getValue().getInstructions().get(0);
112120
assertThat(systemMessage.getText()).isEqualToIgnoringWhitespace("""
113121
Default system text.
@@ -122,13 +130,15 @@ public void promptChatMemory() {
122130
""");
123131
assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM);
124132

133+
// Capture and verify the updated user message instructions
125134
userMessage = this.promptCaptor.getValue().getInstructions().get(1);
126135
assertThat(userMessage.getText()).isEqualToIgnoringWhitespace("What is my name?");
127136
}
128137

129138
@Test
130139
public void streamingPromptChatMemory() {
131140

141+
// Mock the chatModel to stream predefined ChatResponse objects
132142
given(this.chatModel.stream(this.promptCaptor.capture())).willReturn(Flux.generate(
133143
() -> new ChatResponse(List.of(new Generation(new AssistantMessage("Hello John")))), (state, sink) -> {
134144
sink.next(state);
@@ -143,17 +153,22 @@ public void streamingPromptChatMemory() {
143153
return state;
144154
}));
145155

156+
// Initialize an in-memory chat memory to store conversation history
146157
ChatMemory chatMemory = new InMemoryChatMemory();
147158

159+
// Build a ChatClient with default system text and a memory advisor
148160
var chatClient = ChatClient.builder(this.chatModel)
149161
.defaultSystem("Default system text.")
150162
.defaultAdvisors(new PromptChatMemoryAdvisor(chatMemory))
151163
.build();
152164

165+
// Simulate a streaming user prompt and verify the response
153166
var content = join(chatClient.prompt().user("my name is John").stream().content());
154167

168+
// Assert that the streamed content matches the expected output
155169
assertThat(content).isEqualTo("Hello John");
156170

171+
// Capture and verify the system message instructions
157172
Message systemMessage = this.promptCaptor.getValue().getInstructions().get(0);
158173
assertThat(systemMessage.getText()).isEqualToIgnoringWhitespace("""
159174
Default system text.
@@ -166,13 +181,17 @@ public void streamingPromptChatMemory() {
166181
""");
167182
assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM);
168183

184+
// Capture and verify the user message instructions
169185
Message userMessage = this.promptCaptor.getValue().getInstructions().get(1);
170186
assertThat(userMessage.getText()).isEqualToIgnoringWhitespace("my name is John");
171187

188+
// Simulate another streaming user prompt and verify the response
172189
content = join(chatClient.prompt().user("What is my name?").stream().content());
173190

191+
// Assert that the streamed content matches the expected output
174192
assertThat(content).isEqualTo("Your name is John");
175193

194+
// Capture and verify the updated system message instructions
176195
systemMessage = this.promptCaptor.getValue().getInstructions().get(0);
177196
assertThat(systemMessage.getText()).isEqualToIgnoringWhitespace("""
178197
Default system text.
@@ -187,6 +206,7 @@ public void streamingPromptChatMemory() {
187206
""");
188207
assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM);
189208

209+
// Capture and verify the updated user message instructions
190210
userMessage = this.promptCaptor.getValue().getInstructions().get(1);
191211
assertThat(userMessage.getText()).isEqualToIgnoringWhitespace("What is my name?");
192212
}

0 commit comments

Comments
 (0)