Skip to content

Commit 478f180

Browse files
chemicLtzolov
authored andcommitted
Add OpenAiChatModel stream observability
Integrated Micrometer's Observation into the OpenAiChatModel#stream reactive chain. Included changes: - Added ability to aggregate streaming responses for use in Observation metadata. - Improved error handling and logging for chat response processing. - Updated unit tests to include new observation logic and subscribe to Flux responses. - Refined validation of observations in both normal and streaming chat operations. - Disabled retry for streaming which used RetryTemplate - should use .retryWhen operator as the next step. - Added an integration test. Resolves #1190 Co-authored-by Christian Tzolov <ctzolov@vmware.com>
1 parent 86348e4 commit 478f180

File tree

7 files changed

+252
-71
lines changed

7 files changed

+252
-71
lines changed

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java

Lines changed: 79 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
import org.springframework.ai.chat.model.ChatModel;
4040
import org.springframework.ai.chat.model.ChatResponse;
4141
import org.springframework.ai.chat.model.Generation;
42+
import org.springframework.ai.chat.model.MessageAggregator;
4243
import org.springframework.ai.chat.model.StreamingChatModel;
4344
import org.springframework.ai.chat.observation.ChatModelObservationContext;
4445
import org.springframework.ai.chat.observation.ChatModelObservationConvention;
@@ -72,7 +73,9 @@
7273
import org.springframework.util.MultiValueMap;
7374
import org.springframework.util.StringUtils;
7475

76+
import io.micrometer.observation.Observation;
7577
import io.micrometer.observation.ObservationRegistry;
78+
import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor;
7679
import reactor.core.publisher.Flux;
7780
import reactor.core.publisher.Mono;
7881

@@ -271,64 +274,90 @@ public ChatResponse call(Prompt prompt) {
271274

272275
@Override
273276
public Flux<ChatResponse> stream(Prompt prompt) {
274-
275-
ChatCompletionRequest request = createRequest(prompt, true);
276-
277-
Flux<OpenAiApi.ChatCompletionChunk> completionChunks = this.retryTemplate
278-
.execute(ctx -> this.openAiApi.chatCompletionStream(request, getAdditionalHttpHeaders(prompt)));
279-
280-
// For chunked responses, only the first chunk contains the choice role.
281-
// The rest of the chunks with same ID share the same role.
282-
ConcurrentHashMap<String, String> roleMap = new ConcurrentHashMap<>();
283-
284-
// Convert the ChatCompletionChunk into a ChatCompletion to be able to reuse
285-
// the function call handling logic.
286-
Flux<ChatResponse> chatResponse = completionChunks.map(this::chunkToChatCompletion)
287-
.switchMap(chatCompletion -> Mono.just(chatCompletion).map(chatCompletion2 -> {
288-
try {
289-
@SuppressWarnings("null")
290-
String id = chatCompletion2.id();
291-
292-
// @formatter:off
293-
List<Generation> generations = chatCompletion2.choices().stream().map(choice -> {
294-
if (choice.message().role() != null) {
295-
roleMap.putIfAbsent(id, choice.message().role().name());
296-
}
297-
Map<String, Object> metadata = Map.of(
298-
"id", chatCompletion2.id(),
299-
"role", roleMap.getOrDefault(id, ""),
300-
"index", choice.index(),
301-
"finishReason", choice.finishReason() != null ? choice.finishReason().name() : "");
302-
return buildGeneration(choice, metadata);
277+
return Flux.deferContextual(contextView -> {
278+
ChatCompletionRequest request = createRequest(prompt, true);
279+
280+
Flux<OpenAiApi.ChatCompletionChunk> completionChunks = this.openAiApi.chatCompletionStream(request,
281+
getAdditionalHttpHeaders(prompt));
282+
283+
// For chunked responses, only the first chunk contains the choice role.
284+
// The rest of the chunks with same ID share the same role.
285+
ConcurrentHashMap<String, String> roleMap = new ConcurrentHashMap<>();
286+
287+
final ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
288+
.prompt(prompt)
289+
.operationMetadata(buildOperationMetadata())
290+
.requestOptions(buildRequestOptions(request))
291+
.build();
292+
293+
Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation(
294+
this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
295+
this.observationRegistry);
296+
297+
observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)).start();
298+
299+
// Convert the ChatCompletionChunk into a ChatCompletion to be able to reuse
300+
// the function call handling logic.
301+
Flux<ChatResponse> chatResponse = completionChunks.map(this::chunkToChatCompletion)
302+
.switchMap(chatCompletion -> Mono.just(chatCompletion).map(chatCompletion2 -> {
303+
try {
304+
@SuppressWarnings("null")
305+
String id = chatCompletion2.id();
306+
307+
List<Generation> generations = chatCompletion2.choices().stream().map(choice -> {// @formatter:off
308+
309+
if (choice.message().role() != null) {
310+
roleMap.putIfAbsent(id, choice.message().role().name());
311+
}
312+
Map<String, Object> metadata = Map.of(
313+
"id", chatCompletion2.id(),
314+
"role", roleMap.getOrDefault(id, ""),
315+
"index", choice.index(),
316+
"finishReason", choice.finishReason() != null ? choice.finishReason().name() : "");
317+
318+
return buildGeneration(choice, metadata);
303319
}).toList();
304-
// @formatter:on
320+
// @formatter:on
305321

306-
if (chatCompletion2.usage() != null) {
307322
return new ChatResponse(generations, from(chatCompletion2, null));
308323
}
309-
else {
310-
return new ChatResponse(generations);
324+
catch (Exception e) {
325+
logger.error("Error processing chat completion", e);
326+
return new ChatResponse(List.of());
311327
}
312-
}
313-
catch (Exception e) {
314-
logger.error("Error processing chat completion", e);
315-
return new ChatResponse(List.of());
316-
}
317328

318-
}));
329+
}));
319330

320-
return chatResponse.flatMap(response -> {
331+
// @formatter:off
332+
Flux<ChatResponse> flux = chatResponse.flatMap(response -> {
333+
334+
if (isToolCall(response, Set.of(OpenAiApi.ChatCompletionFinishReason.TOOL_CALLS.name(),
335+
OpenAiApi.ChatCompletionFinishReason.STOP.name()))) {
336+
var toolCallConversation = handleToolCalls(prompt, response);
337+
// Recursively call the stream method with the tool call message
338+
// conversation that contains the call responses.
339+
return this.stream(new Prompt(toolCallConversation, prompt.getOptions()));
340+
}
341+
else {
342+
return Flux.just(response);
343+
}
344+
})
345+
.doOnError(observation::error)
346+
.doFinally(s -> {
347+
// TODO: Consider a custom ObservationContext and
348+
// include additional metadata
349+
// if (s == SignalType.CANCEL) {
350+
// observationContext.setAborted(true);
351+
// }
352+
observation.stop();
353+
})
354+
.contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation));
355+
// @formatter:on
356+
357+
return new MessageAggregator().aggregate(flux, mergedChatResponse -> {
358+
observationContext.setResponse(mergedChatResponse);
359+
});
321360

322-
if (isToolCall(response, Set.of(OpenAiApi.ChatCompletionFinishReason.TOOL_CALLS.name(),
323-
OpenAiApi.ChatCompletionFinishReason.STOP.name()))) {
324-
var toolCallConversation = handleToolCalls(prompt, response);
325-
// Recursively call the stream method with the tool call message
326-
// conversation that contains the call responses.
327-
return this.stream(new Prompt(toolCallConversation, prompt.getOptions()));
328-
}
329-
else {
330-
return Flux.just(response);
331-
}
332361
});
333362
}
334363

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/MessageTypeContentTests.java

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ public void streamUserMessageSimpleContentType() {
104104

105105
when(openAiApi.chatCompletionStream(pomptCaptor.capture(), headersCaptor.capture())).thenReturn(fluxResponse);
106106

107-
chatModel.stream(new Prompt(List.of(new UserMessage("test message"))));
107+
chatModel.stream(new Prompt(List.of(new UserMessage("test message")))).subscribe();
108108

109109
validateStringContent(pomptCaptor.getValue());
110110
assertThat(headersCaptor.getValue()).isEmpty();
@@ -137,8 +137,10 @@ public void streamUserMessageWithMediaType() throws MalformedURLException {
137137
when(openAiApi.chatCompletionStream(pomptCaptor.capture(), headersCaptor.capture())).thenReturn(fluxResponse);
138138

139139
URL mediaUrl = new URL("http://test");
140-
chatModel.stream(new Prompt(
141-
List.of(new UserMessage("test message", List.of(new Media(MimeTypeUtils.IMAGE_JPEG, mediaUrl))))));
140+
chatModel
141+
.stream(new Prompt(
142+
List.of(new UserMessage("test message", List.of(new Media(MimeTypeUtils.IMAGE_JPEG, mediaUrl))))))
143+
.subscribe();
142144

143145
validateComplexContent(pomptCaptor.getValue());
144146
}

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelObservationIT.java

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818
import io.micrometer.common.KeyValue;
1919
import io.micrometer.observation.tck.TestObservationRegistry;
2020
import io.micrometer.observation.tck.TestObservationRegistryAssert;
21+
import reactor.core.publisher.Flux;
22+
23+
import org.junit.jupiter.api.BeforeEach;
2124
import org.junit.jupiter.api.Test;
2225
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
2326
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
@@ -37,6 +40,7 @@
3740
import org.springframework.retry.support.RetryTemplate;
3841

3942
import java.util.List;
43+
import java.util.stream.Collectors;
4044

4145
import static org.assertj.core.api.Assertions.assertThat;
4246
import static org.springframework.ai.chat.observation.ChatModelObservationDocumentation.HighCardinalityKeyNames;
@@ -57,8 +61,14 @@ public class OpenAiChatModelObservationIT {
5761
@Autowired
5862
OpenAiChatModel chatModel;
5963

64+
@BeforeEach
65+
void beforeEach() {
66+
observationRegistry.clear();
67+
}
68+
6069
@Test
61-
void observationForEmbeddingOperation() {
70+
void observationForChatOperation() {
71+
6272
var options = OpenAiChatOptions.builder()
6373
.withModel(OpenAiApi.ChatModel.GPT_4_O_MINI.getValue())
6474
.withFrequencyPenalty(0f)
@@ -77,6 +87,45 @@ void observationForEmbeddingOperation() {
7787
ChatResponseMetadata responseMetadata = chatResponse.getMetadata();
7888
assertThat(responseMetadata).isNotNull();
7989

90+
validate(responseMetadata);
91+
}
92+
93+
@Test
94+
void observationForStreamingChatOperation() {
95+
var options = OpenAiChatOptions.builder()
96+
.withModel(OpenAiApi.ChatModel.GPT_4_O_MINI.getValue())
97+
.withFrequencyPenalty(0f)
98+
.withMaxTokens(2048)
99+
.withPresencePenalty(0f)
100+
.withStop(List.of("this-is-the-end"))
101+
.withTemperature(0.7f)
102+
.withTopP(1f)
103+
.withStreamUsage(true)
104+
.build();
105+
106+
Prompt prompt = new Prompt("Why does a raven look like a desk?", options);
107+
108+
Flux<ChatResponse> chatResponseFlux = chatModel.stream(prompt);
109+
110+
List<ChatResponse> responses = chatResponseFlux.collectList().block();
111+
assertThat(responses).isNotEmpty();
112+
assertThat(responses).hasSizeGreaterThan(10);
113+
114+
String aggregatedResponse = responses.subList(0, responses.size() - 1)
115+
.stream()
116+
.map(r -> r.getResult().getOutput().getContent())
117+
.collect(Collectors.joining());
118+
assertThat(aggregatedResponse).isNotEmpty();
119+
120+
ChatResponse lastChatResponse = responses.get(responses.size() - 1);
121+
122+
ChatResponseMetadata responseMetadata = lastChatResponse.getMetadata();
123+
assertThat(responseMetadata).isNotNull();
124+
125+
validate(responseMetadata);
126+
}
127+
128+
private void validate(ChatResponseMetadata responseMetadata) {
80129
TestObservationRegistryAssert.assertThat(observationRegistry)
81130
.doesNotHaveAnyRemainingCurrentObservation()
82131
.hasObservationWithNameEqualTo(DefaultChatModelObservationConvention.DEFAULT_NAME)

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiPaymentTransactionIT.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,7 @@
5353
* @author Christian Tzolov
5454
*/
5555
@SpringBootTest
56-
@EnabledIfEnvironmentVariable(named = "VERTEX_AI_GEMINI_PROJECT_ID", matches = ".*")
57-
@EnabledIfEnvironmentVariable(named = "VERTEX_AI_GEMINI_LOCATION", matches = ".*")
56+
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".*")
5857
public class OpenAiPaymentTransactionIT {
5958

6059
private final static Logger logger = LoggerFactory.getLogger(OpenAiPaymentTransactionIT.class);

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiRetryTests.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import java.util.Optional;
2020

2121
import org.junit.jupiter.api.BeforeEach;
22+
import org.junit.jupiter.api.Disabled;
2223
import org.junit.jupiter.api.Test;
2324
import org.junit.jupiter.api.extension.ExtendWith;
2425
import org.mockito.Mock;
@@ -163,6 +164,7 @@ public void openAiChatNonTransientError() {
163164
}
164165

165166
@Test
167+
@Disabled("Currently stream() does not implmement retry")
166168
public void openAiChatStreamTransientError() {
167169

168170
var choice = new ChatCompletionChunk.ChunkChoice(ChatCompletionFinishReason.STOP, 0,
@@ -184,10 +186,11 @@ public void openAiChatStreamTransientError() {
184186
}
185187

186188
@Test
189+
@Disabled("Currently stream() does not implmement retry")
187190
public void openAiChatStreamNonTransientError() {
188191
when(openAiApi.chatCompletionStream(isA(ChatCompletionRequest.class), any()))
189192
.thenThrow(new RuntimeException("Non Transient Error"));
190-
assertThrows(RuntimeException.class, () -> chatModel.stream(new Prompt("text")));
193+
assertThrows(RuntimeException.class, () -> chatModel.stream(new Prompt("text")).subscribe());
191194
}
192195

193196
@Test

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientMultipleFunctionCallsIT.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
import org.slf4j.Logger;
2828
import org.slf4j.LoggerFactory;
2929
import org.springframework.ai.chat.client.ChatClient;
30-
import org.springframework.ai.chat.client.DefaultChatClient;
3130
import org.springframework.ai.openai.OpenAiTestConfiguration;
3231
import org.springframework.ai.openai.api.tool.MockWeatherService;
3332
import org.springframework.ai.openai.testutils.AbstractIT;

0 commit comments

Comments
 (0)