From 602194318c50c75c08fa17d20a6545a30ee273dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=AE=B6=E5=A8=83?= Date: Fri, 4 Jul 2025 16:45:23 +0800 Subject: [PATCH] feat: support the chat_template_kwargs, with OpenAI-Compatible Server(https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#extra-parameters_3), relate: (#3409) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 家娃 --- .../ai/openai/OpenAiChatOptions.java | 25 +++++++++++++++++-- .../ai/openai/api/OpenAiApi.java | 15 +++++------ .../ai/openai/OpenAiChatOptionsTests.java | 4 +++ .../ai/openai/api/OpenAiApiIT.java | 2 +- 4 files changed, 36 insertions(+), 10 deletions(-) diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java index afbbd803ec6..bc25e46f8af 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java @@ -201,6 +201,12 @@ public class OpenAiChatOptions implements ToolCallingChatOptions { */ private @JsonProperty("web_search_options") WebSearchOptions webSearchOptions; + /** + * This extra body for support thinking outside the context of the conversation. + */ + private @JsonProperty("chat_template_kwargs") Map chatTemplateKwargs; + + /** * Collection of {@link ToolCallback}s to be used for tool calling in the chat completion requests. */ @@ -268,6 +274,7 @@ public static OpenAiChatOptions fromOptions(OpenAiChatOptions fromOptions) { .metadata(fromOptions.getMetadata()) .reasoningEffort(fromOptions.getReasoningEffort()) .webSearchOptions(fromOptions.getWebSearchOptions()) + .chatTemplateKwargs(fromOptions.chatTemplateKwargs) .build(); } @@ -564,6 +571,14 @@ public void setWebSearchOptions(WebSearchOptions webSearchOptions) { this.webSearchOptions = webSearchOptions; } + public Map getChatTemplateKwargs() { + return chatTemplateKwargs; + } + + public void setChatTemplateKwargs(Map chatTemplateKwargs) { + this.chatTemplateKwargs = chatTemplateKwargs; + } + @Override public OpenAiChatOptions copy() { return OpenAiChatOptions.fromOptions(this); @@ -576,7 +591,7 @@ public int hashCode() { this.streamOptions, this.seed, this.stop, this.temperature, this.topP, this.tools, this.toolChoice, this.user, this.parallelToolCalls, this.toolCallbacks, this.toolNames, this.httpHeaders, this.internalToolExecutionEnabled, this.toolContext, this.outputModalities, this.outputAudio, - this.store, this.metadata, this.reasoningEffort, this.webSearchOptions); + this.store, this.metadata, this.reasoningEffort, this.webSearchOptions, this.chatTemplateKwargs); } @Override @@ -609,7 +624,8 @@ public boolean equals(Object o) { && Objects.equals(this.outputAudio, other.outputAudio) && Objects.equals(this.store, other.store) && Objects.equals(this.metadata, other.metadata) && Objects.equals(this.reasoningEffort, other.reasoningEffort) - && Objects.equals(this.webSearchOptions, other.webSearchOptions); + && Objects.equals(this.webSearchOptions, other.webSearchOptions) + && Objects.equals(this.chatTemplateKwargs, other.chatTemplateKwargs); } @Override @@ -802,6 +818,11 @@ public Builder webSearchOptions(WebSearchOptions webSearchOptions) { return this; } + public Builder chatTemplateKwargs(Map chatTemplateKwargs) { + this.options.chatTemplateKwargs = chatTemplateKwargs; + return this; + } + public OpenAiChatOptions build() { return this.options; } diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java index 383065fc209..1663002f706 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java @@ -1057,7 +1057,8 @@ public record ChatCompletionRequest(// @formatter:off @JsonProperty("parallel_tool_calls") Boolean parallelToolCalls, @JsonProperty("user") String user, @JsonProperty("reasoning_effort") String reasoningEffort, - @JsonProperty("web_search_options") WebSearchOptions webSearchOptions) { + @JsonProperty("web_search_options") WebSearchOptions webSearchOptions, + @JsonProperty("chat_template_kwargs") Map chatTemplateKwargs) { /** * Shortcut constructor for a chat completion request with the given messages, model and temperature. @@ -1069,7 +1070,7 @@ public record ChatCompletionRequest(// @formatter:off public ChatCompletionRequest(List messages, String model, Double temperature) { this(messages, model, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, false, null, temperature, null, - null, null, null, null, null, null); + null, null, null, null, null, null, null); } /** @@ -1083,7 +1084,7 @@ public ChatCompletionRequest(List messages, String model, this(messages, model, null, null, null, null, null, null, null, null, null, List.of(OutputModality.AUDIO, OutputModality.TEXT), audio, null, null, null, null, null, stream, null, null, null, - null, null, null, null, null, null); + null, null, null, null, null, null, null); } /** @@ -1098,7 +1099,7 @@ public ChatCompletionRequest(List messages, String model, public ChatCompletionRequest(List messages, String model, Double temperature, boolean stream) { this(messages, model, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, stream, null, temperature, null, - null, null, null, null, null, null); + null, null, null, null, null, null, null); } /** @@ -1114,7 +1115,7 @@ public ChatCompletionRequest(List messages, String model, List tools, Object toolChoice) { this(messages, model, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, false, null, 0.8, null, - tools, toolChoice, null, null, null, null); + tools, toolChoice, null, null, null, null, null); } /** @@ -1127,7 +1128,7 @@ public ChatCompletionRequest(List messages, String model, public ChatCompletionRequest(List messages, Boolean stream) { this(messages, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, stream, null, null, null, - null, null, null, null, null, null); + null, null, null, null, null, null, null); } /** @@ -1140,7 +1141,7 @@ public ChatCompletionRequest streamOptions(StreamOptions streamOptions) { return new ChatCompletionRequest(this.messages, this.model, this.store, this.metadata, this.frequencyPenalty, this.logitBias, this.logprobs, this.topLogprobs, this.maxTokens, this.maxCompletionTokens, this.n, this.outputModalities, this.audioParameters, this.presencePenalty, this.responseFormat, this.seed, this.serviceTier, this.stop, this.stream, streamOptions, this.temperature, this.topP, - this.tools, this.toolChoice, this.parallelToolCalls, this.user, this.reasoningEffort, this.webSearchOptions); + this.tools, this.toolChoice, this.parallelToolCalls, this.user, this.reasoningEffort, this.webSearchOptions, this.chatTemplateKwargs); } /** diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiChatOptionsTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiChatOptionsTests.java index 70b7f1fad66..691c19cf8a2 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiChatOptionsTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiChatOptionsTests.java @@ -141,6 +141,7 @@ void testCopy() { .reasoningEffort("low") .internalToolExecutionEnabled(true) .httpHeaders(Map.of("header1", "value1")) + .chatTemplateKwargs(Map.of("enable_thinking", true)) .build(); OpenAiChatOptions copiedOptions = originalOptions.copy(); @@ -189,6 +190,7 @@ void testSetters() { options.setReasoningEffort("high"); options.setInternalToolExecutionEnabled(false); options.setHttpHeaders(Map.of("header2", "value2")); + options.setChatTemplateKwargs(Map.of("enable_thinking", true)); assertThat(options.getModel()).isEqualTo("test-model"); assertThat(options.getFrequencyPenalty()).isEqualTo(0.5); @@ -223,6 +225,7 @@ void testSetters() { options.setStopSequences(List.of("s1", "s2")); assertThat(options.getStopSequences()).isEqualTo(List.of("s1", "s2")); assertThat(options.getStop()).isEqualTo(List.of("s1", "s2")); + assertThat(options.getChatTemplateKwargs()).isEqualTo(Map.of("enable_thinking", true)); } @Test @@ -258,6 +261,7 @@ void testDefaultValues() { assertThat(options.getToolContext()).isEqualTo(new HashMap<>()); assertThat(options.getStreamUsage()).isFalse(); assertThat(options.getStopSequences()).isNull(); + assertThat(options.getChatTemplateKwargs()).isNull(); } @Test diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiIT.java index bf56a9fc2e8..7d95e48dd8c 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiIT.java @@ -75,7 +75,7 @@ void validateReasoningTokens() { "If a train travels 100 miles in 2 hours, what is its average speed?", ChatCompletionMessage.Role.USER); ChatCompletionRequest request = new ChatCompletionRequest(List.of(userMessage), "o1", null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, false, null, null, null, null, - null, null, null, "low", null); + null, null, null, "low", null, null); ResponseEntity response = this.openAiApi.chatCompletionEntity(request); assertThat(response).isNotNull();