Skip to content

Commit af0303f

Browse files
berjanjonkermarkpollack
authored andcommitted
Make responseMessage in AzureOpenAiChatModel.buildGeneration null-safe
Signed-off-by: Berjan Jonker <berjanjonker@users.noreply.github.com>
1 parent 5634d89 commit af0303f

File tree

2 files changed

+36
-11
lines changed

2 files changed

+36
-11
lines changed

models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2023-2024 the original author or authors.
2+
* Copyright 2023-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -112,6 +112,7 @@
112112
* @author Jihoon Kim
113113
* @author Ilayaperumal Gopinathan
114114
* @author Alexandros Pappas
115+
* @author Berjan Jonker
115116
* @see ChatModel
116117
* @see com.azure.ai.openai.OpenAIClient
117118
* @since 1.0.0
@@ -462,16 +463,19 @@ private Generation buildGeneration(ChatChoice choice, Map<String, Object> metada
462463

463464
var responseMessage = Optional.ofNullable(choice.getMessage()).orElse(choice.getDelta());
464465

465-
List<AssistantMessage.ToolCall> toolCalls = responseMessage.getToolCalls() == null ? List.of()
466-
: responseMessage.getToolCalls().stream().map(toolCall -> {
467-
final var tc1 = (ChatCompletionsFunctionToolCall) toolCall;
468-
String id = tc1.getId();
469-
String name = tc1.getFunction().getName();
470-
String arguments = tc1.getFunction().getArguments();
471-
return new AssistantMessage.ToolCall(id, "function", name, arguments);
472-
}).toList();
466+
List<AssistantMessage.ToolCall> toolCalls = List.of();
467+
if (responseMessage != null && responseMessage.getToolCalls() != null) {
468+
toolCalls = responseMessage.getToolCalls().stream().map(toolCall -> {
469+
final var tc1 = (ChatCompletionsFunctionToolCall) toolCall;
470+
String id = tc1.getId();
471+
String name = tc1.getFunction().getName();
472+
String arguments = tc1.getFunction().getArguments();
473+
return new AssistantMessage.ToolCall(id, "function", name, arguments);
474+
}).toList();
475+
}
473476

474-
var assistantMessage = new AssistantMessage(responseMessage.getContent(), metadata, toolCalls);
477+
var content = responseMessage == null ? "" : responseMessage.getContent();
478+
var assistantMessage = new AssistantMessage(content, metadata, toolCalls);
475479
var generationMetadata = generateChoiceMetadata(choice);
476480

477481
return new Generation(assistantMessage, generationMetadata);

models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatModelIT.java

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2023-2024 the original author or authors.
2+
* Copyright 2023-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -22,6 +22,7 @@
2222
import java.util.List;
2323
import java.util.Map;
2424
import java.util.Objects;
25+
import java.util.concurrent.atomic.AtomicInteger;
2526
import java.util.stream.Collectors;
2627

2728
import com.azure.ai.openai.OpenAIClientBuilder;
@@ -106,6 +107,26 @@ void testMessageHistory() {
106107
assertThat(response.getResult().getOutput().getText()).containsAnyOf("Blackbeard");
107108
}
108109

110+
@Test
111+
void testStreaming() {
112+
String prompt = """
113+
Provide a list of planets in our solar system
114+
""";
115+
116+
final var counter = new AtomicInteger();
117+
String content = this.chatModel.stream(prompt)
118+
.doOnEach(listSignal -> counter.getAndIncrement())
119+
.collectList()
120+
.block()
121+
.stream()
122+
.collect(Collectors.joining());
123+
logger.info("Response: {}", content);
124+
125+
assertThat(counter.get()).isGreaterThan(8).as("More than 8 chuncks because there are 8 planets");
126+
127+
assertThat(content).contains("Earth", "Mars", "Jupiter");
128+
}
129+
109130
@Test
110131
void listOutputConverter() {
111132
DefaultConversionService conversionService = new DefaultConversionService();

0 commit comments

Comments
 (0)