Skip to content

Commit 866b262

Browse files
TarasVovk669tzolov
authored andcommitted
Add refusal field to ChatCompletionMessage and related classes
- Updated OpenAiChatModel, OpenAiApi, and OpenAiStreamFunctionCallingHelper to include the `refusal` field in metadata. - Adjusted constructors and methods to handle the new `refusal` attribute. - Modified related tests to account for the new `refusal` field. - Add the refusal field value to the Spring AI AssistantMessage metadata Resolves #1178
1 parent e2c5208 commit 866b262

File tree

4 files changed

+17
-13
lines changed

4 files changed

+17
-13
lines changed

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -241,12 +241,13 @@ public ChatResponse call(Prompt prompt) {
241241

242242
List<Generation> generations = choices.stream().map(choice -> {
243243
// @formatter:off
244-
Map<String, Object> metadata = Map.of(
245-
"id", chatCompletion.id() != null ? chatCompletion.id() : "",
246-
"role", choice.message().role() != null ? choice.message().role().name() : "",
247-
"index", choice.index(),
248-
"finishReason", choice.finishReason() != null ? choice.finishReason().name() : "");
249-
// @formatter:on
244+
Map<String, Object> metadata = Map.of(
245+
"id", chatCompletion.id() != null ? chatCompletion.id() : "",
246+
"role", choice.message().role() != null ? choice.message().role().name() : "",
247+
"index", choice.index(),
248+
"finishReason", choice.finishReason() != null ? choice.finishReason().name() : "",
249+
"refusal", StringUtils.hasText(choice.message().refusal()) ? choice.message().refusal() : "");
250+
// @formatter:on
250251
return buildGeneration(choice, metadata);
251252
}).toList();
252253

@@ -313,7 +314,8 @@ public Flux<ChatResponse> stream(Prompt prompt) {
313314
"id", chatCompletion2.id(),
314315
"role", roleMap.getOrDefault(id, ""),
315316
"index", choice.index(),
316-
"finishReason", choice.finishReason() != null ? choice.finishReason().name() : "");
317+
"finishReason", choice.finishReason() != null ? choice.finishReason().name() : "",
318+
"refusal", StringUtils.hasText(choice.message().refusal()) ? choice.message().refusal() : "");
317319

318320
return buildGeneration(choice, metadata);
319321
}).toList();
@@ -453,7 +455,7 @@ else if (message.getMessageType() == MessageType.ASSISTANT) {
453455
}).toList();
454456
}
455457
return List.of(new ChatCompletionMessage(assistantMessage.getContent(),
456-
ChatCompletionMessage.Role.ASSISTANT, null, null, toolCalls));
458+
ChatCompletionMessage.Role.ASSISTANT, null, null, toolCalls, null));
457459
}
458460
else if (message.getMessageType() == MessageType.TOOL) {
459461
ToolResponseMessage toolMessage = (ToolResponseMessage) message;
@@ -466,7 +468,7 @@ else if (message.getMessageType() == MessageType.TOOL) {
466468
return toolMessage.getResponses()
467469
.stream()
468470
.map(tr -> new ChatCompletionMessage(tr.responseData(), ChatCompletionMessage.Role.TOOL, tr.name(),
469-
tr.id(), null))
471+
tr.id(), null, null))
470472
.toList();
471473
}
472474
else {

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -560,7 +560,8 @@ public record ChatCompletionMessage(// @formatter:off
560560
@JsonProperty("role") Role role,
561561
@JsonProperty("name") String name,
562562
@JsonProperty("tool_call_id") String toolCallId,
563-
@JsonProperty("tool_calls") List<ToolCall> toolCalls) {// @formatter:on
563+
@JsonProperty("tool_calls") List<ToolCall> toolCalls,
564+
@JsonProperty("refusal") String refusal) {// @formatter:on
564565

565566
/**
566567
* Get message content as String.
@@ -582,7 +583,7 @@ public String content() {
582583
* @param role The role of the author of this message.
583584
*/
584585
public ChatCompletionMessage(Object content, Role role) {
585-
this(content, role, null, null, null);
586+
this(content, role, null, null, null, null);
586587
}
587588

588589
/**

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiStreamFunctionCallingHelper.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ private ChatCompletionMessage merge(ChatCompletionMessage previous, ChatCompleti
9191
role = (role != null ? role : Role.ASSISTANT); // default to ASSISTANT (if null
9292
String name = (current.name() != null ? current.name() : previous.name());
9393
String toolCallId = (current.toolCallId() != null ? current.toolCallId() : previous.toolCallId());
94+
String refusal = (current.refusal() != null ? current.refusal() : previous.refusal());
9495

9596
List<ToolCall> toolCalls = new ArrayList<>();
9697
ToolCall lastPreviousTooCall = null;
@@ -120,7 +121,7 @@ private ChatCompletionMessage merge(ChatCompletionMessage previous, ChatCompleti
120121
toolCalls.add(lastPreviousTooCall);
121122
}
122123
}
123-
return new ChatCompletionMessage(content, role, name, toolCallId, toolCalls);
124+
return new ChatCompletionMessage(content, role, name, toolCallId, toolCalls, refusal);
124125
}
125126

126127
private ToolCall merge(ToolCall previous, ToolCall current) {

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/tool/OpenAiApiToolFunctionCallIT.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ public void toolFunctionCall() {
122122

123123
// extend conversation with function response.
124124
messages.add(new ChatCompletionMessage("" + weatherResponse.temp() + weatherRequest.unit(),
125-
Role.TOOL, functionName, toolCall.id(), null));
125+
Role.TOOL, functionName, toolCall.id(), null, null));
126126
}
127127
}
128128

0 commit comments

Comments
 (0)