Skip to content

Commit 4209d5a

Browse files
committed
Make consistent sync/stream AssistantMessage properties for OpenAI and Mistral AI
1 parent 7f1570d commit 4209d5a

File tree

2 files changed

+31
-12
lines changed

2 files changed

+31
-12
lines changed

models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatClient.java

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
*/
1616
package org.springframework.ai.mistralai;
1717

18+
import java.util.HashMap;
1819
import java.util.HashSet;
1920
import java.util.List;
2021
import java.util.Map;
@@ -113,15 +114,28 @@ public ChatResponse call(Prompt prompt) {
113114

114115
List<Generation> generations = chatCompletion.choices()
115116
.stream()
116-
.map(choice -> new Generation(choice.message().content(),
117-
Map.of("role", choice.message().role().name()))
117+
.map(choice -> new Generation(choice.message().content(), toMap(chatCompletion.id(), choice))
118118
.withGenerationMetadata(ChatGenerationMetadata.from(choice.finishReason().name(), null)))
119119
.toList();
120120

121121
return new ChatResponse(generations);
122122
});
123123
}
124124

125+
private Map<String, Object> toMap(String id, ChatCompletion.Choice choice) {
126+
Map<String, Object> map = new HashMap<>();
127+
128+
var message = choice.message();
129+
if (message.role() != null) {
130+
map.put("role", message.role().name());
131+
}
132+
if (choice.finishReason() != null) {
133+
map.put("finishReason", choice.finishReason().name());
134+
}
135+
map.put("id", id);
136+
return map;
137+
}
138+
125139
@Override
126140
public Flux<ChatResponse> stream(Prompt prompt) {
127141
var request = createRequest(prompt, true);

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatClient.java

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ public ChatResponse call(Prompt prompt) {
147147
RateLimit rateLimits = OpenAiResponseHeaderExtractor.extractAiResponseHeaders(completionEntity);
148148

149149
List<Generation> generations = chatCompletion.choices().stream().map(choice -> {
150-
return new Generation(choice.message().content(), toMap(choice.message()))
150+
return new Generation(choice.message().content(), toMap(chatCompletion.id(), choice))
151151
.withGenerationMetadata(ChatGenerationMetadata.from(choice.finishReason().name(), null));
152152
}).toList();
153153

@@ -156,6 +156,20 @@ public ChatResponse call(Prompt prompt) {
156156
});
157157
}
158158

159+
private Map<String, Object> toMap(String id, ChatCompletion.Choice choice) {
160+
Map<String, Object> map = new HashMap<>();
161+
162+
var message = choice.message();
163+
if (message.role() != null) {
164+
map.put("role", message.role().name());
165+
}
166+
if (choice.finishReason() != null) {
167+
map.put("finishReason", choice.finishReason().name());
168+
}
169+
map.put("id", id);
170+
return map;
171+
}
172+
159173
@Override
160174
public Flux<ChatResponse> stream(Prompt prompt) {
161175

@@ -280,15 +294,6 @@ private List<OpenAiApi.FunctionTool> getFunctionTools(Set<String> functionNames)
280294
}).toList();
281295
}
282296

283-
private Map<String, Object> toMap(ChatCompletionMessage message) {
284-
Map<String, Object> map = new HashMap<>();
285-
286-
if (message.role() != null) {
287-
map.put("role", message.role().name());
288-
}
289-
return map;
290-
}
291-
292297
@Override
293298
protected ChatCompletionRequest doCreateToolResponseRequest(ChatCompletionRequest previousRequest,
294299
ChatCompletionMessage responseMessage, List<ChatCompletionMessage> conversationHistory) {

0 commit comments

Comments
 (0)