Skip to content

feat: support the chat_template_kwargs, with OpenAI-Compatible Server… #3744

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,12 @@ public class OpenAiChatOptions implements ToolCallingChatOptions {
*/
private @JsonProperty("web_search_options") WebSearchOptions webSearchOptions;

/**
* This extra body for support thinking outside the context of the conversation.
*/
private @JsonProperty("chat_template_kwargs") Map<String,Object> chatTemplateKwargs;


/**
* Collection of {@link ToolCallback}s to be used for tool calling in the chat completion requests.
*/
Expand Down Expand Up @@ -268,6 +274,7 @@ public static OpenAiChatOptions fromOptions(OpenAiChatOptions fromOptions) {
.metadata(fromOptions.getMetadata())
.reasoningEffort(fromOptions.getReasoningEffort())
.webSearchOptions(fromOptions.getWebSearchOptions())
.chatTemplateKwargs(fromOptions.chatTemplateKwargs)
.build();
}

Expand Down Expand Up @@ -564,6 +571,14 @@ public void setWebSearchOptions(WebSearchOptions webSearchOptions) {
this.webSearchOptions = webSearchOptions;
}

public Map<String, Object> getChatTemplateKwargs() {
return chatTemplateKwargs;
}

public void setChatTemplateKwargs(Map<String, Object> chatTemplateKwargs) {
this.chatTemplateKwargs = chatTemplateKwargs;
}

@Override
public OpenAiChatOptions copy() {
return OpenAiChatOptions.fromOptions(this);
Expand All @@ -576,7 +591,7 @@ public int hashCode() {
this.streamOptions, this.seed, this.stop, this.temperature, this.topP, this.tools, this.toolChoice,
this.user, this.parallelToolCalls, this.toolCallbacks, this.toolNames, this.httpHeaders,
this.internalToolExecutionEnabled, this.toolContext, this.outputModalities, this.outputAudio,
this.store, this.metadata, this.reasoningEffort, this.webSearchOptions);
this.store, this.metadata, this.reasoningEffort, this.webSearchOptions, this.chatTemplateKwargs);
}

@Override
Expand Down Expand Up @@ -609,7 +624,8 @@ public boolean equals(Object o) {
&& Objects.equals(this.outputAudio, other.outputAudio) && Objects.equals(this.store, other.store)
&& Objects.equals(this.metadata, other.metadata)
&& Objects.equals(this.reasoningEffort, other.reasoningEffort)
&& Objects.equals(this.webSearchOptions, other.webSearchOptions);
&& Objects.equals(this.webSearchOptions, other.webSearchOptions)
&& Objects.equals(this.chatTemplateKwargs, other.chatTemplateKwargs);
}

@Override
Expand Down Expand Up @@ -802,6 +818,11 @@ public Builder webSearchOptions(WebSearchOptions webSearchOptions) {
return this;
}

public Builder chatTemplateKwargs(Map<String, Object> chatTemplateKwargs) {
this.options.chatTemplateKwargs = chatTemplateKwargs;
return this;
}

public OpenAiChatOptions build() {
return this.options;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1057,7 +1057,8 @@ public record ChatCompletionRequest(// @formatter:off
@JsonProperty("parallel_tool_calls") Boolean parallelToolCalls,
@JsonProperty("user") String user,
@JsonProperty("reasoning_effort") String reasoningEffort,
@JsonProperty("web_search_options") WebSearchOptions webSearchOptions) {
@JsonProperty("web_search_options") WebSearchOptions webSearchOptions,
@JsonProperty("chat_template_kwargs") Map<String,Object> chatTemplateKwargs) {

/**
* Shortcut constructor for a chat completion request with the given messages, model and temperature.
Expand All @@ -1069,7 +1070,7 @@ public record ChatCompletionRequest(// @formatter:off
public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model, Double temperature) {
this(messages, model, null, null, null, null, null, null, null, null, null, null, null, null, null,
null, null, null, false, null, temperature, null,
null, null, null, null, null, null);
null, null, null, null, null, null, null);
}

/**
Expand All @@ -1083,7 +1084,7 @@ public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model,
this(messages, model, null, null, null, null, null, null,
null, null, null, List.of(OutputModality.AUDIO, OutputModality.TEXT), audio, null, null,
null, null, null, stream, null, null, null,
null, null, null, null, null, null);
null, null, null, null, null, null, null);
}

/**
Expand All @@ -1098,7 +1099,7 @@ public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model,
public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model, Double temperature, boolean stream) {
this(messages, model, null, null, null, null, null, null, null, null, null,
null, null, null, null, null, null, null, stream, null, temperature, null,
null, null, null, null, null, null);
null, null, null, null, null, null, null);
}

/**
Expand All @@ -1114,7 +1115,7 @@ public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model,
List<FunctionTool> tools, Object toolChoice) {
this(messages, model, null, null, null, null, null, null, null, null, null,
null, null, null, null, null, null, null, false, null, 0.8, null,
tools, toolChoice, null, null, null, null);
tools, toolChoice, null, null, null, null, null);
}

/**
Expand All @@ -1127,7 +1128,7 @@ public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model,
public ChatCompletionRequest(List<ChatCompletionMessage> messages, Boolean stream) {
this(messages, null, null, null, null, null, null, null, null, null, null,
null, null, null, null, null, null, null, stream, null, null, null,
null, null, null, null, null, null);
null, null, null, null, null, null, null);
}

/**
Expand All @@ -1140,7 +1141,7 @@ public ChatCompletionRequest streamOptions(StreamOptions streamOptions) {
return new ChatCompletionRequest(this.messages, this.model, this.store, this.metadata, this.frequencyPenalty, this.logitBias, this.logprobs,
this.topLogprobs, this.maxTokens, this.maxCompletionTokens, this.n, this.outputModalities, this.audioParameters, this.presencePenalty,
this.responseFormat, this.seed, this.serviceTier, this.stop, this.stream, streamOptions, this.temperature, this.topP,
this.tools, this.toolChoice, this.parallelToolCalls, this.user, this.reasoningEffort, this.webSearchOptions);
this.tools, this.toolChoice, this.parallelToolCalls, this.user, this.reasoningEffort, this.webSearchOptions, this.chatTemplateKwargs);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ void testCopy() {
.reasoningEffort("low")
.internalToolExecutionEnabled(true)
.httpHeaders(Map.of("header1", "value1"))
.chatTemplateKwargs(Map.of("enable_thinking", true))
.build();

OpenAiChatOptions copiedOptions = originalOptions.copy();
Expand Down Expand Up @@ -189,6 +190,7 @@ void testSetters() {
options.setReasoningEffort("high");
options.setInternalToolExecutionEnabled(false);
options.setHttpHeaders(Map.of("header2", "value2"));
options.setChatTemplateKwargs(Map.of("enable_thinking", true));

assertThat(options.getModel()).isEqualTo("test-model");
assertThat(options.getFrequencyPenalty()).isEqualTo(0.5);
Expand Down Expand Up @@ -223,6 +225,7 @@ void testSetters() {
options.setStopSequences(List.of("s1", "s2"));
assertThat(options.getStopSequences()).isEqualTo(List.of("s1", "s2"));
assertThat(options.getStop()).isEqualTo(List.of("s1", "s2"));
assertThat(options.getChatTemplateKwargs()).isEqualTo(Map.of("enable_thinking", true));
}

@Test
Expand Down Expand Up @@ -258,6 +261,7 @@ void testDefaultValues() {
assertThat(options.getToolContext()).isEqualTo(new HashMap<>());
assertThat(options.getStreamUsage()).isFalse();
assertThat(options.getStopSequences()).isNull();
assertThat(options.getChatTemplateKwargs()).isNull();
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ void validateReasoningTokens() {
"If a train travels 100 miles in 2 hours, what is its average speed?", ChatCompletionMessage.Role.USER);
ChatCompletionRequest request = new ChatCompletionRequest(List.of(userMessage), "o1", null, null, null, null,
null, null, null, null, null, null, null, null, null, null, null, null, false, null, null, null, null,
null, null, null, "low", null);
null, null, null, "low", null, null);
ResponseEntity<ChatCompletion> response = this.openAiApi.chatCompletionEntity(request);

assertThat(response).isNotNull();
Expand Down