|
15 | 15 | */
|
16 | 16 | package org.springframework.ai.openai;
|
17 | 17 |
|
18 |
| -import io.micrometer.observation.ObservationRegistry; |
| 18 | +import java.util.ArrayList; |
| 19 | +import java.util.Base64; |
| 20 | +import java.util.HashMap; |
| 21 | +import java.util.HashSet; |
| 22 | +import java.util.List; |
| 23 | +import java.util.Map; |
| 24 | +import java.util.Set; |
| 25 | +import java.util.concurrent.ConcurrentHashMap; |
| 26 | +import java.util.stream.Collectors; |
| 27 | + |
19 | 28 | import org.slf4j.Logger;
|
20 | 29 | import org.slf4j.LoggerFactory;
|
21 | 30 | import org.springframework.ai.chat.messages.AssistantMessage;
|
|
26 | 35 | import org.springframework.ai.chat.metadata.ChatResponseMetadata;
|
27 | 36 | import org.springframework.ai.chat.metadata.EmptyUsage;
|
28 | 37 | import org.springframework.ai.chat.metadata.RateLimit;
|
29 |
| -import org.springframework.ai.chat.model.*; |
30 |
| -import org.springframework.ai.chat.observation.*; |
| 38 | +import org.springframework.ai.chat.model.AbstractToolCallSupport; |
| 39 | +import org.springframework.ai.chat.model.ChatModel; |
| 40 | +import org.springframework.ai.chat.model.ChatResponse; |
| 41 | +import org.springframework.ai.chat.model.Generation; |
| 42 | +import org.springframework.ai.chat.model.StreamingChatModel; |
| 43 | +import org.springframework.ai.chat.observation.ChatModelObservationContext; |
| 44 | +import org.springframework.ai.chat.observation.ChatModelObservationConvention; |
| 45 | +import org.springframework.ai.chat.observation.ChatModelObservationDocumentation; |
| 46 | +import org.springframework.ai.chat.observation.ChatModelRequestOptions; |
| 47 | +import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention; |
31 | 48 | import org.springframework.ai.chat.prompt.ChatOptions;
|
32 | 49 | import org.springframework.ai.chat.prompt.Prompt;
|
33 | 50 | import org.springframework.ai.model.ModelOptionsUtils;
|
|
52 | 69 | import org.springframework.util.Assert;
|
53 | 70 | import org.springframework.util.CollectionUtils;
|
54 | 71 | import org.springframework.util.MimeType;
|
| 72 | +import org.springframework.util.MultiValueMap; |
55 | 73 | import org.springframework.util.StringUtils;
|
| 74 | + |
| 75 | +import io.micrometer.observation.ObservationRegistry; |
56 | 76 | import reactor.core.publisher.Flux;
|
57 | 77 | import reactor.core.publisher.Mono;
|
58 | 78 |
|
59 |
| -import java.util.*; |
60 |
| -import java.util.concurrent.ConcurrentHashMap; |
61 |
| - |
62 | 79 | /**
|
63 | 80 | * {@link ChatModel} and {@link StreamingChatModel} implementation for {@literal OpenAI}
|
64 | 81 | * backed by {@link OpenAiApi}.
|
@@ -204,7 +221,7 @@ public ChatResponse call(Prompt prompt) {
|
204 | 221 | .observe(() -> {
|
205 | 222 |
|
206 | 223 | ResponseEntity<ChatCompletion> completionEntity = this.retryTemplate
|
207 |
| - .execute(ctx -> this.openAiApi.chatCompletionEntity(request)); |
| 224 | + .execute(ctx -> this.openAiApi.chatCompletionEntity(request, getAdditionalHttpHeaders(prompt))); |
208 | 225 |
|
209 | 226 | var chatCompletion = completionEntity.getBody();
|
210 | 227 |
|
@@ -258,7 +275,7 @@ public Flux<ChatResponse> stream(Prompt prompt) {
|
258 | 275 | ChatCompletionRequest request = createRequest(prompt, true);
|
259 | 276 |
|
260 | 277 | Flux<OpenAiApi.ChatCompletionChunk> completionChunks = this.retryTemplate
|
261 |
| - .execute(ctx -> this.openAiApi.chatCompletionStream(request)); |
| 278 | + .execute(ctx -> this.openAiApi.chatCompletionStream(request, getAdditionalHttpHeaders(prompt))); |
262 | 279 |
|
263 | 280 | // For chunked responses, only the first chunk contains the choice role.
|
264 | 281 | // The rest of the chunks with same ID share the same role.
|
@@ -315,6 +332,16 @@ public Flux<ChatResponse> stream(Prompt prompt) {
|
315 | 332 | });
|
316 | 333 | }
|
317 | 334 |
|
| 335 | + private MultiValueMap<String, String> getAdditionalHttpHeaders(Prompt prompt) { |
| 336 | + |
| 337 | + Map<String, String> headers = new HashMap<>(this.defaultOptions.getHttpHeaders()); |
| 338 | + if (prompt.getOptions() != null && prompt.getOptions() instanceof OpenAiChatOptions chatOptions) { |
| 339 | + headers.putAll(chatOptions.getHttpHeaders()); |
| 340 | + } |
| 341 | + return CollectionUtils.toMultiValueMap( |
| 342 | + headers.entrySet().stream().collect(Collectors.toMap(e -> e.getKey(), e -> List.of(e.getValue())))); |
| 343 | + } |
| 344 | + |
318 | 345 | private Generation buildGeneration(Choice choice, Map<String, Object> metadata) {
|
319 | 346 | List<AssistantMessage.ToolCall> toolCalls = choice.message().toolCalls() == null ? List.of()
|
320 | 347 | : choice.message()
|
|
0 commit comments