Skip to content

Commit 3978e8e

Browse files
committed
Enhance OpenAI Authentication and Configuration
- Add org-id and project-id properties with unified merging logic - Update autoconfig and docs for all OpenAI models - Introduce OpenAiChatOptions#httpHeaders option - Add integration test for httpHeaders and update docs Resolves #1141
1 parent 4c1347d commit 3978e8e

File tree

17 files changed

+853
-723
lines changed

17 files changed

+853
-723
lines changed

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,16 @@
1515
*/
1616
package org.springframework.ai.openai;
1717

18-
import io.micrometer.observation.ObservationRegistry;
18+
import java.util.ArrayList;
19+
import java.util.Base64;
20+
import java.util.HashMap;
21+
import java.util.HashSet;
22+
import java.util.List;
23+
import java.util.Map;
24+
import java.util.Set;
25+
import java.util.concurrent.ConcurrentHashMap;
26+
import java.util.stream.Collectors;
27+
1928
import org.slf4j.Logger;
2029
import org.slf4j.LoggerFactory;
2130
import org.springframework.ai.chat.messages.AssistantMessage;
@@ -26,8 +35,16 @@
2635
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
2736
import org.springframework.ai.chat.metadata.EmptyUsage;
2837
import org.springframework.ai.chat.metadata.RateLimit;
29-
import org.springframework.ai.chat.model.*;
30-
import org.springframework.ai.chat.observation.*;
38+
import org.springframework.ai.chat.model.AbstractToolCallSupport;
39+
import org.springframework.ai.chat.model.ChatModel;
40+
import org.springframework.ai.chat.model.ChatResponse;
41+
import org.springframework.ai.chat.model.Generation;
42+
import org.springframework.ai.chat.model.StreamingChatModel;
43+
import org.springframework.ai.chat.observation.ChatModelObservationContext;
44+
import org.springframework.ai.chat.observation.ChatModelObservationConvention;
45+
import org.springframework.ai.chat.observation.ChatModelObservationDocumentation;
46+
import org.springframework.ai.chat.observation.ChatModelRequestOptions;
47+
import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention;
3148
import org.springframework.ai.chat.prompt.ChatOptions;
3249
import org.springframework.ai.chat.prompt.Prompt;
3350
import org.springframework.ai.model.ModelOptionsUtils;
@@ -52,13 +69,13 @@
5269
import org.springframework.util.Assert;
5370
import org.springframework.util.CollectionUtils;
5471
import org.springframework.util.MimeType;
72+
import org.springframework.util.MultiValueMap;
5573
import org.springframework.util.StringUtils;
74+
75+
import io.micrometer.observation.ObservationRegistry;
5676
import reactor.core.publisher.Flux;
5777
import reactor.core.publisher.Mono;
5878

59-
import java.util.*;
60-
import java.util.concurrent.ConcurrentHashMap;
61-
6279
/**
6380
* {@link ChatModel} and {@link StreamingChatModel} implementation for {@literal OpenAI}
6481
* backed by {@link OpenAiApi}.
@@ -204,7 +221,7 @@ public ChatResponse call(Prompt prompt) {
204221
.observe(() -> {
205222

206223
ResponseEntity<ChatCompletion> completionEntity = this.retryTemplate
207-
.execute(ctx -> this.openAiApi.chatCompletionEntity(request));
224+
.execute(ctx -> this.openAiApi.chatCompletionEntity(request, getAdditionalHttpHeaders(prompt)));
208225

209226
var chatCompletion = completionEntity.getBody();
210227

@@ -258,7 +275,7 @@ public Flux<ChatResponse> stream(Prompt prompt) {
258275
ChatCompletionRequest request = createRequest(prompt, true);
259276

260277
Flux<OpenAiApi.ChatCompletionChunk> completionChunks = this.retryTemplate
261-
.execute(ctx -> this.openAiApi.chatCompletionStream(request));
278+
.execute(ctx -> this.openAiApi.chatCompletionStream(request, getAdditionalHttpHeaders(prompt)));
262279

263280
// For chunked responses, only the first chunk contains the choice role.
264281
// The rest of the chunks with same ID share the same role.
@@ -315,6 +332,16 @@ public Flux<ChatResponse> stream(Prompt prompt) {
315332
});
316333
}
317334

335+
private MultiValueMap<String, String> getAdditionalHttpHeaders(Prompt prompt) {
336+
337+
Map<String, String> headers = new HashMap<>(this.defaultOptions.getHttpHeaders());
338+
if (prompt.getOptions() != null && prompt.getOptions() instanceof OpenAiChatOptions chatOptions) {
339+
headers.putAll(chatOptions.getHttpHeaders());
340+
}
341+
return CollectionUtils.toMultiValueMap(
342+
headers.entrySet().stream().collect(Collectors.toMap(e -> e.getKey(), e -> List.of(e.getValue()))));
343+
}
344+
318345
private Generation buildGeneration(Choice choice, Map<String, Object> metadata) {
319346
List<AssistantMessage.ToolCall> toolCalls = choice.message().toolCalls() == null ? List.of()
320347
: choice.message()

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
package org.springframework.ai.openai;
1717

1818
import java.util.ArrayList;
19+
import java.util.HashMap;
1920
import java.util.HashSet;
2021
import java.util.List;
2122
import java.util.Map;
@@ -169,6 +170,13 @@ public class OpenAiChatOptions implements FunctionCallingOptions, ChatOptions {
169170
@NestedConfigurationProperty
170171
@JsonIgnore
171172
private Set<String> functions = new HashSet<>();
173+
174+
/**
175+
* Optional HTTP headers to be added to the chat completion request.
176+
*/
177+
@NestedConfigurationProperty
178+
@JsonIgnore
179+
private Map<String, String> httpHeaders = new HashMap<>();
172180
// @formatter:on
173181

174182
public static Builder builder() {
@@ -299,6 +307,12 @@ public Builder withFunction(String functionName) {
299307
return this;
300308
}
301309

310+
public Builder withHttpHeaders(Map<String, String> httpHeaders) {
311+
Assert.notNull(httpHeaders, "HTTP headers must not be null");
312+
this.options.httpHeaders = httpHeaders;
313+
return this;
314+
}
315+
302316
public OpenAiChatOptions build() {
303317
return this.options;
304318
}
@@ -478,6 +492,14 @@ public void setFunctions(Set<String> functionNames) {
478492
this.functions = functionNames;
479493
}
480494

495+
public Map<String, String> getHttpHeaders() {
496+
return this.httpHeaders;
497+
}
498+
499+
public void setHttpHeaders(Map<String, String> httpHeaders) {
500+
this.httpHeaders = httpHeaders;
501+
}
502+
481503
@Override
482504
public int hashCode() {
483505
final int prime = 31;
@@ -662,6 +684,7 @@ public static OpenAiChatOptions fromOptions(OpenAiChatOptions fromOptions) {
662684
.withParallelToolCalls(fromOptions.getParallelToolCalls())
663685
.withFunctionCallbacks(fromOptions.getFunctionCallbacks())
664686
.withFunctions(fromOptions.getFunctions())
687+
.withHttpHeaders(fromOptions.getHttpHeaders())
665688
.build();
666689
}
667690

0 commit comments

Comments
 (0)