|
39 | 39 | import org.springframework.ai.chat.model.ChatModel;
|
40 | 40 | import org.springframework.ai.chat.model.ChatResponse;
|
41 | 41 | import org.springframework.ai.chat.model.Generation;
|
| 42 | +import org.springframework.ai.chat.model.MessageAggregator; |
42 | 43 | import org.springframework.ai.chat.model.StreamingChatModel;
|
43 | 44 | import org.springframework.ai.chat.observation.ChatModelObservationContext;
|
44 | 45 | import org.springframework.ai.chat.observation.ChatModelObservationConvention;
|
|
72 | 73 | import org.springframework.util.MultiValueMap;
|
73 | 74 | import org.springframework.util.StringUtils;
|
74 | 75 |
|
| 76 | +import io.micrometer.observation.Observation; |
75 | 77 | import io.micrometer.observation.ObservationRegistry;
|
| 78 | +import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; |
76 | 79 | import reactor.core.publisher.Flux;
|
77 | 80 | import reactor.core.publisher.Mono;
|
78 | 81 |
|
@@ -271,64 +274,90 @@ public ChatResponse call(Prompt prompt) {
|
271 | 274 |
|
272 | 275 | @Override
|
273 | 276 | public Flux<ChatResponse> stream(Prompt prompt) {
|
274 |
| - |
275 |
| - ChatCompletionRequest request = createRequest(prompt, true); |
276 |
| - |
277 |
| - Flux<OpenAiApi.ChatCompletionChunk> completionChunks = this.retryTemplate |
278 |
| - .execute(ctx -> this.openAiApi.chatCompletionStream(request, getAdditionalHttpHeaders(prompt))); |
279 |
| - |
280 |
| - // For chunked responses, only the first chunk contains the choice role. |
281 |
| - // The rest of the chunks with same ID share the same role. |
282 |
| - ConcurrentHashMap<String, String> roleMap = new ConcurrentHashMap<>(); |
283 |
| - |
284 |
| - // Convert the ChatCompletionChunk into a ChatCompletion to be able to reuse |
285 |
| - // the function call handling logic. |
286 |
| - Flux<ChatResponse> chatResponse = completionChunks.map(this::chunkToChatCompletion) |
287 |
| - .switchMap(chatCompletion -> Mono.just(chatCompletion).map(chatCompletion2 -> { |
288 |
| - try { |
289 |
| - @SuppressWarnings("null") |
290 |
| - String id = chatCompletion2.id(); |
291 |
| - |
292 |
| - // @formatter:off |
293 |
| - List<Generation> generations = chatCompletion2.choices().stream().map(choice -> { |
294 |
| - if (choice.message().role() != null) { |
295 |
| - roleMap.putIfAbsent(id, choice.message().role().name()); |
296 |
| - } |
297 |
| - Map<String, Object> metadata = Map.of( |
298 |
| - "id", chatCompletion2.id(), |
299 |
| - "role", roleMap.getOrDefault(id, ""), |
300 |
| - "index", choice.index(), |
301 |
| - "finishReason", choice.finishReason() != null ? choice.finishReason().name() : ""); |
302 |
| - return buildGeneration(choice, metadata); |
| 277 | + return Flux.deferContextual(contextView -> { |
| 278 | + ChatCompletionRequest request = createRequest(prompt, true); |
| 279 | + |
| 280 | + Flux<OpenAiApi.ChatCompletionChunk> completionChunks = this.openAiApi.chatCompletionStream(request, |
| 281 | + getAdditionalHttpHeaders(prompt)); |
| 282 | + |
| 283 | + // For chunked responses, only the first chunk contains the choice role. |
| 284 | + // The rest of the chunks with same ID share the same role. |
| 285 | + ConcurrentHashMap<String, String> roleMap = new ConcurrentHashMap<>(); |
| 286 | + |
| 287 | + final ChatModelObservationContext observationContext = ChatModelObservationContext.builder() |
| 288 | + .prompt(prompt) |
| 289 | + .operationMetadata(buildOperationMetadata()) |
| 290 | + .requestOptions(buildRequestOptions(request)) |
| 291 | + .build(); |
| 292 | + |
| 293 | + Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation( |
| 294 | + this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, |
| 295 | + this.observationRegistry); |
| 296 | + |
| 297 | + observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)).start(); |
| 298 | + |
| 299 | + // Convert the ChatCompletionChunk into a ChatCompletion to be able to reuse |
| 300 | + // the function call handling logic. |
| 301 | + Flux<ChatResponse> chatResponse = completionChunks.map(this::chunkToChatCompletion) |
| 302 | + .switchMap(chatCompletion -> Mono.just(chatCompletion).map(chatCompletion2 -> { |
| 303 | + try { |
| 304 | + @SuppressWarnings("null") |
| 305 | + String id = chatCompletion2.id(); |
| 306 | + |
| 307 | + List<Generation> generations = chatCompletion2.choices().stream().map(choice -> {// @formatter:off |
| 308 | + |
| 309 | + if (choice.message().role() != null) { |
| 310 | + roleMap.putIfAbsent(id, choice.message().role().name()); |
| 311 | + } |
| 312 | + Map<String, Object> metadata = Map.of( |
| 313 | + "id", chatCompletion2.id(), |
| 314 | + "role", roleMap.getOrDefault(id, ""), |
| 315 | + "index", choice.index(), |
| 316 | + "finishReason", choice.finishReason() != null ? choice.finishReason().name() : ""); |
| 317 | + |
| 318 | + return buildGeneration(choice, metadata); |
303 | 319 | }).toList();
|
304 |
| - // @formatter:on |
| 320 | + // @formatter:on |
305 | 321 |
|
306 |
| - if (chatCompletion2.usage() != null) { |
307 | 322 | return new ChatResponse(generations, from(chatCompletion2, null));
|
308 | 323 | }
|
309 |
| - else { |
310 |
| - return new ChatResponse(generations); |
| 324 | + catch (Exception e) { |
| 325 | + logger.error("Error processing chat completion", e); |
| 326 | + return new ChatResponse(List.of()); |
311 | 327 | }
|
312 |
| - } |
313 |
| - catch (Exception e) { |
314 |
| - logger.error("Error processing chat completion", e); |
315 |
| - return new ChatResponse(List.of()); |
316 |
| - } |
317 | 328 |
|
318 |
| - })); |
| 329 | + })); |
319 | 330 |
|
320 |
| - return chatResponse.flatMap(response -> { |
| 331 | + // @formatter:off |
| 332 | + Flux<ChatResponse> flux = chatResponse.flatMap(response -> { |
| 333 | + |
| 334 | + if (isToolCall(response, Set.of(OpenAiApi.ChatCompletionFinishReason.TOOL_CALLS.name(), |
| 335 | + OpenAiApi.ChatCompletionFinishReason.STOP.name()))) { |
| 336 | + var toolCallConversation = handleToolCalls(prompt, response); |
| 337 | + // Recursively call the stream method with the tool call message |
| 338 | + // conversation that contains the call responses. |
| 339 | + return this.stream(new Prompt(toolCallConversation, prompt.getOptions())); |
| 340 | + } |
| 341 | + else { |
| 342 | + return Flux.just(response); |
| 343 | + } |
| 344 | + }) |
| 345 | + .doOnError(observation::error) |
| 346 | + .doFinally(s -> { |
| 347 | + // TODO: Consider a custom ObservationContext and |
| 348 | + // include additional metadata |
| 349 | + // if (s == SignalType.CANCEL) { |
| 350 | + // observationContext.setAborted(true); |
| 351 | + // } |
| 352 | + observation.stop(); |
| 353 | + }) |
| 354 | + .contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation)); |
| 355 | + // @formatter:on |
| 356 | + |
| 357 | + return new MessageAggregator().aggregate(flux, mergedChatResponse -> { |
| 358 | + observationContext.setResponse(mergedChatResponse); |
| 359 | + }); |
321 | 360 |
|
322 |
| - if (isToolCall(response, Set.of(OpenAiApi.ChatCompletionFinishReason.TOOL_CALLS.name(), |
323 |
| - OpenAiApi.ChatCompletionFinishReason.STOP.name()))) { |
324 |
| - var toolCallConversation = handleToolCalls(prompt, response); |
325 |
| - // Recursively call the stream method with the tool call message |
326 |
| - // conversation that contains the call responses. |
327 |
| - return this.stream(new Prompt(toolCallConversation, prompt.getOptions())); |
328 |
| - } |
329 |
| - else { |
330 |
| - return Flux.just(response); |
331 |
| - } |
332 | 361 | });
|
333 | 362 | }
|
334 | 363 |
|
|
0 commit comments