Skip to content

Commit 1196ef2

Browse files
committed
Fix OpenAI API Tool Choice configuraiton options
- Related to the https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice - It seems that when a cuntion type is set explicitely: {"type": "function", "function": {"name": "my_function"}} the parallel calling is not working anymore! Resolves #551
1 parent 4e473aa commit 1196ef2

File tree

2 files changed

+7
-14
lines changed

2 files changed

+7
-14
lines changed

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ public record ChatCompletionRequest (
319319
@JsonProperty("temperature") Float temperature,
320320
@JsonProperty("top_p") Float topP,
321321
@JsonProperty("tools") List<FunctionTool> tools,
322-
@JsonProperty("tool_choice") String toolChoice,
322+
@JsonProperty("tool_choice") Object toolChoice,
323323
@JsonProperty("user") String user) {
324324

325325
/**
@@ -360,7 +360,7 @@ public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model,
360360
* @param toolChoice Controls which (if any) function is called by the model.
361361
*/
362362
public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model,
363-
List<FunctionTool> tools, String toolChoice) {
363+
List<FunctionTool> tools, Object toolChoice) {
364364
this(messages, model, null, null, null, null, null, null, null,
365365
null, null, null, false, 0.8f, null,
366366
tools, toolChoice, null);
@@ -396,8 +396,8 @@ public static class ToolChoiceBuilder {
396396
/**
397397
* Specifying a particular function forces the model to call that function.
398398
*/
399-
public static String FUNCTION(String functionName) {
400-
return ModelOptionsUtils.toJsonString(Map.of("type", "function", "function", Map.of("name", functionName)));
399+
public static Object FUNCTION(String functionName) {
400+
return Map.of("type", "function", "function", Map.of("name", functionName));
401401
}
402402
}
403403

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/tool/OpenAiApiToolFunctionCallIT.java

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage;
3333
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage.Role;
3434
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage.ToolCall;
35+
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.ToolChoiceBuilder;
3536
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest;
3637
import org.springframework.ai.openai.api.OpenAiApi.FunctionTool.Type;
3738
import org.springframework.http.ResponseEntity;
@@ -89,19 +90,11 @@ public void toolFunctionCall() {
8990
}
9091
""")));
9192

92-
// Or you can use the
93-
// ModelOptionsUtils.getJsonSchema(FakeWeatherService.Request.class))) to
94-
// auto-generate the JSON schema like:
95-
// var functionTool = new OpenAiApi.FunctionTool(Type.FUNCTION, new
96-
// OpenAiApi.FunctionTool.Function(
97-
// "Get the weather in location. Return temperature in 30°F or 30°C format.",
98-
// "getCurrentWeather",
99-
// ModelOptionsUtils.getJsonSchema(FakeWeatherService.Request.class)));
100-
10193
List<ChatCompletionMessage> messages = new ArrayList<>(List.of(message));
10294

10395
ChatCompletionRequest chatCompletionRequest = new ChatCompletionRequest(messages, "gpt-4-turbo-preview",
104-
List.of(functionTool), null);
96+
List.of(functionTool), ToolChoiceBuilder.AUTO);
97+
// List.of(functionTool), ToolChoiceBuilder.FUNCTION("getCurrentWeather"));
10598

10699
ResponseEntity<ChatCompletion> chatCompletion = completionApi.chatCompletionEntity(chatCompletionRequest);
107100

0 commit comments

Comments
 (0)