Skip to content

Commit e63dc6a

Browse files
committed
Enhance ChatClientRequestSpec with sealed Prompt input
- When ChatClientRequestSpec#prompt(Prompt) is used, unseal the prompt instance. Convert the last message instance (if user message) into spec#user and spec#media and add the remaining messages (excluding the last) to the spec#messages. Add the prompt#options to the spec#options. - Improve DefaultChatClient to handle UserMessage media and content separately. - Update AbstractToolCallSupport to use new hasToolCalls() method. - Add hasToolCalls() method to AssistantMessage. - Enhance ChatClientTest with additional test cases for media handling. - Disable Groq and Nvidia integration tests due to rate limiting and credit requirements.
1 parent 9822eab commit e63dc6a

File tree

6 files changed

+59
-27
lines changed

6 files changed

+59
-27
lines changed

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/GroqWithOpenAiChatModelIT.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666

6767
@SpringBootTest(classes = GroqWithOpenAiChatModelIT.Config.class)
6868
@EnabledIfEnvironmentVariable(named = "GROQ_API_KEY", matches = ".+")
69+
@Disabled("Due to rate limiting it is hard to run it in one go")
6970
class GroqWithOpenAiChatModelIT {
7071

7172
private static final Logger logger = LoggerFactory.getLogger(OpenAiChatModelIT.class);

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/NvidiaWithOpenAiChatModelIT.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import java.util.Map;
2424
import java.util.stream.Collectors;
2525

26+
import org.junit.jupiter.api.Disabled;
2627
import org.junit.jupiter.api.Test;
2728
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
2829
import org.slf4j.Logger;
@@ -61,6 +62,7 @@
6162
*/
6263
@SpringBootTest(classes = NvidiaWithOpenAiChatModelIT.Config.class)
6364
@EnabledIfEnvironmentVariable(named = "NVIDIA_API_KEY", matches = ".+")
65+
@Disabled("Requires NVIDIA credits")
6466
class NvidiaWithOpenAiChatModelIT {
6567

6668
private static final Logger logger = LoggerFactory.getLogger(NvidiaWithOpenAiChatModelIT.class);

spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
import org.springframework.ai.chat.client.observation.ChatClientObservationDocumentation;
4343
import org.springframework.ai.chat.client.observation.DefaultChatClientObservationConvention;
4444
import org.springframework.ai.chat.messages.Message;
45+
import org.springframework.ai.chat.messages.MessageType;
4546
import org.springframework.ai.chat.messages.SystemMessage;
4647
import org.springframework.ai.chat.messages.UserMessage;
4748
import org.springframework.ai.chat.model.ChatModel;
@@ -104,11 +105,33 @@ public ChatClientRequestSpec prompt(String content) {
104105
public ChatClientRequestSpec prompt(Prompt prompt) {
105106

106107
DefaultChatClientRequestSpec spec = new DefaultChatClientRequestSpec(this.defaultChatClientRequest);
107-
spec.messages(prompt.getInstructions());
108+
109+
// Options
108110
if (prompt.getOptions() != null) {
109111
spec.options(prompt.getOptions());
110112
}
111113

114+
// Messages
115+
List<Message> messages = prompt.getInstructions();
116+
117+
if (!CollectionUtils.isEmpty(messages)) {
118+
var lastMessage = messages.get(messages.size() - 1);
119+
if (lastMessage.getMessageType() == MessageType.USER) {
120+
// Unzip the last message
121+
var userMessage = (UserMessage) lastMessage;
122+
if (StringUtils.hasText(userMessage.getContent())) {
123+
spec.user(lastMessage.getContent());
124+
}
125+
var media = userMessage.getMedia();
126+
if (!CollectionUtils.isEmpty(media)) {
127+
spec.user(u -> u.media(media.toArray(new Media[media.size()])));
128+
}
129+
messages = messages.subList(0, messages.size() - 1);
130+
}
131+
}
132+
133+
spec.messages(messages);
134+
112135
return spec;
113136
}
114137

spring-ai-core/src/main/java/org/springframework/ai/chat/messages/AssistantMessage.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import java.util.Objects;
2121

2222
import org.springframework.util.Assert;
23+
import org.springframework.util.CollectionUtils;
2324

2425
/**
2526
* Lets the generative know the content was generated as a response to the user. This role
@@ -52,15 +53,14 @@ public AssistantMessage(String content, Map<String, Object> properties, List<Too
5253
this.toolCalls = toolCalls;
5354
}
5455

55-
@Override
56-
public String getContent() {
57-
return this.textContent;
58-
}
59-
6056
public List<ToolCall> getToolCalls() {
6157
return this.toolCalls;
6258
}
6359

60+
public boolean hasToolCalls() {
61+
return !CollectionUtils.isEmpty(this.toolCalls);
62+
}
63+
6464
@Override
6565
public boolean equals(Object o) {
6666
if (this == o)

spring-ai-core/src/main/java/org/springframework/ai/chat/model/AbstractToolCallSupport.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ protected boolean isToolCall(ChatResponse chatResponse, Set<String> toolCallFini
230230
protected boolean isToolCall(Generation generation, Set<String> toolCallFinishReasons) {
231231
var finishReason = (generation.getMetadata().getFinishReason() != null)
232232
? generation.getMetadata().getFinishReason() : "";
233-
return !CollectionUtils.isEmpty(generation.getOutput().getToolCalls()) && toolCallFinishReasons.stream()
233+
return generation.getOutput().hasToolCalls() && toolCallFinishReasons.stream()
234234
.map(s -> s.toLowerCase())
235235
.toList()
236236
.contains(finishReason.toLowerCase());

spring-ai-core/src/test/java/org/springframework/ai/chat/client/ChatClientTest.java

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
import org.springframework.ai.chat.model.ChatResponse;
3939
import org.springframework.ai.chat.model.Generation;
4040
import org.springframework.ai.chat.prompt.Prompt;
41+
import org.springframework.ai.model.Media;
4142
import org.springframework.ai.model.function.FunctionCallingOptions;
4243
import org.springframework.ai.model.function.FunctionCallingOptionsBuilder;
4344
import org.springframework.ai.model.function.FunctionCallingOptionsBuilder.PortableFunctionCallingOptions;
@@ -67,14 +68,15 @@ private String join(Flux<String> fluxContent) {
6768
@Test
6869
public void defaultSystemText() {
6970

70-
when(chatModel.call(promptCaptor.capture())).thenReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response")))));
71+
when(chatModel.call(promptCaptor.capture()))
72+
.thenReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response")))));
7173

72-
when(chatModel.stream(promptCaptor.capture()))
73-
.thenReturn(Flux.generate(() -> new ChatResponse(List.of(new Generation(new AssistantMessage("response")))), (state, sink) -> {
74-
sink.next(state);
75-
sink.complete();
76-
return state;
77-
}));
74+
when(chatModel.stream(promptCaptor.capture())).thenReturn(Flux.generate(
75+
() -> new ChatResponse(List.of(new Generation(new AssistantMessage("response")))), (state, sink) -> {
76+
sink.next(state);
77+
sink.complete();
78+
return state;
79+
}));
7880

7981
var chatClient = ChatClient.builder(chatModel).defaultSystem("Default system text").build();
8082

@@ -114,14 +116,15 @@ public void defaultSystemText() {
114116
@Test
115117
public void defaultSystemTextLambda() {
116118

117-
when(chatModel.call(promptCaptor.capture())).thenReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response")))));
119+
when(chatModel.call(promptCaptor.capture()))
120+
.thenReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response")))));
118121

119-
when(chatModel.stream(promptCaptor.capture()))
120-
.thenReturn(Flux.generate(() -> new ChatResponse(List.of(new Generation(new AssistantMessage("response")))), (state, sink) -> {
121-
sink.next(state);
122-
sink.complete();
123-
return state;
124-
}));
122+
when(chatModel.stream(promptCaptor.capture())).thenReturn(Flux.generate(
123+
() -> new ChatResponse(List.of(new Generation(new AssistantMessage("response")))), (state, sink) -> {
124+
sink.next(state);
125+
sink.complete();
126+
return state;
127+
}));
125128

126129
var chatClient = ChatClient.builder(chatModel)
127130
.defaultSystem(s -> s.text("Default system text {param1}, {param2}")
@@ -438,10 +441,9 @@ public void defaultUserText() {
438441
@Test
439442
public void simpleUserPromptAsString() {
440443
when(chatModel.call(promptCaptor.capture()))
441-
.thenReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response")))));
444+
.thenReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response")))));
442445

443-
assertThat(ChatClient.builder(chatModel).build().prompt("User prompt").call().content())
444-
.isEqualTo("response");
446+
assertThat(ChatClient.builder(chatModel).build().prompt("User prompt").call().content()).isEqualTo("response");
445447

446448
Message userMessage = promptCaptor.getValue().getInstructions().get(0);
447449
assertThat(userMessage.getContent()).isEqualTo("User prompt");
@@ -466,13 +468,17 @@ public void simpleUserPromptObject() throws MalformedURLException {
466468
when(chatModel.call(promptCaptor.capture()))
467469
.thenReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response")))));
468470

469-
UserMessage message = new UserMessage("User prompt");
471+
var media = new Media(MimeTypeUtils.IMAGE_JPEG, new DefaultResourceLoader().getResource("classpath:/bikes.json"));
472+
473+
UserMessage message = new UserMessage("User prompt", List.of(media));
470474
Prompt prompt = new Prompt(message);
471475
assertThat(ChatClient.builder(chatModel).build().prompt(prompt).call().content()).isEqualTo("response");
472476

477+
assertThat(promptCaptor.getValue().getInstructions()).hasSize(1);
473478
Message userMessage = promptCaptor.getValue().getInstructions().get(0);
474-
assertThat(userMessage.getContent()).isEqualTo("User prompt");
475479
assertThat(userMessage.getMessageType()).isEqualTo(MessageType.USER);
480+
assertThat(userMessage.getContent()).isEqualTo("User prompt");
481+
assertThat(((UserMessage) userMessage).getMedia()).hasSize(1);
476482
}
477483

478484
@Test
@@ -527,7 +533,7 @@ public void complexCall() throws MalformedURLException {
527533
assertThat(userMessage.getMedia().iterator().next().getMimeType()).isEqualTo(MimeTypeUtils.IMAGE_PNG);
528534
assertThat(userMessage.getMedia().iterator().next().getData())
529535
.isEqualTo("https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/_images/multimodal.test.png");
530-
536+
531537
FunctionCallingOptions runtieOptions = (FunctionCallingOptions) promptCaptor.getValue().getOptions();
532538

533539
assertThat(runtieOptions.getFunctions()).containsExactly("function1");

0 commit comments

Comments
 (0)