From f9d08c996577afa809ffb20148c1c754cb8bef80 Mon Sep 17 00:00:00 2001 From: Ilayaperumal Gopinathan Date: Thu, 20 Mar 2025 16:07:13 +0000 Subject: [PATCH] Remove FunctionCallback deprecations - Remove the super type FunctionCallingOptions from ToolCallingChatOptions - Move toolContext builder methods into ToolCallingChatOptions - Remove Model chat options' function specific usages - Replace them with tooling: FunctionCallback -> ToolCallback functions -> toolNames - Remove proxyToolCalls use - Remove deprecated methods - Update ChatClient methods - Replace FunctionCallback -> ToolCallback - Remove deprecated methods - Update DefaultChatClient - functionNames -> toolNames - functionCallbacks -> toolCallbacks - Update AdviseRequest - functionNames -> toolNames - functionCallbacks -> toolCallbacks - Remove FunctionCallingOptions and replace it with ToolCallingOptions - Remove FunctionCallingHelper - Update DefaultToolCallingChatOptions, ToolCallbackResolvers, ToolCallbackProvider to use Tool calling types - Update documentation Resolves #2528 Signed-off-by: Ilayaperumal Gopinathan --- .../tool/FunctionCallWithFunctionBeanIT.java | 4 +- .../FunctionCallWithPromptFunctionIT.java | 2 +- .../tool/FunctionCallWithFunctionBeanIT.java | 4 +- .../FunctionCallWithFunctionWrapperIT.java | 2 +- .../FunctionCallWithPromptFunctionIT.java | 2 +- .../tool/PaymentStatusBeanIT.java | 4 +- .../tool/PaymentStatusBeanOpenAiIT.java | 4 +- .../tool/PaymentStatusPromptIT.java | 2 +- .../tool/WeatherServicePromptIT.java | 9 +- .../tool/FunctionCallbackInPromptIT.java | 6 +- .../tool/OllamaFunctionCallbackIT.java | 4 +- .../tool/OllamaFunctionToolBeanIT.java | 2 +- .../tool/FunctionCallbackContextKotlinIT.kt | 2 +- .../tool/FunctionCallbackInPrompt2IT.java | 8 +- .../tool/FunctionCallbackInPromptIT.java | 4 +- ...nctionCallbackWithPlainFunctionBeanIT.java | 24 +- .../tool/OpenAiFunctionCallback2IT.java | 4 +- .../tool/OpenAiFunctionCallbackIT.java | 4 +- .../ToolCallingAutoConfiguration.java | 6 +- .../ToolCallingAutoConfigurationTests.java | 11 +- .../ai/anthropic/AnthropicChatOptions.java | 76 +-- .../ai/anthropic/AnthropicChatModelIT.java | 2 +- .../azure/openai/AzureOpenAiChatOptions.java | 74 +-- .../AzureOpenAiChatModelFunctionCallIT.java | 10 +- .../converse/api/ConverseApiUtils.java | 2 +- .../minimax/ChatCompletionRequestTests.java | 18 +- .../ai/mistralai/MistralAiChatModel.java | 5 - .../ai/mistralai/MistralAiChatOptions.java | 75 +-- .../MistralAiChatCompletionRequestTest.java | 3 +- .../ai/mistralai/MistralAiChatModelIT.java | 6 +- .../ai/ollama/api/OllamaOptions.java | 82 +-- .../OllamaChatModelFunctionCallingIT.java | 4 +- .../ai/ollama/OllamaChatRequestTests.java | 3 +- .../ollama/api/OllamaModelOptionsTests.java | 27 +- .../ai/openai/OpenAiChatModel.java | 5 - .../ai/openai/OpenAiChatOptions.java | 79 +-- .../ai/openai/ChatCompletionRequestTests.java | 12 +- .../ai/openai/OpenAiChatOptionsTests.java | 17 +- .../OpenAiChatModelFunctionCallingIT.java | 17 +- .../ai/openai/chat/OpenAiChatModelIT.java | 13 +- .../chat/OpenAiPaymentTransactionIT.java | 8 +- .../OpenAiChatClientProxyFunctionCallsIT.java | 185 ------ .../proxy/DeepSeekWithOpenAiChatModelIT.java | 8 +- ...ockerModelRunnerWithOpenAiChatModelIT.java | 8 +- .../chat/proxy/GroqWithOpenAiChatModelIT.java | 8 +- .../proxy/MistralWithOpenAiChatModelIT.java | 4 +- .../proxy/NvidiaWithOpenAiChatModelIT.java | 8 +- .../proxy/OllamaWithOpenAiChatModelIT.java | 8 +- .../PerplexityWithOpenAiChatModelIT.java | 8 +- .../gemini/VertexAiGeminiChatModel.java | 99 --- .../gemini/VertexAiGeminiChatOptions.java | 56 +- .../gemini/CreateGeminiRequestTests.java | 7 +- .../gemini/TestVertexAiGeminiChatModel.java | 7 +- .../gemini/VertexAiGeminiRetryTests.java | 3 +- ...texAiGeminiChatModelFunctionCallingIT.java | 12 +- .../VertexAiGeminiPaymentTransactionIT.java | 6 +- ...texAiGeminiPaymentTransactionMethodIT.java | 8 +- ...rtexAiGeminiPaymentTransactionToolsIT.java | 6 +- .../ai/chat/client/ChatClient.java | 23 +- .../ai/chat/client/DefaultChatClient.java | 39 +- .../chat/client/DefaultChatClientBuilder.java | 32 +- .../client/advisor/api/AdvisedRequest.java | 64 +- ...efaultChatClientObservationConvention.java | 12 +- .../ai/chat/client/ChatClientTest.java | 72 ++- .../chat/client/DefaultChatClientTests.java | 94 ++- .../advisor/api/AdvisedRequestTests.java | 6 +- ...tChatClientObservationConventionTests.java | 23 +- .../src/main/antora/modules/ROOT/nav.adoc | 2 +- .../functions/anthropic-chat-functions.adoc | 12 +- .../azure-open-ai-chat-functions.adoc | 12 +- .../functions/mistralai-chat-functions.adoc | 12 +- .../chat/functions/ollama-chat-functions.adoc | 12 +- .../chat/functions/openai-chat-functions.adoc | 15 +- .../modules/ROOT/pages/api/functions.adoc | 90 +-- .../model/function/FunctionCallingHelper.java | 189 ------ .../function/FunctionCallingOptions.java | 8 +- .../tool/DefaultToolCallingChatOptions.java | 74 +-- .../model/tool/DefaultToolCallingManager.java | 45 +- .../model/tool/LegacyToolCallingManager.java | 251 -------- .../ai/model/tool/ToolCallingChatOptions.java | 72 +-- .../ai/tool/StaticToolCallbackProvider.java | 15 +- .../ai/tool/ToolCallbackProvider.java | 8 +- .../DelegatingToolCallbackResolver.java | 6 +- .../StaticToolCallbackResolver.java | 15 +- .../tool/resolution/ToolCallbackResolver.java | 2 +- .../ai/tool/util/ToolUtils.java | 9 +- .../chat/prompt/ChatOptionsBuilderTests.java | 15 +- .../DefaultFunctionCallbackBuilderTests.java | 333 ---------- ...ultFunctionCallingOptionsBuilderTests.java | 580 ------------------ .../MethodInvokingFunctionCallbackTests.java | 161 ----- .../DefaultToolCallingChatOptionsTests.java | 18 +- ...oolExecutionEligibilityPredicateTests.java | 16 - .../tool/LegacyToolCallingManagerTests.java | 214 ------- .../tool/ToolCallingChatOptionsTests.java | 40 +- 94 files changed, 520 insertions(+), 3158 deletions(-) delete mode 100644 models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientProxyFunctionCallsIT.java delete mode 100644 spring-ai-model/src/main/java/org/springframework/ai/model/function/FunctionCallingHelper.java delete mode 100644 spring-ai-model/src/main/java/org/springframework/ai/model/tool/LegacyToolCallingManager.java delete mode 100644 spring-ai-model/src/test/java/org/springframework/ai/model/function/DefaultFunctionCallbackBuilderTests.java delete mode 100644 spring-ai-model/src/test/java/org/springframework/ai/model/function/DefaultFunctionCallingOptionsBuilderTests.java delete mode 100644 spring-ai-model/src/test/java/org/springframework/ai/model/function/MethodInvokingFunctionCallbackTests.java delete mode 100644 spring-ai-model/src/test/java/org/springframework/ai/model/tool/LegacyToolCallingManagerTests.java diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-anthropic/src/test/java/org/springframework/ai/model/anthropic/autoconfigure/tool/FunctionCallWithFunctionBeanIT.java b/auto-configurations/models/spring-ai-autoconfigure-model-anthropic/src/test/java/org/springframework/ai/model/anthropic/autoconfigure/tool/FunctionCallWithFunctionBeanIT.java index 8ec56423a35..91efc4bd6f9 100644 --- a/auto-configurations/models/spring-ai-autoconfigure-model-anthropic/src/test/java/org/springframework/ai/model/anthropic/autoconfigure/tool/FunctionCallWithFunctionBeanIT.java +++ b/auto-configurations/models/spring-ai-autoconfigure-model-anthropic/src/test/java/org/springframework/ai/model/anthropic/autoconfigure/tool/FunctionCallWithFunctionBeanIT.java @@ -66,14 +66,14 @@ void functionCallTest() { "What's the weather like in San Francisco, in Paris, France and in Tokyo, Japan? Return the temperature in Celsius."); ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), - AnthropicChatOptions.builder().function("weatherFunction").build())); + AnthropicChatOptions.builder().toolNames("weatherFunction").build())); logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getText()).contains("30", "10", "15"); response = chatModel.call(new Prompt(List.of(userMessage), - AnthropicChatOptions.builder().function("weatherFunction3").build())); + AnthropicChatOptions.builder().toolNames("weatherFunction3").build())); logger.info("Response: {}", response); diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-anthropic/src/test/java/org/springframework/ai/model/anthropic/autoconfigure/tool/FunctionCallWithPromptFunctionIT.java b/auto-configurations/models/spring-ai-autoconfigure-model-anthropic/src/test/java/org/springframework/ai/model/anthropic/autoconfigure/tool/FunctionCallWithPromptFunctionIT.java index c937963cc4f..94e0c9cb354 100644 --- a/auto-configurations/models/spring-ai-autoconfigure-model-anthropic/src/test/java/org/springframework/ai/model/anthropic/autoconfigure/tool/FunctionCallWithPromptFunctionIT.java +++ b/auto-configurations/models/spring-ai-autoconfigure-model-anthropic/src/test/java/org/springframework/ai/model/anthropic/autoconfigure/tool/FunctionCallWithPromptFunctionIT.java @@ -58,7 +58,7 @@ void functionCallTest() { "What's the weather like in San Francisco, in Paris and in Tokyo? Return the temperature in Celsius."); var promptOptions = AnthropicChatOptions.builder() - .functionCallbacks( + .toolCallbacks( List.of(FunctionToolCallback.builder("CurrentWeatherService", new MockWeatherService()) .description("Get the weather in location. Return temperature in 36°F or 36°C format.") .inputType(MockWeatherService.Request.class) diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-azure-openai/src/test/java/org/springframework/ai/model/azure/openai/autoconfigure/tool/FunctionCallWithFunctionBeanIT.java b/auto-configurations/models/spring-ai-autoconfigure-model-azure-openai/src/test/java/org/springframework/ai/model/azure/openai/autoconfigure/tool/FunctionCallWithFunctionBeanIT.java index e4a17e2a7b3..e1b0ff19ca2 100644 --- a/auto-configurations/models/spring-ai-autoconfigure-model-azure-openai/src/test/java/org/springframework/ai/model/azure/openai/autoconfigure/tool/FunctionCallWithFunctionBeanIT.java +++ b/auto-configurations/models/spring-ai-autoconfigure-model-azure-openai/src/test/java/org/springframework/ai/model/azure/openai/autoconfigure/tool/FunctionCallWithFunctionBeanIT.java @@ -67,14 +67,14 @@ void functionCallTest() { "What's the weather like in San Francisco, Paris and in Tokyo? Use Multi-turn function calling."); ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), - AzureOpenAiChatOptions.builder().function("weatherFunction").build())); + AzureOpenAiChatOptions.builder().toolNames("weatherFunction").build())); logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getText()).contains("30", "10", "15"); response = chatModel.call(new Prompt(List.of(userMessage), - AzureOpenAiChatOptions.builder().function("weatherFunction3").build())); + AzureOpenAiChatOptions.builder().toolNames("weatherFunction3").build())); logger.info("Response: {}", response); diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-azure-openai/src/test/java/org/springframework/ai/model/azure/openai/autoconfigure/tool/FunctionCallWithFunctionWrapperIT.java b/auto-configurations/models/spring-ai-autoconfigure-model-azure-openai/src/test/java/org/springframework/ai/model/azure/openai/autoconfigure/tool/FunctionCallWithFunctionWrapperIT.java index c8c61ae1555..244ba7c555a 100644 --- a/auto-configurations/models/spring-ai-autoconfigure-model-azure-openai/src/test/java/org/springframework/ai/model/azure/openai/autoconfigure/tool/FunctionCallWithFunctionWrapperIT.java +++ b/auto-configurations/models/spring-ai-autoconfigure-model-azure-openai/src/test/java/org/springframework/ai/model/azure/openai/autoconfigure/tool/FunctionCallWithFunctionWrapperIT.java @@ -65,7 +65,7 @@ void functionCallTest() { "What's the weather like in San Francisco, Paris and in Tokyo?"); ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), - AzureOpenAiChatOptions.builder().function("WeatherInfo").build())); + AzureOpenAiChatOptions.builder().toolNames("WeatherInfo").build())); logger.info("Response: {}", response); diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-azure-openai/src/test/java/org/springframework/ai/model/azure/openai/autoconfigure/tool/FunctionCallWithPromptFunctionIT.java b/auto-configurations/models/spring-ai-autoconfigure-model-azure-openai/src/test/java/org/springframework/ai/model/azure/openai/autoconfigure/tool/FunctionCallWithPromptFunctionIT.java index 8f2d63b8e3c..e553f9c9ce5 100644 --- a/auto-configurations/models/spring-ai-autoconfigure-model-azure-openai/src/test/java/org/springframework/ai/model/azure/openai/autoconfigure/tool/FunctionCallWithPromptFunctionIT.java +++ b/auto-configurations/models/spring-ai-autoconfigure-model-azure-openai/src/test/java/org/springframework/ai/model/azure/openai/autoconfigure/tool/FunctionCallWithPromptFunctionIT.java @@ -61,7 +61,7 @@ void functionCallTest() { "What's the weather like in San Francisco, in Paris and in Tokyo? Use Multi-turn function calling."); var promptOptions = AzureOpenAiChatOptions.builder() - .functionCallbacks( + .toolCallbacks( List.of(FunctionToolCallback.builder("CurrentWeatherService", new MockWeatherService()) .description("Get the weather in location") .inputType(MockWeatherService.Request.class) diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-mistral-ai/src/test/java/org/springframework/ai/model/mistralai/autoconfigure/tool/PaymentStatusBeanIT.java b/auto-configurations/models/spring-ai-autoconfigure-model-mistral-ai/src/test/java/org/springframework/ai/model/mistralai/autoconfigure/tool/PaymentStatusBeanIT.java index bacc8182cb1..d57c065bd16 100644 --- a/auto-configurations/models/spring-ai-autoconfigure-model-mistral-ai/src/test/java/org/springframework/ai/model/mistralai/autoconfigure/tool/PaymentStatusBeanIT.java +++ b/auto-configurations/models/spring-ai-autoconfigure-model-mistral-ai/src/test/java/org/springframework/ai/model/mistralai/autoconfigure/tool/PaymentStatusBeanIT.java @@ -68,8 +68,8 @@ void functionCallTest() { ChatResponse response = chatModel .call(new Prompt(List.of(new UserMessage("What's the status of my transaction with id T1001?")), MistralAiChatOptions.builder() - .function("retrievePaymentStatus") - .function("retrievePaymentDate") + .toolNames("retrievePaymentStatus") + .toolNames("retrievePaymentDate") .build())); logger.info("Response: {}", response); diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-mistral-ai/src/test/java/org/springframework/ai/model/mistralai/autoconfigure/tool/PaymentStatusBeanOpenAiIT.java b/auto-configurations/models/spring-ai-autoconfigure-model-mistral-ai/src/test/java/org/springframework/ai/model/mistralai/autoconfigure/tool/PaymentStatusBeanOpenAiIT.java index ad66bc99d63..85ef41db018 100644 --- a/auto-configurations/models/spring-ai-autoconfigure-model-mistral-ai/src/test/java/org/springframework/ai/model/mistralai/autoconfigure/tool/PaymentStatusBeanOpenAiIT.java +++ b/auto-configurations/models/spring-ai-autoconfigure-model-mistral-ai/src/test/java/org/springframework/ai/model/mistralai/autoconfigure/tool/PaymentStatusBeanOpenAiIT.java @@ -75,8 +75,8 @@ void functionCallTest() { ChatResponse response = chatModel .call(new Prompt(List.of(new UserMessage("What's the status of my transaction with id T1001?")), OpenAiChatOptions.builder() - .function("retrievePaymentStatus") - .function("retrievePaymentDate") + .toolNames("retrievePaymentStatus") + .toolNames("retrievePaymentDate") .build())); logger.info("Response: {}", response); diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-mistral-ai/src/test/java/org/springframework/ai/model/mistralai/autoconfigure/tool/PaymentStatusPromptIT.java b/auto-configurations/models/spring-ai-autoconfigure-model-mistral-ai/src/test/java/org/springframework/ai/model/mistralai/autoconfigure/tool/PaymentStatusPromptIT.java index c524dcb8523..f7e8a4be090 100644 --- a/auto-configurations/models/spring-ai-autoconfigure-model-mistral-ai/src/test/java/org/springframework/ai/model/mistralai/autoconfigure/tool/PaymentStatusPromptIT.java +++ b/auto-configurations/models/spring-ai-autoconfigure-model-mistral-ai/src/test/java/org/springframework/ai/model/mistralai/autoconfigure/tool/PaymentStatusPromptIT.java @@ -64,7 +64,7 @@ void functionCallTest() { UserMessage userMessage = new UserMessage("What's the status of my transaction with id T1001?"); var promptOptions = MistralAiChatOptions.builder() - .functionCallbacks(List.of(FunctionToolCallback + .toolCallbacks(List.of(FunctionToolCallback .builder("retrievePaymentStatus", (Transaction transaction) -> new Status(DATA.get(transaction).status())) .description("Get payment status of a transaction") diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-mistral-ai/src/test/java/org/springframework/ai/model/mistralai/autoconfigure/tool/WeatherServicePromptIT.java b/auto-configurations/models/spring-ai-autoconfigure-model-mistral-ai/src/test/java/org/springframework/ai/model/mistralai/autoconfigure/tool/WeatherServicePromptIT.java index 45d40d981ff..13d4e9db8bd 100644 --- a/auto-configurations/models/spring-ai-autoconfigure-model-mistral-ai/src/test/java/org/springframework/ai/model/mistralai/autoconfigure/tool/WeatherServicePromptIT.java +++ b/auto-configurations/models/spring-ai-autoconfigure-model-mistral-ai/src/test/java/org/springframework/ai/model/mistralai/autoconfigure/tool/WeatherServicePromptIT.java @@ -73,11 +73,10 @@ void promptFunctionCall() { var promptOptions = MistralAiChatOptions.builder() .toolChoice(ToolChoice.AUTO) - .functionCallbacks( - List.of(FunctionToolCallback.builder("CurrentWeatherService", new MyWeatherService()) - .description("Get the current weather in requested location") - .inputType(MyWeatherService.Request.class) - .build())) + .toolCallbacks(List.of(FunctionToolCallback.builder("CurrentWeatherService", new MyWeatherService()) + .description("Get the current weather in requested location") + .inputType(MyWeatherService.Request.class) + .build())) .build(); ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), promptOptions)); diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-ollama/src/test/java/org/springframework/ai/model/ollama/autoconfigure/tool/FunctionCallbackInPromptIT.java b/auto-configurations/models/spring-ai-autoconfigure-model-ollama/src/test/java/org/springframework/ai/model/ollama/autoconfigure/tool/FunctionCallbackInPromptIT.java index 76fcaf77f24..8098f0b1866 100644 --- a/auto-configurations/models/spring-ai-autoconfigure-model-ollama/src/test/java/org/springframework/ai/model/ollama/autoconfigure/tool/FunctionCallbackInPromptIT.java +++ b/auto-configurations/models/spring-ai-autoconfigure-model-ollama/src/test/java/org/springframework/ai/model/ollama/autoconfigure/tool/FunctionCallbackInPromptIT.java @@ -70,8 +70,7 @@ void functionCallTest() { "What are the weather conditions in San Francisco, Tokyo, and Paris? Find the temperature in Celsius for each of the three locations."); var promptOptions = OllamaOptions.builder() - .functionCallbacks(List.of(FunctionToolCallback - .builder("CurrentWeatherService", new MockWeatherService()) + .toolCallbacks(List.of(FunctionToolCallback.builder("CurrentWeatherService", new MockWeatherService()) .description( "Find the weather conditions, forecasts, and temperatures for a location, like a city or state.") .inputType(MockWeatherService.Request.class) @@ -96,8 +95,7 @@ void streamingFunctionCallTest() { "What are the weather conditions in San Francisco, Tokyo, and Paris? Find the temperature in Celsius for each of the three locations."); var promptOptions = OllamaOptions.builder() - .functionCallbacks(List.of(FunctionToolCallback - .builder("CurrentWeatherService", new MockWeatherService()) + .toolCallbacks(List.of(FunctionToolCallback.builder("CurrentWeatherService", new MockWeatherService()) .description( "Find the weather conditions, forecasts, and temperatures for a location, like a city or state.") .inputType(MockWeatherService.Request.class) diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-ollama/src/test/java/org/springframework/ai/model/ollama/autoconfigure/tool/OllamaFunctionCallbackIT.java b/auto-configurations/models/spring-ai-autoconfigure-model-ollama/src/test/java/org/springframework/ai/model/ollama/autoconfigure/tool/OllamaFunctionCallbackIT.java index 56f1db05660..cbaf6362bce 100644 --- a/auto-configurations/models/spring-ai-autoconfigure-model-ollama/src/test/java/org/springframework/ai/model/ollama/autoconfigure/tool/OllamaFunctionCallbackIT.java +++ b/auto-configurations/models/spring-ai-autoconfigure-model-ollama/src/test/java/org/springframework/ai/model/ollama/autoconfigure/tool/OllamaFunctionCallbackIT.java @@ -75,7 +75,7 @@ void functionCallTest() { "What are the weather conditions in San Francisco, Tokyo, and Paris? Find the temperature in Celsius for each of the three locations."); ChatResponse response = chatModel - .call(new Prompt(List.of(userMessage), OllamaOptions.builder().function("WeatherInfo").build())); + .call(new Prompt(List.of(userMessage), OllamaOptions.builder().toolNames("WeatherInfo").build())); logger.info("Response: " + response); @@ -93,7 +93,7 @@ void streamFunctionCallTest() { "What are the weather conditions in San Francisco, Tokyo, and Paris? Find the temperature in Celsius for each of the three locations."); Flux response = chatModel - .stream(new Prompt(List.of(userMessage), OllamaOptions.builder().function("WeatherInfo").build())); + .stream(new Prompt(List.of(userMessage), OllamaOptions.builder().toolNames("WeatherInfo").build())); String content = response.collectList() .block() diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-ollama/src/test/java/org/springframework/ai/model/ollama/autoconfigure/tool/OllamaFunctionToolBeanIT.java b/auto-configurations/models/spring-ai-autoconfigure-model-ollama/src/test/java/org/springframework/ai/model/ollama/autoconfigure/tool/OllamaFunctionToolBeanIT.java index faaf0ea5fcb..a565870131b 100644 --- a/auto-configurations/models/spring-ai-autoconfigure-model-ollama/src/test/java/org/springframework/ai/model/ollama/autoconfigure/tool/OllamaFunctionToolBeanIT.java +++ b/auto-configurations/models/spring-ai-autoconfigure-model-ollama/src/test/java/org/springframework/ai/model/ollama/autoconfigure/tool/OllamaFunctionToolBeanIT.java @@ -121,7 +121,7 @@ void streamFunctionCallTest() { "What are the weather conditions in San Francisco, Tokyo, and Paris? Find the temperature in Celsius for each of the three locations."); Flux response = chatModel - .stream(new Prompt(List.of(userMessage), OllamaOptions.builder().function("weatherInfo").build())); + .stream(new Prompt(List.of(userMessage), OllamaOptions.builder().toolNames("weatherInfo").build())); String content = response.collectList() .block() diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-ollama/src/test/kotlin/org/springframework/ai/model/ollama/autoconfigure/tool/FunctionCallbackContextKotlinIT.kt b/auto-configurations/models/spring-ai-autoconfigure-model-ollama/src/test/kotlin/org/springframework/ai/model/ollama/autoconfigure/tool/FunctionCallbackContextKotlinIT.kt index 1b7d1007764..63ad113bcc8 100644 --- a/auto-configurations/models/spring-ai-autoconfigure-model-ollama/src/test/kotlin/org/springframework/ai/model/ollama/autoconfigure/tool/FunctionCallbackContextKotlinIT.kt +++ b/auto-configurations/models/spring-ai-autoconfigure-model-ollama/src/test/kotlin/org/springframework/ai/model/ollama/autoconfigure/tool/FunctionCallbackContextKotlinIT.kt @@ -68,7 +68,7 @@ class FunctionCallbackResolverKotlinIT : BaseOllamaIT() { "What are the weather conditions in San Francisco, Tokyo, and Paris? Find the temperature in Celsius for each of the three locations.") val response = chatModel - .call(Prompt(listOf(userMessage), OllamaOptions.builder().function("weatherInfo").build())) + .call(Prompt(listOf(userMessage), OllamaOptions.builder().toolNames("weatherInfo").build())) logger.info("Response: $response") diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-openai/src/test/java/org/springframework/ai/model/openai/autoconfigure/tool/FunctionCallbackInPrompt2IT.java b/auto-configurations/models/spring-ai-autoconfigure-model-openai/src/test/java/org/springframework/ai/model/openai/autoconfigure/tool/FunctionCallbackInPrompt2IT.java index 7b334589520..a66b94b3e77 100644 --- a/auto-configurations/models/spring-ai-autoconfigure-model-openai/src/test/java/org/springframework/ai/model/openai/autoconfigure/tool/FunctionCallbackInPrompt2IT.java +++ b/auto-configurations/models/spring-ai-autoconfigure-model-openai/src/test/java/org/springframework/ai/model/openai/autoconfigure/tool/FunctionCallbackInPrompt2IT.java @@ -60,7 +60,7 @@ void functionCallTest() { String content = ChatClient.builder(chatModel).build().prompt() .user("What's the weather like in San Francisco, Tokyo, and Paris?") - .functions(FunctionToolCallback + .tools(FunctionToolCallback .builder("CurrentWeatherService", new MockWeatherService()) .description("Get the weather in location") .inputType(MockWeatherService.Request.class) @@ -88,7 +88,7 @@ record LightInfo(String roomName, boolean isOn) { // @formatter:off String content = ChatClient.builder(chatModel).build().prompt() .user("Turn the light on in the kitchen and in the living room!") - .functions(FunctionToolCallback + .tools(FunctionToolCallback .builder("turnLight", (LightInfo lightInfo) -> { logger.info("Turning light to [" + lightInfo.isOn + "] in " + lightInfo.roomName()); state.put(lightInfo.roomName(), lightInfo.isOn()); @@ -114,7 +114,7 @@ void functionCallTest2() { // @formatter:off String content = ChatClient.builder(chatModel).build().prompt() .user("What's the weather like in Amsterdam?") - .functions(FunctionToolCallback + .tools(FunctionToolCallback .builder("CurrentWeatherService", input -> "18 degrees Celsius") .description("Get the weather in location") .inputType(MockWeatherService.Request.class) @@ -138,7 +138,7 @@ void streamingFunctionCallTest() { // @formatter:off String content = ChatClient.builder(chatModel).build().prompt() .user("What's the weather like in San Francisco, Tokyo, and Paris?") - .functions(FunctionToolCallback + .tools(FunctionToolCallback .builder("CurrentWeatherService", new MockWeatherService()) .description("Get the weather in location") .inputType(MockWeatherService.Request.class) diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-openai/src/test/java/org/springframework/ai/model/openai/autoconfigure/tool/FunctionCallbackInPromptIT.java b/auto-configurations/models/spring-ai-autoconfigure-model-openai/src/test/java/org/springframework/ai/model/openai/autoconfigure/tool/FunctionCallbackInPromptIT.java index c3073a56dad..5550d3fcee3 100644 --- a/auto-configurations/models/spring-ai-autoconfigure-model-openai/src/test/java/org/springframework/ai/model/openai/autoconfigure/tool/FunctionCallbackInPromptIT.java +++ b/auto-configurations/models/spring-ai-autoconfigure-model-openai/src/test/java/org/springframework/ai/model/openai/autoconfigure/tool/FunctionCallbackInPromptIT.java @@ -62,7 +62,7 @@ void functionCallTest() { "What's the weather like in San Francisco, Tokyo, and Paris?"); var promptOptions = OpenAiChatOptions.builder() - .functionCallbacks( + .toolCallbacks( List.of(FunctionToolCallback.builder("CurrentWeatherService", new MockWeatherService()) .description("Get the weather in location") .inputType(MockWeatherService.Request.class) @@ -91,7 +91,7 @@ void streamingFunctionCallTest() { "What's the weather like in San Francisco, Tokyo, and Paris?"); var promptOptions = OpenAiChatOptions.builder() - .functionCallbacks( + .toolCallbacks( List.of(FunctionToolCallback.builder("CurrentWeatherService", new MockWeatherService()) .description("Get the weather in location") .inputType(MockWeatherService.Request.class) diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-openai/src/test/java/org/springframework/ai/model/openai/autoconfigure/tool/FunctionCallbackWithPlainFunctionBeanIT.java b/auto-configurations/models/spring-ai-autoconfigure-model-openai/src/test/java/org/springframework/ai/model/openai/autoconfigure/tool/FunctionCallbackWithPlainFunctionBeanIT.java index 7216afeaedd..a8f511d40e9 100644 --- a/auto-configurations/models/spring-ai-autoconfigure-model-openai/src/test/java/org/springframework/ai/model/openai/autoconfigure/tool/FunctionCallbackWithPlainFunctionBeanIT.java +++ b/auto-configurations/models/spring-ai-autoconfigure-model-openai/src/test/java/org/springframework/ai/model/openai/autoconfigure/tool/FunctionCallbackWithPlainFunctionBeanIT.java @@ -80,7 +80,7 @@ void functionCallingVoidInput() { UserMessage userMessage = new UserMessage("Turn the light on in the living room"); ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), - OpenAiChatOptions.builder().function("turnLivingRoomLightOn").build())); + OpenAiChatOptions.builder().toolNames("turnLivingRoomLightOn").build())); logger.info("Response: {}", response); assertThat(feedback).hasSize(1); @@ -98,7 +98,7 @@ void functionCallingSupplier() { UserMessage userMessage = new UserMessage("Turn the light on in the living room"); ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), - OpenAiChatOptions.builder().function("turnLivingRoomLightOnSupplier").build())); + OpenAiChatOptions.builder().toolNames("turnLivingRoomLightOnSupplier").build())); logger.info("Response: {}", response); assertThat(feedback).hasSize(1); @@ -116,7 +116,7 @@ void functionCallingVoidOutput() { UserMessage userMessage = new UserMessage("Turn the light on in the kitchen and in the living room"); ChatResponse response = chatModel - .call(new Prompt(List.of(userMessage), OpenAiChatOptions.builder().function("turnLight").build())); + .call(new Prompt(List.of(userMessage), OpenAiChatOptions.builder().toolNames("turnLight").build())); logger.info("Response: {}", response); assertThat(feedback).hasSize(2); @@ -135,7 +135,7 @@ void functionCallingConsumer() { UserMessage userMessage = new UserMessage("Turn the light on in the kitchen and in the living room"); ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), - OpenAiChatOptions.builder().function("turnLightConsumer").build())); + OpenAiChatOptions.builder().toolNames("turnLightConsumer").build())); logger.info("Response: {}", response); assertThat(feedback).hasSize(2); @@ -174,7 +174,7 @@ void functionCallWithDirectBiFunction() { ChatClient chatClient = ChatClient.builder(chatModel).build(); String content = chatClient.prompt("What's the weather like in San Francisco, Tokyo, and Paris?") - .functions("weatherFunctionWithContext") + .tools("weatherFunctionWithContext") .toolContext(Map.of("sessionId", "123")) .call() .content(); @@ -186,7 +186,7 @@ void functionCallWithDirectBiFunction() { ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), OpenAiChatOptions.builder() - .function("weatherFunctionWithContext") + .toolNames("weatherFunctionWithContext") .toolContext(Map.of("sessionId", "123")) .build())); @@ -206,7 +206,7 @@ void functionCallWithBiFunctionClass() { ChatClient chatClient = ChatClient.builder(chatModel).build(); String content = chatClient.prompt("What's the weather like in San Francisco, Tokyo, and Paris?") - .functions("weatherFunctionWithClassBiFunction") + .tools("weatherFunctionWithClassBiFunction") .toolContext(Map.of("sessionId", "123")) .call() .content(); @@ -218,7 +218,7 @@ void functionCallWithBiFunctionClass() { ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), OpenAiChatOptions.builder() - .function("weatherFunctionWithClassBiFunction") + .toolNames("weatherFunctionWithClassBiFunction") .toolContext(Map.of("sessionId", "123")) .build())); @@ -240,7 +240,7 @@ void functionCallTest() { "What's the weather like in San Francisco, Tokyo, and Paris? You can call the following functions 'weatherFunction'"); ChatResponse response = chatModel.call( - new Prompt(List.of(userMessage), OpenAiChatOptions.builder().function("weatherFunction").build())); + new Prompt(List.of(userMessage), OpenAiChatOptions.builder().toolNames("weatherFunction").build())); logger.info("Response: {}", response); @@ -248,7 +248,7 @@ void functionCallTest() { // Test weatherFunctionTwo response = chatModel.call(new Prompt(List.of(userMessage), - OpenAiChatOptions.builder().function("weatherFunctionTwo").build())); + OpenAiChatOptions.builder().toolNames("weatherFunctionTwo").build())); logger.info("Response: {}", response); @@ -289,7 +289,7 @@ void streamFunctionCallTest() { "What's the weather like in San Francisco, Tokyo, and Paris? You can call the following functions 'weatherFunction'"); Flux response = chatModel.stream( - new Prompt(List.of(userMessage), OpenAiChatOptions.builder().function("weatherFunction").build())); + new Prompt(List.of(userMessage), OpenAiChatOptions.builder().toolNames("weatherFunction").build())); String content = response.collectList() .block() @@ -305,7 +305,7 @@ void streamFunctionCallTest() { // Test weatherFunctionTwo response = chatModel.stream(new Prompt(List.of(userMessage), - OpenAiChatOptions.builder().function("weatherFunctionTwo").build())); + OpenAiChatOptions.builder().toolNames("weatherFunctionTwo").build())); content = response.collectList() .block() diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-openai/src/test/java/org/springframework/ai/model/openai/autoconfigure/tool/OpenAiFunctionCallback2IT.java b/auto-configurations/models/spring-ai-autoconfigure-model-openai/src/test/java/org/springframework/ai/model/openai/autoconfigure/tool/OpenAiFunctionCallback2IT.java index bb22b7bfd27..0b71c6f618b 100644 --- a/auto-configurations/models/spring-ai-autoconfigure-model-openai/src/test/java/org/springframework/ai/model/openai/autoconfigure/tool/OpenAiFunctionCallback2IT.java +++ b/auto-configurations/models/spring-ai-autoconfigure-model-openai/src/test/java/org/springframework/ai/model/openai/autoconfigure/tool/OpenAiFunctionCallback2IT.java @@ -55,7 +55,7 @@ void functionCallTest() { // @formatter:off ChatClient chatClient = ChatClient.builder(chatModel) - .defaultFunctions("WeatherInfo") + .defaultTools("WeatherInfo") .defaultUser(u -> u.text("What's the weather like in {cities}?")) .build(); @@ -78,7 +78,7 @@ void streamFunctionCallTest() { // @formatter:off String content = ChatClient.builder(chatModel).build().prompt() - .functions("WeatherInfo") + .tools("WeatherInfo") .user("What's the weather like in San Francisco, Tokyo, and Paris?") .stream().content() .collectList().block().stream().collect(Collectors.joining()); diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-openai/src/test/java/org/springframework/ai/model/openai/autoconfigure/tool/OpenAiFunctionCallbackIT.java b/auto-configurations/models/spring-ai-autoconfigure-model-openai/src/test/java/org/springframework/ai/model/openai/autoconfigure/tool/OpenAiFunctionCallbackIT.java index 67d6d807690..b490eb04edf 100644 --- a/auto-configurations/models/spring-ai-autoconfigure-model-openai/src/test/java/org/springframework/ai/model/openai/autoconfigure/tool/OpenAiFunctionCallbackIT.java +++ b/auto-configurations/models/spring-ai-autoconfigure-model-openai/src/test/java/org/springframework/ai/model/openai/autoconfigure/tool/OpenAiFunctionCallbackIT.java @@ -63,7 +63,7 @@ void functionCallTest() { UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?"); ChatResponse response = chatModel - .call(new Prompt(List.of(userMessage), OpenAiChatOptions.builder().function("WeatherInfo").build())); + .call(new Prompt(List.of(userMessage), OpenAiChatOptions.builder().toolNames("WeatherInfo").build())); logger.info("Response: {}", response); @@ -82,7 +82,7 @@ void streamFunctionCallTest() { "What's the weather like in San Francisco, Tokyo, and Paris? You can call the following functions 'WeatherInfo'"); Flux response = chatModel - .stream(new Prompt(List.of(userMessage), OpenAiChatOptions.builder().function("WeatherInfo").build())); + .stream(new Prompt(List.of(userMessage), OpenAiChatOptions.builder().toolNames("WeatherInfo").build())); String content = response.collectList() .block() diff --git a/auto-configurations/models/tool/spring-ai-autoconfigure-model-tool/src/main/java/org/springframework/ai/model/tool/autoconfigure/ToolCallingAutoConfiguration.java b/auto-configurations/models/tool/spring-ai-autoconfigure-model-tool/src/main/java/org/springframework/ai/model/tool/autoconfigure/ToolCallingAutoConfiguration.java index b653b67d7c7..0090a6610ef 100644 --- a/auto-configurations/models/tool/spring-ai-autoconfigure-model-tool/src/main/java/org/springframework/ai/model/tool/autoconfigure/ToolCallingAutoConfiguration.java +++ b/auto-configurations/models/tool/spring-ai-autoconfigure-model-tool/src/main/java/org/springframework/ai/model/tool/autoconfigure/ToolCallingAutoConfiguration.java @@ -22,8 +22,8 @@ import io.micrometer.observation.ObservationRegistry; import org.springframework.ai.chat.model.ChatModel; -import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.tool.ToolCallingManager; +import org.springframework.ai.tool.ToolCallback; import org.springframework.ai.tool.ToolCallbackProvider; import org.springframework.ai.tool.execution.DefaultToolExecutionExceptionProcessor; import org.springframework.ai.tool.execution.ToolExecutionExceptionProcessor; @@ -52,9 +52,9 @@ public class ToolCallingAutoConfiguration { @Bean @ConditionalOnMissingBean ToolCallbackResolver toolCallbackResolver(GenericApplicationContext applicationContext, - List functionCallbacks, List tcbProviders) { + List toolCallbacks, List tcbProviders) { - List allFunctionAndToolCallbacks = new ArrayList<>(functionCallbacks); + List allFunctionAndToolCallbacks = new ArrayList<>(toolCallbacks); tcbProviders.stream().map(pr -> List.of(pr.getToolCallbacks())).forEach(allFunctionAndToolCallbacks::addAll); var staticToolCallbackResolver = new StaticToolCallbackResolver(allFunctionAndToolCallbacks); diff --git a/auto-configurations/models/tool/spring-ai-autoconfigure-model-tool/src/test/java/org/springframework/ai/model/tool/autoconfigure/ToolCallingAutoConfigurationTests.java b/auto-configurations/models/tool/spring-ai-autoconfigure-model-tool/src/test/java/org/springframework/ai/model/tool/autoconfigure/ToolCallingAutoConfigurationTests.java index bb03817dfb8..be73e1d2b2a 100644 --- a/auto-configurations/models/tool/spring-ai-autoconfigure-model-tool/src/test/java/org/springframework/ai/model/tool/autoconfigure/ToolCallingAutoConfigurationTests.java +++ b/auto-configurations/models/tool/spring-ai-autoconfigure-model-tool/src/test/java/org/springframework/ai/model/tool/autoconfigure/ToolCallingAutoConfigurationTests.java @@ -20,7 +20,6 @@ import org.junit.jupiter.api.Test; -import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.tool.DefaultToolCallingManager; import org.springframework.ai.model.tool.ToolCallingManager; import org.springframework.ai.tool.StaticToolCallbackProvider; @@ -135,18 +134,16 @@ public Function weatherFunction1() { } @Bean - public FunctionCallback functionCallbacks3() { - return FunctionCallback.builder() - .function("getCurrentWeather3", (Request request) -> "15.0°C") + public ToolCallback functionCallbacks3() { + return FunctionToolCallback.builder("getCurrentWeather3", (Request request) -> "15.0°C") .description("Gets the weather in location") .inputType(Request.class) .build(); } @Bean - public FunctionCallback functionCallbacks4() { - return FunctionCallback.builder() - .function("getCurrentWeather4", (Request request) -> "15.0°C") + public ToolCallback functionCallbacks4() { + return FunctionToolCallback.builder("getCurrentWeather4", (Request request) -> "15.0°C") .description("Gets the weather in location") .inputType(Request.class) .build(); diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java index 2292041ee5d..dbfbee561c8 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java @@ -32,7 +32,6 @@ import org.springframework.ai.anthropic.api.AnthropicApi; import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionRequest; -import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.tool.ToolCallingChatOptions; import org.springframework.ai.tool.ToolCallback; import org.springframework.lang.Nullable; @@ -44,6 +43,7 @@ * @author Christian Tzolov * @author Thomas Vitale * @author Alexandros Pappas + * @author Ilayaperumal Gopinathan * @since 1.0.0 */ @JsonInclude(Include.NON_NULL) @@ -64,7 +64,7 @@ public class AnthropicChatOptions implements ToolCallingChatOptions { * completion requests. */ @JsonIgnore - private List toolCallbacks = new ArrayList<>(); + private List toolCallbacks = new ArrayList<>(); /** * Collection of tool names to be resolved at runtime and used for tool calling in the @@ -186,13 +186,13 @@ public void setThinking(ChatCompletionRequest.ThinkingConfig thinking) { @Override @JsonIgnore - public List getToolCallbacks() { + public List getToolCallbacks() { return this.toolCallbacks; } @Override @JsonIgnore - public void setToolCallbacks(List toolCallbacks) { + public void setToolCallbacks(List toolCallbacks) { Assert.notNull(toolCallbacks, "toolCallbacks cannot be null"); Assert.noNullElements(toolCallbacks, "toolCallbacks cannot contain null elements"); this.toolCallbacks = toolCallbacks; @@ -226,34 +226,6 @@ public void setInternalToolExecutionEnabled(@Nullable Boolean internalToolExecut this.internalToolExecutionEnabled = internalToolExecutionEnabled; } - @Override - @Deprecated - @JsonIgnore - public List getFunctionCallbacks() { - return this.getToolCallbacks(); - } - - @Override - @Deprecated - @JsonIgnore - public void setFunctionCallbacks(List functionCallbacks) { - this.setToolCallbacks(functionCallbacks); - } - - @Override - @Deprecated - @JsonIgnore - public Set getFunctions() { - return this.getToolNames(); - } - - @Override - @Deprecated - @JsonIgnore - public void setFunctions(Set functionNames) { - this.setToolNames(functionNames); - } - @Override @JsonIgnore public Double getFrequencyPenalty() { @@ -266,19 +238,6 @@ public Double getPresencePenalty() { return null; } - @Override - @Deprecated - @JsonIgnore - public Boolean getProxyToolCalls() { - return this.internalToolExecutionEnabled != null ? !this.internalToolExecutionEnabled : null; - } - - @Deprecated - @JsonIgnore - public void setProxyToolCalls(Boolean proxyToolCalls) { - this.internalToolExecutionEnabled = proxyToolCalls != null ? !proxyToolCalls : null; - } - @Override @JsonIgnore public Map getToolContext() { @@ -387,12 +346,12 @@ public Builder thinking(AnthropicApi.ThinkingType type, Integer budgetTokens) { return this; } - public Builder toolCallbacks(List toolCallbacks) { + public Builder toolCallbacks(List toolCallbacks) { this.options.setToolCallbacks(toolCallbacks); return this; } - public Builder toolCallbacks(FunctionCallback... toolCallbacks) { + public Builder toolCallbacks(ToolCallback... toolCallbacks) { Assert.notNull(toolCallbacks, "toolCallbacks cannot be null"); this.options.toolCallbacks.addAll(Arrays.asList(toolCallbacks)); return this; @@ -415,29 +374,6 @@ public Builder internalToolExecutionEnabled(@Nullable Boolean internalToolExecut return this; } - @Deprecated - public Builder functionCallbacks(List functionCallbacks) { - return toolCallbacks(functionCallbacks); - } - - @Deprecated - public Builder functions(Set functionNames) { - return toolNames(functionNames); - } - - @Deprecated - public Builder function(String functionName) { - return toolNames(functionName); - } - - @Deprecated - public Builder proxyToolCalls(Boolean proxyToolCalls) { - if (proxyToolCalls != null) { - this.options.setInternalToolExecutionEnabled(!proxyToolCalls); - } - return this; - } - public Builder toolContext(Map toolContext) { if (this.options.toolContext == null) { this.options.toolContext = toolContext; diff --git a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelIT.java b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelIT.java index a13c6cca869..e629c3411ad 100644 --- a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelIT.java +++ b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelIT.java @@ -424,7 +424,7 @@ void testToolUseContentBlock() { var promptOptions = AnthropicChatOptions.builder() .model(AnthropicApi.ChatModel.CLAUDE_3_OPUS.getName()) - .functionCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) + .toolCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) .description( "Get the weather in location. Return temperature in 36°F or 36°C format. Use multi-turn if needed.") .inputType(MockWeatherService.Request.class) diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java index c301bdd1d4f..78651bd5a62 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java @@ -32,7 +32,6 @@ import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; -import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.tool.ToolCallingChatOptions; import org.springframework.ai.tool.ToolCallback; import org.springframework.lang.Nullable; @@ -185,7 +184,7 @@ public class AzureOpenAiChatOptions implements ToolCallingChatOptions { * completion requests. */ @JsonIgnore - private List toolCallbacks = new ArrayList<>(); + private List toolCallbacks = new ArrayList<>(); /** * Collection of tool names to be resolved at runtime and used for tool calling in the @@ -202,13 +201,13 @@ public class AzureOpenAiChatOptions implements ToolCallingChatOptions { @Override @JsonIgnore - public List getToolCallbacks() { + public List getToolCallbacks() { return this.toolCallbacks; } @Override @JsonIgnore - public void setToolCallbacks(List toolCallbacks) { + public void setToolCallbacks(List toolCallbacks) { Assert.notNull(toolCallbacks, "toolCallbacks cannot be null"); Assert.noNullElements(toolCallbacks, "toolCallbacks cannot contain null elements"); this.toolCallbacks = toolCallbacks; @@ -256,9 +255,9 @@ public static AzureOpenAiChatOptions fromOptions(AzureOpenAiChatOptions fromOpti .temperature(fromOptions.getTemperature()) .topP(fromOptions.getTopP()) .user(fromOptions.getUser()) - .functionCallbacks(fromOptions.getFunctionCallbacks() != null - ? new ArrayList<>(fromOptions.getFunctionCallbacks()) : null) - .functions(fromOptions.getFunctions() != null ? new HashSet<>(fromOptions.getFunctions()) : null) + .toolCallbacks( + fromOptions.getToolCallbacks() != null ? new ArrayList<>(fromOptions.getToolCallbacks()) : null) + .toolNames(fromOptions.getToolNames() != null ? new HashSet<>(fromOptions.getToolNames()) : null) .responseFormat(fromOptions.getResponseFormat()) .seed(fromOptions.getSeed()) .logprobs(fromOptions.isLogprobs()) @@ -380,27 +379,6 @@ public void setTopP(Double topP) { this.topP = topP; } - @Override - @Deprecated - @JsonIgnore - public List getFunctionCallbacks() { - return this.getToolCallbacks(); - } - - @Override - @Deprecated - @JsonIgnore - public void setFunctionCallbacks(List functionCallbacks) { - this.setToolCallbacks(functionCallbacks); - } - - @Override - @Deprecated - @JsonIgnore - public Set getFunctions() { - return this.getToolNames(); - } - public void setFunctions(Set functions) { this.setToolNames(functions); } @@ -451,19 +429,6 @@ public void setEnhancements(AzureChatEnhancementConfiguration enhancements) { this.enhancements = enhancements; } - @Override - @Deprecated - @JsonIgnore - public Boolean getProxyToolCalls() { - return this.internalToolExecutionEnabled != null ? !this.internalToolExecutionEnabled : null; - } - - @Deprecated - @JsonIgnore - public void setProxyToolCalls(Boolean proxyToolCalls) { - this.internalToolExecutionEnabled = proxyToolCalls != null ? !proxyToolCalls : null; - } - @Override public Map getToolContext() { return this.toolContext; @@ -583,34 +548,11 @@ public Builder user(String user) { return this; } - @Deprecated - public Builder functionCallbacks(List functionCallbacks) { - return toolCallbacks(functionCallbacks); - } - - @Deprecated - public Builder functions(Set functionNames) { - return toolNames(functionNames); - } - - @Deprecated - public Builder function(String functionName) { - return toolNames(functionName); - } - public Builder responseFormat(AzureOpenAiResponseFormat responseFormat) { this.options.responseFormat = responseFormat; return this; } - @Deprecated - public Builder proxyToolCalls(Boolean proxyToolCalls) { - if (proxyToolCalls != null) { - this.options.setInternalToolExecutionEnabled(!proxyToolCalls); - } - return this; - } - public Builder seed(Long seed) { this.options.seed = seed; return this; @@ -646,12 +588,12 @@ public Builder streamOptions(ChatCompletionStreamOptions streamOptions) { return this; } - public Builder toolCallbacks(List toolCallbacks) { + public Builder toolCallbacks(List toolCallbacks) { this.options.setToolCallbacks(toolCallbacks); return this; } - public Builder toolCallbacks(FunctionCallback... toolCallbacks) { + public Builder toolCallbacks(ToolCallback... toolCallbacks) { Assert.notNull(toolCallbacks, "toolCallbacks cannot be null"); this.options.toolCallbacks.addAll(Arrays.asList(toolCallbacks)); return this; diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/AzureOpenAiChatModelFunctionCallIT.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/AzureOpenAiChatModelFunctionCallIT.java index f5c7cfc6330..699d0ac05cb 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/AzureOpenAiChatModelFunctionCallIT.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/AzureOpenAiChatModelFunctionCallIT.java @@ -70,7 +70,7 @@ void functionCallTest() { var promptOptions = AzureOpenAiChatOptions.builder() .deploymentName(this.selectedModel) - .functionCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) + .toolCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) .description("Get the current weather in a given location") .inputType(MockWeatherService.Request.class) .build())) @@ -98,7 +98,7 @@ void functionCallSequentialTest() { var promptOptions = AzureOpenAiChatOptions.builder() .deploymentName(this.selectedModel) - .functionCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) + .toolCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) .description("Get the current weather in a given location") .inputType(MockWeatherService.Request.class) .build())) @@ -119,7 +119,7 @@ void streamFunctionCallTest() { var promptOptions = AzureOpenAiChatOptions.builder() .deploymentName(this.selectedModel) - .functionCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) + .toolCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) .description("Get the current weather in a given location") .inputType(MockWeatherService.Request.class) .build())) @@ -156,7 +156,7 @@ void streamFunctionCallUsageTest() { var promptOptions = AzureOpenAiChatOptions.builder() .deploymentName(this.selectedModel) - .functionCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) + .toolCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) .description("Get the current weather in a given location") .inputType(MockWeatherService.Request.class) .build())) @@ -182,7 +182,7 @@ void functionCallSequentialAndStreamTest() { var promptOptions = AzureOpenAiChatOptions.builder() .deploymentName(this.selectedModel) - .functionCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) + .toolCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) .description("Get the current weather in a given location") .inputType(MockWeatherService.Request.class) .build())) diff --git a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/api/ConverseApiUtils.java b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/api/ConverseApiUtils.java index fc233beb5ea..a19de831a7e 100644 --- a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/api/ConverseApiUtils.java +++ b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/api/ConverseApiUtils.java @@ -329,7 +329,7 @@ public static Document getChatOptionsAdditionalModelRequestFields(ChatOptions de attributes.remove("proxyToolCalls"); attributes.remove("functions"); attributes.remove("toolContext"); - attributes.remove("functionCallbacks"); + attributes.remove("toolCallbacks"); attributes.remove("toolCallbacks"); attributes.remove("toolNames"); diff --git a/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/ChatCompletionRequestTests.java b/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/ChatCompletionRequestTests.java index c2d53bf2066..b229912039a 100644 --- a/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/ChatCompletionRequestTests.java +++ b/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/ChatCompletionRequestTests.java @@ -24,6 +24,7 @@ import org.springframework.ai.minimax.api.MiniMaxApi; import org.springframework.ai.minimax.api.MockWeatherService; import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.ai.tool.function.FunctionToolCallback; import static org.assertj.core.api.Assertions.assertThat; @@ -64,16 +65,13 @@ public void promptOptionsTools() { var client = new MiniMaxChatModel(new MiniMaxApi("TEST"), MiniMaxChatOptions.builder().model("DEFAULT_MODEL").build()); - var request = client.createRequest(new Prompt("Test message content", - MiniMaxChatOptions.builder() - .model("PROMPT_MODEL") - .functionCallbacks(List.of(FunctionCallback.builder() - .function(TOOL_FUNCTION_NAME, new MockWeatherService()) - .description("Get the weather in location") - .inputType(MockWeatherService.Request.class) - .build())) - .build()), - false); + var request = client.createRequest(new Prompt("Test message content", MiniMaxChatOptions.builder() + .model("PROMPT_MODEL") + .functionCallbacks(List.of(FunctionToolCallback.builder(TOOL_FUNCTION_NAME, new MockWeatherService()) + .description("Get the weather in location") + .inputType(MockWeatherService.Request.class) + .build())) + .build()), false); assertThat(client.getFunctionCallbackRegister()).hasSize(1); assertThat(client.getFunctionCallbackRegister()).containsKeys(TOOL_FUNCTION_NAME); diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java index 65eee34aa94..7b8a3ee9136 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java @@ -60,7 +60,6 @@ import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionMessage.ToolCall; import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionRequest; import org.springframework.ai.model.ModelOptionsUtils; -import org.springframework.ai.model.function.FunctionCallingOptions; import org.springframework.ai.model.tool.DefaultToolExecutionEligibilityPredicate; import org.springframework.ai.model.tool.ToolCallingChatOptions; import org.springframework.ai.model.tool.ToolCallingManager; @@ -381,10 +380,6 @@ Prompt buildRequestPrompt(Prompt prompt) { runtimeOptions = ModelOptionsUtils.copyToTarget(toolCallingChatOptions, ToolCallingChatOptions.class, MistralAiChatOptions.class); } - else if (prompt.getOptions() instanceof FunctionCallingOptions functionCallingOptions) { - runtimeOptions = ModelOptionsUtils.copyToTarget(functionCallingOptions, FunctionCallingOptions.class, - MistralAiChatOptions.class); - } else { runtimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class, MistralAiChatOptions.class); diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatOptions.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatOptions.java index e912c2ebcda..700e12dfe70 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatOptions.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatOptions.java @@ -33,7 +33,6 @@ import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionRequest.ResponseFormat; import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionRequest.ToolChoice; import org.springframework.ai.mistralai.api.MistralAiApi.FunctionTool; -import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.tool.ToolCallingChatOptions; import org.springframework.ai.tool.ToolCallback; import org.springframework.lang.Nullable; @@ -120,7 +119,7 @@ public class MistralAiChatOptions implements ToolCallingChatOptions { * completion requests. */ @JsonIgnore - private List toolCallbacks = new ArrayList<>(); + private List toolCallbacks = new ArrayList<>(); /** * Collection of tool names to be resolved at runtime and used for tool calling in the @@ -257,13 +256,13 @@ public void setTopP(Double topP) { @Override @JsonIgnore - public List getToolCallbacks() { + public List getToolCallbacks() { return this.toolCallbacks; } @Override @JsonIgnore - public void setToolCallbacks(List toolCallbacks) { + public void setToolCallbacks(List toolCallbacks) { Assert.notNull(toolCallbacks, "toolCallbacks cannot be null"); Assert.noNullElements(toolCallbacks, "toolCallbacks cannot contain null elements"); this.toolCallbacks = toolCallbacks; @@ -297,34 +296,6 @@ public void setInternalToolExecutionEnabled(@Nullable Boolean internalToolExecut this.internalToolExecutionEnabled = internalToolExecutionEnabled; } - @Override - @Deprecated - @JsonIgnore - public List getFunctionCallbacks() { - return this.getToolCallbacks(); - } - - @Override - @Deprecated - @JsonIgnore - public void setFunctionCallbacks(List functionCallbacks) { - this.setToolCallbacks(functionCallbacks); - } - - @Override - @Deprecated - @JsonIgnore - public Set getFunctions() { - return this.getToolNames(); - } - - @Override - @Deprecated - @JsonIgnore - public void setFunctions(Set functionNames) { - this.setToolNames(functionNames); - } - @Override @JsonIgnore public Double getFrequencyPenalty() { @@ -343,19 +314,6 @@ public Integer getTopK() { return null; } - @Override - @Deprecated - @JsonIgnore - public Boolean getProxyToolCalls() { - return this.internalToolExecutionEnabled != null ? !this.internalToolExecutionEnabled : null; - } - - @Deprecated - @JsonIgnore - public void setProxyToolCalls(Boolean proxyToolCalls) { - this.internalToolExecutionEnabled = proxyToolCalls != null ? !proxyToolCalls : null; - } - @Override @JsonIgnore public Map getToolContext() { @@ -463,12 +421,12 @@ public Builder toolChoice(ToolChoice toolChoice) { return this; } - public Builder toolCallbacks(List toolCallbacks) { + public Builder toolCallbacks(List toolCallbacks) { this.options.setToolCallbacks(toolCallbacks); return this; } - public Builder toolCallbacks(FunctionCallback... toolCallbacks) { + public Builder toolCallbacks(ToolCallback... toolCallbacks) { Assert.notNull(toolCallbacks, "toolCallbacks cannot be null"); this.options.toolCallbacks.addAll(Arrays.asList(toolCallbacks)); return this; @@ -491,29 +449,6 @@ public Builder internalToolExecutionEnabled(@Nullable Boolean internalToolExecut return this; } - @Deprecated - public Builder functionCallbacks(List functionCallbacks) { - return toolCallbacks(functionCallbacks); - } - - @Deprecated - public Builder functions(Set functionNames) { - return toolNames(functionNames); - } - - @Deprecated - public Builder function(String functionName) { - return toolNames(functionName); - } - - @Deprecated - public Builder proxyToolCalls(Boolean proxyToolCalls) { - if (proxyToolCalls != null) { - this.options.setInternalToolExecutionEnabled(!proxyToolCalls); - } - return this; - } - public Builder toolContext(Map toolContext) { if (this.options.toolContext == null) { this.options.toolContext = toolContext; diff --git a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatCompletionRequestTest.java b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatCompletionRequestTest.java index 379132ce361..d937609e76c 100644 --- a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatCompletionRequestTest.java +++ b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatCompletionRequestTest.java @@ -23,7 +23,6 @@ import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.mistralai.api.MistralAiApi; -import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.tool.ToolCallingChatOptions; import org.springframework.ai.tool.ToolCallback; import org.springframework.ai.tool.definition.ToolDefinition; @@ -96,7 +95,7 @@ void whenToolRuntimeOptionsThenMergeWithDefaults() { assertThat(((ToolCallingChatOptions) prompt.getOptions()).getToolCallbacks()).hasSize(2); assertThat(((ToolCallingChatOptions) prompt.getOptions()).getToolCallbacks() .stream() - .map(FunctionCallback::getName)).containsExactlyInAnyOrder("tool3", "tool4"); + .map(toolCallback -> toolCallback.getToolDefinition().name())).containsExactlyInAnyOrder("tool3", "tool4"); assertThat(((ToolCallingChatOptions) prompt.getOptions()).getToolNames()).containsExactlyInAnyOrder("tool3"); assertThat(((ToolCallingChatOptions) prompt.getOptions()).getToolContext()).containsEntry("key1", "value1") .containsEntry("key2", "valueB"); diff --git a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatModelIT.java b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatModelIT.java index 0a0c6f3c6ac..658dce1ec3d 100644 --- a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatModelIT.java +++ b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatModelIT.java @@ -204,7 +204,7 @@ void functionCallTest() { var promptOptions = MistralAiChatOptions.builder() .model(MistralAiApi.ChatModel.SMALL.getValue()) - .functionCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) + .toolCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) .description("Get the weather in location") .inputType(MockWeatherService.Request.class) .build())) @@ -229,7 +229,7 @@ void streamFunctionCallTest() { var promptOptions = MistralAiChatOptions.builder() .model(MistralAiApi.ChatModel.SMALL.getValue()) - .functionCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) + .toolCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) .description("Get the weather in location") .inputType(MockWeatherService.Request.class) .build())) @@ -315,7 +315,7 @@ void streamFunctionCallUsageTest() { var promptOptions = MistralAiChatOptions.builder() .model(MistralAiApi.ChatModel.SMALL.getValue()) - .functionCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) + .toolCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) .description("Get the weather in location") .inputType(MockWeatherService.Request.class) .build())) diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java index cd2b32f059e..a71be1ce2b2 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java @@ -33,8 +33,8 @@ import org.springframework.ai.embedding.EmbeddingOptions; import org.springframework.ai.model.ModelOptionsUtils; -import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.tool.ToolCallingChatOptions; +import org.springframework.ai.tool.ToolCallback; import org.springframework.lang.Nullable; import org.springframework.util.Assert; @@ -323,17 +323,17 @@ public class OllamaOptions implements ToolCallingChatOptions, EmbeddingOptions { /** * Tool Function Callbacks to register with the ChatModel. - * For Prompt Options the functionCallbacks are automatically enabled for the duration of the prompt execution. - * For Default Options the functionCallbacks are registered but disabled by default. Use the enableFunctions to set the functions + * For Prompt Options the toolCallbacks are automatically enabled for the duration of the prompt execution. + * For Default Options the toolCallbacks are registered but disabled by default. Use the enableFunctions to set the functions * from the registry to be used by the ChatModel chat completion requests. */ @JsonIgnore - private List toolCallbacks = new ArrayList<>(); + private List toolCallbacks = new ArrayList<>(); /** * List of functions, identified by their names, to configure for function calling in * the chat completion requests. - * Functions with those names must exist in the functionCallbacks registry. + * Functions with those names must exist in the toolCallbacks registry. * The {@link #toolCallbacks} from the PromptOptions are automatically enabled for the duration of the prompt execution. * Note that function enabled with the default options are enabled for all chat completion requests. This could impact the token count and the billing. * If the functions is set in a prompt options, then the enabled functions are only active for the duration of this prompt execution. @@ -706,13 +706,13 @@ public void setTruncate(Boolean truncate) { @Override @JsonIgnore - public List getToolCallbacks() { + public List getToolCallbacks() { return this.toolCallbacks; } @Override @JsonIgnore - public void setToolCallbacks(List toolCallbacks) { + public void setToolCallbacks(List toolCallbacks) { Assert.notNull(toolCallbacks, "toolCallbacks cannot be null"); Assert.noNullElements(toolCallbacks, "toolCallbacks cannot contain null elements"); this.toolCallbacks = toolCallbacks; @@ -746,53 +746,12 @@ public void setInternalToolExecutionEnabled(@Nullable Boolean internalToolExecut this.internalToolExecutionEnabled = internalToolExecutionEnabled; } - @Override - @Deprecated - @JsonIgnore - public List getFunctionCallbacks() { - return this.getToolCallbacks(); - } - - @Override - @Deprecated - @JsonIgnore - public void setFunctionCallbacks(List functionCallbacks) { - this.setToolCallbacks(functionCallbacks); - } - - @Override - @Deprecated - @JsonIgnore - public Set getFunctions() { - return this.getToolNames(); - } - - @Override - @Deprecated - @JsonIgnore - public void setFunctions(Set functions) { - this.setToolNames(functions); - } - @Override @JsonIgnore public Integer getDimensions() { return null; } - @Override - @Deprecated - @JsonIgnore - public Boolean getProxyToolCalls() { - return this.internalToolExecutionEnabled != null ? !this.internalToolExecutionEnabled : null; - } - - @Deprecated - @JsonIgnore - public void setProxyToolCalls(Boolean proxyToolCalls) { - this.internalToolExecutionEnabled = proxyToolCalls != null ? !proxyToolCalls : null; - } - @Override @JsonIgnore public Map getToolContext() { @@ -1043,12 +1002,12 @@ public Builder stop(List stop) { return this; } - public Builder toolCallbacks(List toolCallbacks) { + public Builder toolCallbacks(List toolCallbacks) { this.options.setToolCallbacks(toolCallbacks); return this; } - public Builder toolCallbacks(FunctionCallback... toolCallbacks) { + public Builder toolCallbacks(ToolCallback... toolCallbacks) { Assert.notNull(toolCallbacks, "toolCallbacks cannot be null"); this.options.toolCallbacks.addAll(Arrays.asList(toolCallbacks)); return this; @@ -1070,29 +1029,6 @@ public Builder internalToolExecutionEnabled(@Nullable Boolean internalToolExecut return this; } - @Deprecated - public Builder functionCallbacks(List functionCallbacks) { - return toolCallbacks(functionCallbacks); - } - - @Deprecated - public Builder functions(Set functions) { - return toolNames(functions); - } - - @Deprecated - public Builder function(String functionName) { - return toolNames(functionName); - } - - @Deprecated - public Builder proxyToolCalls(Boolean proxyToolCalls) { - if (proxyToolCalls != null) { - this.options.setInternalToolExecutionEnabled(!proxyToolCalls); - } - return this; - } - public Builder toolContext(Map toolContext) { if (this.options.toolContext == null) { this.options.toolContext = toolContext; diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelFunctionCallingIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelFunctionCallingIT.java index 25f800ca9fa..ef149203f65 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelFunctionCallingIT.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelFunctionCallingIT.java @@ -62,7 +62,7 @@ void functionCallTest() { var promptOptions = OllamaOptions.builder() .model(MODEL) - .functionCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) + .toolCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) .description( "Find the weather conditions, forecasts, and temperatures for a location, like a city or state.") .inputType(MockWeatherService.Request.class) @@ -85,7 +85,7 @@ void streamFunctionCallTest() { var promptOptions = OllamaOptions.builder() .model(MODEL) - .functionCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) + .toolCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) .description( "Find the weather conditions, forecasts, and temperatures for a location, like a city or state.") .inputType(MockWeatherService.Request.class) diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatRequestTests.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatRequestTests.java index c6d8e61c04b..2f8d10c9e69 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatRequestTests.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatRequestTests.java @@ -22,7 +22,6 @@ import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.tool.ToolCallingChatOptions; import org.springframework.ai.ollama.api.OllamaApi; import org.springframework.ai.ollama.api.OllamaOptions; @@ -69,7 +68,7 @@ void whenToolRuntimeOptionsThenMergeWithDefaults() { assertThat(((ToolCallingChatOptions) prompt.getOptions()).getToolCallbacks()).hasSize(2); assertThat(((ToolCallingChatOptions) prompt.getOptions()).getToolCallbacks() .stream() - .map(FunctionCallback::getName)).containsExactlyInAnyOrder("tool3", "tool4"); + .map(toolCallback -> toolCallback.getToolDefinition().name())).containsExactlyInAnyOrder("tool3", "tool4"); assertThat(((ToolCallingChatOptions) prompt.getOptions()).getToolNames()).containsExactlyInAnyOrder("tool3"); assertThat(((ToolCallingChatOptions) prompt.getOptions()).getToolContext()).containsEntry("key1", "value1") .containsEntry("key2", "valueB"); diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaModelOptionsTests.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaModelOptionsTests.java index 1eb6ff4582e..5667047215e 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaModelOptionsTests.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaModelOptionsTests.java @@ -101,7 +101,6 @@ public void testBooleanOptions() { .useMMap(true) .useMLock(false) .penalizeNewline(true) - .proxyToolCalls(true) .build(); var optionsMap = options.toMap(); @@ -128,9 +127,9 @@ public void testModelAndFormat() { @Test public void testFunctionAndToolOptions() { var options = OllamaOptions.builder() - .function("function1") - .function("function2") - .function("function3") + .toolNames("function1") + .toolNames("function2") + .toolNames("function3") .toolContext(Map.of("key1", "value1", "key2", "value2")) .build(); @@ -140,7 +139,7 @@ public void testFunctionAndToolOptions() { assertThat(optionsMap).doesNotContainKey("tool_context"); // But they are accessible through getters - assertThat(options.getFunctions()).containsExactlyInAnyOrder("function1", "function2", "function3"); + assertThat(options.getToolNames()).containsExactlyInAnyOrder("function1", "function2", "function3"); assertThat(options.getToolContext()) .containsExactlyInAnyOrderEntriesOf(Map.of("key1", "value1", "key2", "value2")); } @@ -151,9 +150,9 @@ public void testFunctionOptionsWithMutableSet() { functionSet.add("function1"); functionSet.add("function2"); - var options = OllamaOptions.builder().functions(functionSet).function("function3").build(); + var options = OllamaOptions.builder().toolNames(functionSet).toolNames("function3").build(); - assertThat(options.getFunctions()).containsExactlyInAnyOrder("function1", "function2", "function3"); + assertThat(options.getToolNames()).containsExactlyInAnyOrder("function1", "function2", "function3"); } @Test @@ -162,7 +161,7 @@ public void testFromOptions() { .model("llama2") .temperature(0.7) .topK(40) - .functions(Set.of("function1")) + .toolNames(Set.of("function1")) .build(); var copiedOptions = OllamaOptions.fromOptions(originalOptions); @@ -171,30 +170,30 @@ public void testFromOptions() { assertThat(copiedOptions.getModel()).isEqualTo("llama2"); assertThat(copiedOptions.getTemperature()).isEqualTo(0.7); assertThat(copiedOptions.getTopK()).isEqualTo(40); - assertThat(copiedOptions.getFunctions()).containsExactly("function1"); + assertThat(copiedOptions.getToolNames()).containsExactly("function1"); } @Test public void testFunctionOptionsNotInMap() { - var options = OllamaOptions.builder().model("llama2").functions(Set.of("function1")).build(); + var options = OllamaOptions.builder().model("llama2").toolNames(Set.of("function1")).build(); var optionsMap = options.toMap(); // Verify function-related fields are not included in the map due to @JsonIgnore assertThat(optionsMap).containsEntry("model", "llama2"); assertThat(optionsMap).doesNotContainKey("functions"); - assertThat(optionsMap).doesNotContainKey("functionCallbacks"); + assertThat(optionsMap).doesNotContainKey("toolCallbacks"); assertThat(optionsMap).doesNotContainKey("proxyToolCalls"); assertThat(optionsMap).doesNotContainKey("toolContext"); // But verify they are still accessible through getters - assertThat(options.getFunctions()).containsExactly("function1"); + assertThat(options.getToolNames()).containsExactly("function1"); } @SuppressWarnings("deprecation") @Test public void testDeprecatedMethods() { - var options = OllamaOptions.builder().model("llama2").temperature(0.7).topK(40).function("function1").build(); + var options = OllamaOptions.builder().model("llama2").temperature(0.7).topK(40).toolNames("function1").build(); var optionsMap = options.toMap(); assertThat(optionsMap).containsEntry("model", "llama2"); @@ -202,7 +201,7 @@ public void testDeprecatedMethods() { assertThat(optionsMap).containsEntry("top_k", 40); // Function is not in map but accessible via getter - assertThat(options.getFunctions()).containsExactly("function1"); + assertThat(options.getToolNames()).containsExactly("function1"); } } diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java index d6ea170e979..87636a18af2 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java @@ -57,7 +57,6 @@ import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.content.Media; import org.springframework.ai.model.ModelOptionsUtils; -import org.springframework.ai.model.function.FunctionCallingOptions; import org.springframework.ai.model.tool.DefaultToolExecutionEligibilityPredicate; import org.springframework.ai.model.tool.ToolCallingChatOptions; import org.springframework.ai.model.tool.ToolCallingManager; @@ -493,10 +492,6 @@ Prompt buildRequestPrompt(Prompt prompt) { runtimeOptions = ModelOptionsUtils.copyToTarget(toolCallingChatOptions, ToolCallingChatOptions.class, OpenAiChatOptions.class); } - else if (prompt.getOptions() instanceof FunctionCallingOptions functionCallingOptions) { - runtimeOptions = ModelOptionsUtils.copyToTarget(functionCallingOptions, FunctionCallingOptions.class, - OpenAiChatOptions.class); - } else { runtimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class, OpenAiChatOptions.class); diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java index 3025d5f9056..c5687add88d 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java @@ -31,7 +31,6 @@ import com.fasterxml.jackson.annotation.JsonProperty; import org.springframework.ai.model.ModelOptionsUtils; -import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.tool.ToolCallingChatOptions; import org.springframework.ai.openai.api.OpenAiApi; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.AudioParameters; @@ -198,7 +197,7 @@ public class OpenAiChatOptions implements ToolCallingChatOptions { * Collection of {@link ToolCallback}s to be used for tool calling in the chat completion requests. */ @JsonIgnore - private List toolCallbacks = new ArrayList<>(); + private List toolCallbacks = new ArrayList<>(); /** * Collection of tool names to be resolved at runtime and used for tool calling in the chat completion requests. @@ -440,23 +439,6 @@ public void setToolChoice(Object toolChoice) { this.toolChoice = toolChoice; } - @Override - @Deprecated - @JsonIgnore - public Boolean getProxyToolCalls() { - return this.getToolExecutionEnabled() != null ? !this.internalToolExecutionEnabled : null; - } - - private Boolean getToolExecutionEnabled() { - return this.internalToolExecutionEnabled; - } - - @Deprecated - @JsonIgnore - public void setProxyToolCalls(Boolean proxyToolCalls) { - this.internalToolExecutionEnabled = proxyToolCalls != null ? !proxyToolCalls : null; - } - public String getUser() { return this.user; } @@ -475,13 +457,13 @@ public void setParallelToolCalls(Boolean parallelToolCalls) { @Override @JsonIgnore - public List getToolCallbacks() { + public List getToolCallbacks() { return this.toolCallbacks; } @Override @JsonIgnore - public void setToolCallbacks(List toolCallbacks) { + public void setToolCallbacks(List toolCallbacks) { Assert.notNull(toolCallbacks, "toolCallbacks cannot be null"); Assert.noNullElements(toolCallbacks, "toolCallbacks cannot contain null elements"); this.toolCallbacks = toolCallbacks; @@ -515,34 +497,6 @@ public void setInternalToolExecutionEnabled(@Nullable Boolean internalToolExecut this.internalToolExecutionEnabled = internalToolExecutionEnabled; } - @Override - @Deprecated - @JsonIgnore - public List getFunctionCallbacks() { - return this.getToolCallbacks(); - } - - @Override - @Deprecated - @JsonIgnore - public void setFunctionCallbacks(List functionCallbacks) { - this.setToolCallbacks(functionCallbacks); - } - - @Override - @Deprecated - @JsonIgnore - public Set getFunctions() { - return this.getToolNames(); - } - - @Override - @Deprecated - @JsonIgnore - public void setFunctions(Set functionNames) { - this.setToolNames(functionNames); - } - public Map getHttpHeaders() { return this.httpHeaders; } @@ -767,12 +721,12 @@ public Builder parallelToolCalls(Boolean parallelToolCalls) { return this; } - public Builder toolCallbacks(List toolCallbacks) { + public Builder toolCallbacks(List toolCallbacks) { this.options.setToolCallbacks(toolCallbacks); return this; } - public Builder toolCallbacks(FunctionCallback... toolCallbacks) { + public Builder toolCallbacks(ToolCallback... toolCallbacks) { Assert.notNull(toolCallbacks, "toolCallbacks cannot be null"); this.options.toolCallbacks.addAll(Arrays.asList(toolCallbacks)); return this; @@ -795,29 +749,6 @@ public Builder internalToolExecutionEnabled(@Nullable Boolean internalToolExecut return this; } - @Deprecated - public Builder functionCallbacks(List functionCallbacks) { - return toolCallbacks(functionCallbacks); - } - - @Deprecated - public Builder functions(Set functionNames) { - return toolNames(functionNames); - } - - @Deprecated - public Builder function(String functionName) { - return toolNames(functionName); - } - - @Deprecated - public Builder proxyToolCalls(Boolean proxyToolCalls) { - if (proxyToolCalls != null) { - this.options.setInternalToolExecutionEnabled(!proxyToolCalls); - } - return this; - } - public Builder httpHeaders(Map httpHeaders) { this.options.httpHeaders = httpHeaders; return this; diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/ChatCompletionRequestTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/ChatCompletionRequestTests.java index 7d85cd3fe4f..1de3a339f4a 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/ChatCompletionRequestTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/ChatCompletionRequestTests.java @@ -23,12 +23,12 @@ import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.SimpleApiKey; -import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.tool.ToolCallingChatOptions; import org.springframework.ai.openai.api.OpenAiApi; import org.springframework.ai.openai.api.tool.MockWeatherService; import org.springframework.ai.tool.ToolCallback; import org.springframework.ai.tool.definition.ToolDefinition; +import org.springframework.ai.tool.function.FunctionToolCallback; import static org.assertj.core.api.Assertions.assertThat; @@ -66,7 +66,7 @@ void whenToolRuntimeOptionsThenMergeWithDefaults() { assertThat(((ToolCallingChatOptions) prompt.getOptions()).getToolCallbacks()).hasSize(2); assertThat(((ToolCallingChatOptions) prompt.getOptions()).getToolCallbacks() .stream() - .map(FunctionCallback::getName)).containsExactlyInAnyOrder("tool3", "tool4"); + .map(toolCallback -> toolCallback.getToolDefinition().name())).containsExactlyInAnyOrder("tool3", "tool4"); assertThat(((ToolCallingChatOptions) prompt.getOptions()).getToolNames()).containsExactlyInAnyOrder("tool3"); assertThat(((ToolCallingChatOptions) prompt.getOptions()).getToolContext()).containsEntry("key1", "value1") .containsEntry("key2", "valueB"); @@ -111,8 +111,7 @@ void promptOptionsTools() { var prompt = client.buildRequestPrompt(new Prompt("Test message content", OpenAiChatOptions.builder() .model("PROMPT_MODEL") - .functionCallbacks(List.of(FunctionCallback.builder() - .function(TOOL_FUNCTION_NAME, new MockWeatherService()) + .toolCallbacks(List.of(FunctionToolCallback.builder(TOOL_FUNCTION_NAME, new MockWeatherService()) .description("Get the weather in location") .inputType(MockWeatherService.Request.class) .build())) @@ -136,8 +135,7 @@ void defaultOptionsTools() { .openAiApi(OpenAiApi.builder().apiKey("TEST").build()) .defaultOptions(OpenAiChatOptions.builder() .model("DEFAULT_MODEL") - .functionCallbacks(List.of(FunctionCallback.builder() - .function(TOOL_FUNCTION_NAME, new MockWeatherService()) + .toolCallbacks(List.of(FunctionToolCallback.builder(TOOL_FUNCTION_NAME, new MockWeatherService()) .description("Get the weather in location") .inputType(MockWeatherService.Request.class) .build())) @@ -156,7 +154,7 @@ void defaultOptionsTools() { // Reference the default options tool by name at runtime prompt = client.buildRequestPrompt( - new Prompt("Test message content", OpenAiChatOptions.builder().function(TOOL_FUNCTION_NAME).build())); + new Prompt("Test message content", OpenAiChatOptions.builder().toolNames(TOOL_FUNCTION_NAME).build())); request = client.createRequest(prompt, false); assertThat(request.tools()).hasSize(1); diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiChatOptionsTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiChatOptionsTests.java index 3ae8e204624..d09808f1a31 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiChatOptionsTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiChatOptionsTests.java @@ -79,7 +79,7 @@ void testBuilderWithAllFields() { .store(false) .metadata(metadata) .reasoningEffort("medium") - .proxyToolCalls(false) + .internalToolExecutionEnabled(false) .httpHeaders(Map.of("header1", "value1")) .toolContext(toolContext) .build(); @@ -88,8 +88,8 @@ void testBuilderWithAllFields() { .extracting("model", "frequencyPenalty", "logitBias", "logprobs", "topLogprobs", "maxTokens", "maxCompletionTokens", "n", "outputModalities", "outputAudio", "presencePenalty", "responseFormat", "streamOptions", "seed", "stop", "temperature", "topP", "tools", "toolChoice", "user", - "parallelToolCalls", "store", "metadata", "reasoningEffort", "proxyToolCalls", "httpHeaders", - "toolContext") + "parallelToolCalls", "store", "metadata", "reasoningEffort", "internalToolExecutionEnabled", + "httpHeaders", "toolContext") .containsExactly("test-model", 0.5, logitBias, true, 5, 100, 50, 2, outputModalities, outputAudio, 0.8, responseFormat, streamOptions, 12345, stopSequences, 0.7, 0.9, tools, toolChoice, "test-user", true, false, metadata, "medium", false, Map.of("header1", "value1"), toolContext); @@ -138,7 +138,7 @@ void testCopy() { .store(true) .metadata(metadata) .reasoningEffort("low") - .proxyToolCalls(true) + .internalToolExecutionEnabled(true) .httpHeaders(Map.of("header1", "value1")) .build(); @@ -186,7 +186,7 @@ void testSetters() { options.setStore(false); options.setMetadata(metadata); options.setReasoningEffort("high"); - options.setProxyToolCalls(false); + options.setInternalToolExecutionEnabled(false); options.setHttpHeaders(Map.of("header2", "value2")); assertThat(options.getModel()).isEqualTo("test-model"); @@ -213,7 +213,7 @@ void testSetters() { assertThat(options.getStore()).isFalse(); assertThat(options.getMetadata()).isEqualTo(metadata); assertThat(options.getReasoningEffort()).isEqualTo("high"); - assertThat(options.getProxyToolCalls()).isFalse(); + assertThat(options.getInternalToolExecutionEnabled()).isFalse(); assertThat(options.getHttpHeaders()).isEqualTo(Map.of("header2", "value2")); assertThat(options.getStreamUsage()).isTrue(); options.setStreamUsage(false); @@ -251,9 +251,8 @@ void testDefaultValues() { assertThat(options.getStore()).isNull(); assertThat(options.getMetadata()).isNull(); assertThat(options.getReasoningEffort()).isNull(); - assertThat(options.getFunctionCallbacks()).isNotNull().isEmpty(); - assertThat(options.getFunctions()).isNotNull().isEmpty(); - assertThat(options.getProxyToolCalls()).isNull(); + assertThat(options.getToolCallbacks()).isNotNull().isEmpty(); + assertThat(options.getInternalToolExecutionEnabled()).isNull(); assertThat(options.getHttpHeaders()).isNotNull().isEmpty(); assertThat(options.getToolContext()).isEqualTo(new HashMap<>()); assertThat(options.getStreamUsage()).isFalse(); diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelFunctionCallingIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelFunctionCallingIT.java index e62358b4c60..2bcbb7f384d 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelFunctionCallingIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelFunctionCallingIT.java @@ -38,13 +38,13 @@ import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.model.ToolContext; import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.openai.OpenAiChatModel; import org.springframework.ai.openai.OpenAiChatOptions; import org.springframework.ai.openai.api.OpenAiApi; import org.springframework.ai.openai.api.tool.MockWeatherService; import org.springframework.ai.openai.api.tool.MockWeatherService.Request; import org.springframework.ai.openai.api.tool.MockWeatherService.Response; +import org.springframework.ai.tool.function.FunctionToolCallback; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.test.context.SpringBootTest; @@ -69,8 +69,7 @@ void functionCallSupplier() { // @formatter:off String response = ChatClient.create(this.chatModel).prompt() .user("Turn the light on in the living room") - .functions(FunctionCallback.builder() - .function("turnsLightOnInTheLivingRoom", () -> state.put("Light", "ON")) + .tools(FunctionToolCallback.builder("turnsLightOnInTheLivingRoom", () -> state.put("Light", "ON")) .build()) .call() .content(); @@ -84,8 +83,7 @@ void functionCallSupplier() { void functionCallTest() { functionCallTest(OpenAiChatOptions.builder() .model(OpenAiApi.ChatModel.GPT_4_O.getValue()) - .functionCallbacks(List.of(FunctionCallback.builder() - .function("getCurrentWeather", new MockWeatherService()) + .toolCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) .description("Get the weather in location") .inputType(MockWeatherService.Request.class) .build())) @@ -120,8 +118,7 @@ else if (request.location().contains("San Francisco")) { functionCallTest(OpenAiChatOptions.builder() .model(OpenAiApi.ChatModel.GPT_4_O.getValue()) - .functionCallbacks(List.of(FunctionCallback.builder() - .function("getCurrentWeather", biFunction) + .toolCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", biFunction) .description("Get the weather in location") .inputType(MockWeatherService.Request.class) .build())) @@ -146,8 +143,7 @@ void functionCallTest(OpenAiChatOptions promptOptions) { void streamFunctionCallTest() { streamFunctionCallTest(OpenAiChatOptions.builder() - .functionCallbacks(List.of((FunctionCallback.builder() - .function("getCurrentWeather", new MockWeatherService()) + .toolCallbacks(List.of((FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) .description("Get the weather in location") .inputType(MockWeatherService.Request.class) // .responseConverter(response -> "" + response.temp() + response.unit()) @@ -182,8 +178,7 @@ else if (request.location().contains("San Francisco")) { }; OpenAiChatOptions promptOptions = OpenAiChatOptions.builder() - .functionCallbacks(List.of((FunctionCallback.builder() - .function("getCurrentWeather", biFunction) + .toolCallbacks(List.of((FunctionToolCallback.builder("getCurrentWeather", biFunction) .description("Get the weather in location") .inputType(MockWeatherService.Request.class) .build()))) diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelIT.java index b8d970d8eb0..5dbd922602c 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelIT.java @@ -53,7 +53,6 @@ import org.springframework.ai.converter.BeanOutputConverter; import org.springframework.ai.converter.ListOutputConverter; import org.springframework.ai.converter.MapOutputConverter; -import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.openai.OpenAiChatOptions; import org.springframework.ai.openai.OpenAiTestConfiguration; import org.springframework.ai.openai.api.OpenAiApi; @@ -342,8 +341,7 @@ void functionCallTestDeprecated() { var promptOptions = OpenAiChatOptions.builder() .model(OpenAiApi.ChatModel.GPT_4_O.getValue()) - .functionCallbacks(List.of(FunctionCallback.builder() - .function("getCurrentWeather", new MockWeatherService()) + .toolCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) .description("Get the weather in location") .inputType(MockWeatherService.Request.class) .build())) @@ -389,8 +387,7 @@ void streamFunctionCallTest() { var promptOptions = OpenAiChatOptions.builder() // .withModel(OpenAiApi.ChatModel.GPT_4_TURBO_PREVIEW.getValue()) - .functionCallbacks(List.of(FunctionCallback.builder() - .function("getCurrentWeather", new MockWeatherService()) + .toolCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) .description("Get the weather in location") .inputType(MockWeatherService.Request.class) .build())) @@ -422,8 +419,7 @@ void functionCallUsageTest() { var promptOptions = OpenAiChatOptions.builder() // .withModel(OpenAiApi.ChatModel.GPT_4_TURBO_PREVIEW.getValue()) - .functionCallbacks(List.of(FunctionCallback.builder() - .function("getCurrentWeather", new MockWeatherService()) + .toolCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) .description("Get the weather in location") .inputType(MockWeatherService.Request.class) .build())) @@ -451,8 +447,7 @@ void streamFunctionCallUsageTest() { var promptOptions = OpenAiChatOptions.builder() // .withModel(OpenAiApi.ChatModel.GPT_4_TURBO_PREVIEW.getValue()) - .functionCallbacks(List.of(FunctionCallback.builder() - .function("getCurrentWeather", new MockWeatherService()) + .toolCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) .description("Get the weather in location") .inputType(MockWeatherService.Request.class) .build())) diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiPaymentTransactionIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiPaymentTransactionIT.java index 0843df10774..91331799f9d 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiPaymentTransactionIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiPaymentTransactionIT.java @@ -36,13 +36,13 @@ import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor; import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain; import org.springframework.ai.converter.BeanOutputConverter; -import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.tool.ToolCallingManager; import org.springframework.ai.openai.OpenAiChatModel; import org.springframework.ai.openai.OpenAiChatOptions; import org.springframework.ai.openai.api.OpenAiApi; import org.springframework.ai.openai.api.OpenAiApi.ChatModel; import org.springframework.ai.retry.RetryUtils; +import org.springframework.ai.tool.ToolCallback; import org.springframework.ai.tool.ToolCallbackProvider; import org.springframework.ai.tool.execution.DefaultToolExecutionExceptionProcessor; import org.springframework.ai.tool.execution.ToolExecutionExceptionProcessor; @@ -171,7 +171,7 @@ private AdvisedRequest before(AdvisedRequest request) { logger.info("System params: " + request.systemParams()); logger.info("User text: \n" + request.userText()); logger.info("User params:" + request.userParams()); - logger.info("Function names: " + request.functionNames()); + logger.info("Function names: " + request.toolNames()); logger.info("Options: " + request.chatOptions().toString()); @@ -245,9 +245,9 @@ public OpenAiChatModel openAiClient(OpenAiApi openAiApi, ToolCallingManager tool @Bean @ConditionalOnMissingBean ToolCallbackResolver toolCallbackResolver(GenericApplicationContext applicationContext, - List functionCallbacks, List tcbProviders) { + List toolCallback, List tcbProviders) { - List allFunctionAndToolCallbacks = new ArrayList<>(functionCallbacks); + List allFunctionAndToolCallbacks = new ArrayList<>(toolCallback); tcbProviders.stream() .map(pr -> List.of(pr.getToolCallbacks())) .forEach(allFunctionAndToolCallbacks::addAll); diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientProxyFunctionCallsIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientProxyFunctionCallsIT.java deleted file mode 100644 index 5a072c9a28a..00000000000 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientProxyFunctionCallsIT.java +++ /dev/null @@ -1,185 +0,0 @@ -/* - * Copyright 2023-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.ai.openai.chat.client; - -import java.util.ArrayList; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.Set; - -import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.databind.JsonMappingException; -import com.fasterxml.jackson.databind.ObjectMapper; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import org.springframework.ai.chat.client.ChatClient; -import org.springframework.ai.chat.messages.AssistantMessage; -import org.springframework.ai.chat.messages.Message; -import org.springframework.ai.chat.messages.ToolResponseMessage; -import org.springframework.ai.chat.messages.UserMessage; -import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.ai.chat.model.Generation; -import org.springframework.ai.model.ModelOptionsUtils; -import org.springframework.ai.model.function.FunctionCallback; -import org.springframework.ai.model.function.FunctionCallingHelper; -import org.springframework.ai.openai.OpenAiChatModel; -import org.springframework.ai.openai.OpenAiChatOptions; -import org.springframework.ai.openai.OpenAiTestConfiguration; -import org.springframework.ai.openai.api.OpenAiApi; -import org.springframework.ai.openai.testutils.AbstractIT; -import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.beans.factory.annotation.Value; -import org.springframework.boot.test.context.SpringBootTest; -import org.springframework.core.io.Resource; -import org.springframework.test.context.ActiveProfiles; -import org.springframework.util.CollectionUtils; - -import static org.assertj.core.api.Assertions.assertThat; - -@SpringBootTest(classes = OpenAiTestConfiguration.class) -@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+") -@ActiveProfiles("logging-test") -class OpenAiChatClientProxyFunctionCallsIT extends AbstractIT { - - private static final Logger logger = LoggerFactory.getLogger(OpenAiChatClientMultipleFunctionCallsIT.class); - - @Value("classpath:/prompts/system-message.st") - private Resource systemTextResource; - - FunctionCallback functionDefinition = new FunctionCallingHelper.FunctionDefinition("getWeatherInLocation", - "Get the weather in location", """ - { - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "The city and state e.g. San Francisco, CA" - }, - "unit": { - "type": "string", - "enum": ["C", "F"] - } - }, - "required": ["location", "unit"] - } - """); - - @Autowired - private OpenAiChatModel chatModel; - - // Helper class that reuses some of the {@link AbstractToolCallSupport} functionality - // to help to implement the function call handling logic on the client side. - private FunctionCallingHelper functionCallingHelper = new FunctionCallingHelper(); - - // Function which will be called by the AI model. - private String getWeatherInLocation(String location, String unit) { - - double temperature = 0; - - if (location.contains("Paris")) { - temperature = 15; - } - else if (location.contains("Tokyo")) { - temperature = 10; - } - else if (location.contains("San Francisco")) { - temperature = 30; - } - - return String.format("The weather in %s is %s%s", location, temperature, unit); - } - - @Test - void toolProxyFunctionCall() throws JsonMappingException, JsonProcessingException { - - List messages = List - .of(new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?")); - - boolean isToolCall = false; - - ChatResponse chatResponse = null; - - var chatClient = ChatClient.builder(this.chatModel).build(); - - do { - - chatResponse = chatClient.prompt() - .messages(messages) - .tools(this.functionDefinition) - .options(OpenAiChatOptions.builder().proxyToolCalls(true).build()) - .call() - .chatResponse(); - - // Note that the tool call check could be platform specific because the finish - // reasons. - isToolCall = this.functionCallingHelper.isToolCall(chatResponse, - Set.of(OpenAiApi.ChatCompletionFinishReason.TOOL_CALLS.name(), - OpenAiApi.ChatCompletionFinishReason.STOP.name())); - - if (isToolCall) { - - Optional toolCallGeneration = chatResponse.getResults() - .stream() - .filter(g -> !CollectionUtils.isEmpty(g.getOutput().getToolCalls())) - .findFirst(); - - assertThat(toolCallGeneration).isNotEmpty(); - - AssistantMessage assistantMessage = toolCallGeneration.get().getOutput(); - - List toolResponses = new ArrayList<>(); - - for (AssistantMessage.ToolCall toolCall : assistantMessage.getToolCalls()) { - - var functionName = toolCall.name(); - - assertThat(functionName).isEqualTo("getWeatherInLocation"); - - String functionArguments = toolCall.arguments(); - - @SuppressWarnings("unchecked") - Map argumentsMap = new ObjectMapper().readValue(functionArguments, Map.class); - - String functionResponse = getWeatherInLocation(argumentsMap.get("location").toString(), - argumentsMap.get("unit").toString()); - - toolResponses.add(new ToolResponseMessage.ToolResponse(toolCall.id(), functionName, - ModelOptionsUtils.toJsonString(functionResponse))); - } - - ToolResponseMessage toolMessageResponse = new ToolResponseMessage(toolResponses, Map.of()); - - messages = this.functionCallingHelper.buildToolCallConversation(messages, assistantMessage, - toolMessageResponse); - - assertThat(messages).isNotEmpty(); - - // prompt = new Prompt(toolCallConversation, prompt.getOptions()); - } - } - while (isToolCall); - - logger.info("Response: {}", chatResponse); - - assertThat(chatResponse.getResult().getOutput().getText()).contains("30", "10", "15"); - } - -} diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/DeepSeekWithOpenAiChatModelIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/DeepSeekWithOpenAiChatModelIT.java index 975a172d260..ff136aa2422 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/DeepSeekWithOpenAiChatModelIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/DeepSeekWithOpenAiChatModelIT.java @@ -43,12 +43,12 @@ import org.springframework.ai.converter.BeanOutputConverter; import org.springframework.ai.converter.ListOutputConverter; import org.springframework.ai.converter.MapOutputConverter; -import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.openai.OpenAiChatModel; import org.springframework.ai.openai.OpenAiChatOptions; import org.springframework.ai.openai.api.OpenAiApi; import org.springframework.ai.openai.api.tool.MockWeatherService; import org.springframework.ai.openai.chat.ActorsFilms; +import org.springframework.ai.tool.function.FunctionToolCallback; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; import org.springframework.boot.SpringBootConfiguration; @@ -256,8 +256,7 @@ void functionCallTest() { List messages = new ArrayList<>(List.of(userMessage)); var promptOptions = OpenAiChatOptions.builder() - .functionCallbacks(List.of(FunctionCallback.builder() - .function("getCurrentWeather", new MockWeatherService()) + .toolCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) .description("Get the weather in location") .inputType(MockWeatherService.Request.class) .build())) @@ -280,8 +279,7 @@ void streamFunctionCallTest() { List messages = new ArrayList<>(List.of(userMessage)); var promptOptions = OpenAiChatOptions.builder() - .functionCallbacks(List.of(FunctionCallback.builder() - .function("getCurrentWeather", new MockWeatherService()) + .toolCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) .description("Get the weather in location") .inputType(MockWeatherService.Request.class) .build())) diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/DockerModelRunnerWithOpenAiChatModelIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/DockerModelRunnerWithOpenAiChatModelIT.java index d0c5c9df5e6..d819b61c4de 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/DockerModelRunnerWithOpenAiChatModelIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/DockerModelRunnerWithOpenAiChatModelIT.java @@ -46,12 +46,12 @@ import org.springframework.ai.converter.BeanOutputConverter; import org.springframework.ai.converter.ListOutputConverter; import org.springframework.ai.converter.MapOutputConverter; -import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.openai.OpenAiChatModel; import org.springframework.ai.openai.OpenAiChatOptions; import org.springframework.ai.openai.api.OpenAiApi; import org.springframework.ai.openai.api.tool.MockWeatherService; import org.springframework.ai.openai.chat.ActorsFilms; +import org.springframework.ai.tool.function.FunctionToolCallback; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; import org.springframework.boot.SpringBootConfiguration; @@ -268,8 +268,7 @@ void functionCallTest() { List messages = new ArrayList<>(List.of(userMessage)); var promptOptions = OpenAiChatOptions.builder() - .functionCallbacks(List.of(FunctionCallback.builder() - .function("getCurrentWeather", new MockWeatherService()) + .toolCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) .description("Get the weather in location") .inputType(MockWeatherService.Request.class) .build())) @@ -292,8 +291,7 @@ void streamFunctionCallTest() { List messages = new ArrayList<>(List.of(userMessage)); var promptOptions = OpenAiChatOptions.builder() - .functionCallbacks(List.of(FunctionCallback.builder() - .function("getCurrentWeather", new MockWeatherService()) + .toolCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) .description("Get the weather in location") .inputType(MockWeatherService.Request.class) .build())) diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/GroqWithOpenAiChatModelIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/GroqWithOpenAiChatModelIT.java index d670cc9c446..1bb39e97a0e 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/GroqWithOpenAiChatModelIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/GroqWithOpenAiChatModelIT.java @@ -46,12 +46,12 @@ import org.springframework.ai.converter.BeanOutputConverter; import org.springframework.ai.converter.ListOutputConverter; import org.springframework.ai.converter.MapOutputConverter; -import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.openai.OpenAiChatModel; import org.springframework.ai.openai.OpenAiChatOptions; import org.springframework.ai.openai.api.OpenAiApi; import org.springframework.ai.openai.api.tool.MockWeatherService; import org.springframework.ai.openai.chat.ActorsFilms; +import org.springframework.ai.tool.function.FunctionToolCallback; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; import org.springframework.boot.SpringBootConfiguration; @@ -249,8 +249,7 @@ void functionCallTest() { List messages = new ArrayList<>(List.of(userMessage)); var promptOptions = OpenAiChatOptions.builder() - .functionCallbacks(List.of(FunctionCallback.builder() - .function("getCurrentWeather", new MockWeatherService()) + .toolCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) .description("Get the weather in location") .inputType(MockWeatherService.Request.class) .build())) @@ -272,8 +271,7 @@ void streamFunctionCallTest() { List messages = new ArrayList<>(List.of(userMessage)); var promptOptions = OpenAiChatOptions.builder() - .functionCallbacks(List.of(FunctionCallback.builder() - .function("getCurrentWeather", new MockWeatherService()) + .toolCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) .description("Get the weather in location") .inputType(MockWeatherService.Request.class) .build())) diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/MistralWithOpenAiChatModelIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/MistralWithOpenAiChatModelIT.java index af6e5f232e0..31898265ee8 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/MistralWithOpenAiChatModelIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/MistralWithOpenAiChatModelIT.java @@ -47,7 +47,7 @@ import org.springframework.ai.converter.ListOutputConverter; import org.springframework.ai.converter.MapOutputConverter; import org.springframework.ai.model.SimpleApiKey; -import org.springframework.ai.model.tool.LegacyToolCallingManager; +import org.springframework.ai.model.tool.ToolCallingManager; import org.springframework.ai.openai.OpenAiChatModel; import org.springframework.ai.openai.OpenAiChatOptions; import org.springframework.ai.openai.api.OpenAiApi; @@ -398,7 +398,7 @@ public OpenAiApi chatCompletionApi() { public OpenAiChatModel openAiClient(OpenAiApi openAiApi) { return OpenAiChatModel.builder() .openAiApi(openAiApi) - .toolCallingManager(LegacyToolCallingManager.builder().build()) + .toolCallingManager(ToolCallingManager.builder().build()) .defaultOptions(OpenAiChatOptions.builder().model(MISTRAL_DEFAULT_MODEL).build()) .build(); } diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/NvidiaWithOpenAiChatModelIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/NvidiaWithOpenAiChatModelIT.java index 88263c92382..499f6d9933b 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/NvidiaWithOpenAiChatModelIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/NvidiaWithOpenAiChatModelIT.java @@ -41,12 +41,12 @@ import org.springframework.ai.converter.BeanOutputConverter; import org.springframework.ai.converter.ListOutputConverter; import org.springframework.ai.converter.MapOutputConverter; -import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.openai.OpenAiChatModel; import org.springframework.ai.openai.OpenAiChatOptions; import org.springframework.ai.openai.api.OpenAiApi; import org.springframework.ai.openai.api.tool.MockWeatherService; import org.springframework.ai.openai.chat.ActorsFilms; +import org.springframework.ai.tool.function.FunctionToolCallback; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; import org.springframework.boot.SpringBootConfiguration; @@ -246,8 +246,7 @@ void functionCallTest() { List messages = new ArrayList<>(List.of(userMessage)); var promptOptions = OpenAiChatOptions.builder() - .functionCallbacks(List.of(FunctionCallback.builder() - .function("getCurrentWeather", new MockWeatherService()) + .toolCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) .description("Get the weather in location") .inputType(MockWeatherService.Request.class) .build())) @@ -269,8 +268,7 @@ void streamFunctionCallTest() { List messages = new ArrayList<>(List.of(userMessage)); var promptOptions = OpenAiChatOptions.builder() - .functionCallbacks(List.of(FunctionCallback.builder() - .function("getCurrentWeather", new MockWeatherService()) + .toolCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) .description("Get the weather in location") .inputType(MockWeatherService.Request.class) .build())) diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/OllamaWithOpenAiChatModelIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/OllamaWithOpenAiChatModelIT.java index d2d6b35e57d..33d278bd1ed 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/OllamaWithOpenAiChatModelIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/OllamaWithOpenAiChatModelIT.java @@ -50,8 +50,7 @@ import org.springframework.ai.converter.ListOutputConverter; import org.springframework.ai.converter.MapOutputConverter; import org.springframework.ai.model.NoopApiKey; -import org.springframework.ai.model.function.FunctionCallback; -import org.springframework.ai.model.tool.LegacyToolCallingManager; +import org.springframework.ai.model.tool.ToolCallingManager; import org.springframework.ai.openai.OpenAiChatModel; import org.springframework.ai.openai.OpenAiChatOptions; import org.springframework.ai.openai.api.OpenAiApi; @@ -302,8 +301,7 @@ void streamFunctionCallTest(String modelName) { // Note for Ollama you must set the tool choice to explicitly. Unlike OpenAI // (which defaults to "auto") Ollama defaults to "nono" .toolChoice("auto") - .functionCallbacks(List.of(FunctionCallback.builder() - .function("getCurrentWeather", new MockWeatherService()) + .toolCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) .description("Get the weather in location") .inputType(MockWeatherService.Request.class) .build())) @@ -421,7 +419,7 @@ public OpenAiApi chatCompletionApi() { public OpenAiChatModel openAiClient(OpenAiApi openAiApi) { return OpenAiChatModel.builder() .openAiApi(openAiApi) - .toolCallingManager(LegacyToolCallingManager.builder().build()) + .toolCallingManager(ToolCallingManager.builder().build()) .defaultOptions(OpenAiChatOptions.builder().model(DEFAULT_OLLAMA_MODEL).build()) .build(); } diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/PerplexityWithOpenAiChatModelIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/PerplexityWithOpenAiChatModelIT.java index e2041ca29f3..319eda56039 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/PerplexityWithOpenAiChatModelIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/PerplexityWithOpenAiChatModelIT.java @@ -41,12 +41,12 @@ import org.springframework.ai.converter.BeanOutputConverter; import org.springframework.ai.converter.ListOutputConverter; import org.springframework.ai.converter.MapOutputConverter; -import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.openai.OpenAiChatModel; import org.springframework.ai.openai.OpenAiChatOptions; import org.springframework.ai.openai.api.OpenAiApi; import org.springframework.ai.openai.api.tool.MockWeatherService; import org.springframework.ai.openai.chat.ActorsFilms; +import org.springframework.ai.tool.function.FunctionToolCallback; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; import org.springframework.boot.SpringBootConfiguration; @@ -255,8 +255,7 @@ void functionCallTest() { List messages = new ArrayList<>(List.of(userMessage)); var promptOptions = OpenAiChatOptions.builder() - .functionCallbacks(List.of(FunctionCallback.builder() - .function("getCurrentWeather", new MockWeatherService()) + .toolCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) .description("Get the weather in location") .inputType(MockWeatherService.Request.class) .build())) @@ -277,8 +276,7 @@ void streamFunctionCallTest() { List messages = new ArrayList<>(List.of(userMessage)); var promptOptions = OpenAiChatOptions.builder() - .functionCallbacks(List.of(FunctionCallback.builder() - .function("getCurrentWeather", new MockWeatherService()) + .toolCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) .description("Get the weather in location") .inputType(MockWeatherService.Request.class) .build())) diff --git a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java index 8d9a4278b21..af524b611bb 100644 --- a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java +++ b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java @@ -75,10 +75,7 @@ import org.springframework.ai.content.Media; import org.springframework.ai.model.ChatModelDescription; import org.springframework.ai.model.ModelOptionsUtils; -import org.springframework.ai.model.function.FunctionCallback; -import org.springframework.ai.model.function.FunctionCallbackResolver; import org.springframework.ai.model.tool.DefaultToolExecutionEligibilityPredicate; -import org.springframework.ai.model.tool.LegacyToolCallingManager; import org.springframework.ai.model.tool.ToolCallingChatOptions; import org.springframework.ai.model.tool.ToolCallingManager; import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate; @@ -182,70 +179,6 @@ public class VertexAiGeminiChatModel implements ChatModel, DisposableBean { */ private ChatModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION; - /** - * @deprecated Use {@link VertexAiGeminiChatModel.Builder}. - */ - @Deprecated - public VertexAiGeminiChatModel(VertexAI vertexAI) { - this(vertexAI, VertexAiGeminiChatOptions.builder().model(ChatModel.GEMINI_1_5_PRO).temperature(0.8).build()); - } - - /** - * @deprecated Use {@link VertexAiGeminiChatModel.Builder}. - */ - @Deprecated - public VertexAiGeminiChatModel(VertexAI vertexAI, VertexAiGeminiChatOptions options) { - this(vertexAI, options, null); - } - - /** - * @deprecated Use {@link VertexAiGeminiChatModel.Builder}. - */ - @Deprecated - public VertexAiGeminiChatModel(VertexAI vertexAI, VertexAiGeminiChatOptions options, - FunctionCallbackResolver functionCallbackResolver) { - this(vertexAI, options, functionCallbackResolver, List.of()); - } - - /** - * @deprecated Use {@link VertexAiGeminiChatModel.Builder}. - */ - @Deprecated - public VertexAiGeminiChatModel(VertexAI vertexAI, VertexAiGeminiChatOptions options, - FunctionCallbackResolver functionCallbackResolver, List toolFunctionCallbacks) { - this(vertexAI, options, functionCallbackResolver, toolFunctionCallbacks, RetryUtils.DEFAULT_RETRY_TEMPLATE); - } - - /** - * @deprecated Use {@link VertexAiGeminiChatModel.Builder}. - */ - @Deprecated - public VertexAiGeminiChatModel(VertexAI vertexAI, VertexAiGeminiChatOptions options, - FunctionCallbackResolver functionCallbackResolver, List toolFunctionCallbacks, - RetryTemplate retryTemplate) { - this(vertexAI, options, functionCallbackResolver, toolFunctionCallbacks, retryTemplate, - ObservationRegistry.NOOP); - } - - /** - * @deprecated Use {@link VertexAiGeminiChatModel.Builder}. - */ - @Deprecated - public VertexAiGeminiChatModel(VertexAI vertexAI, VertexAiGeminiChatOptions options, - FunctionCallbackResolver functionCallbackResolver, List toolFunctionCallbacks, - RetryTemplate retryTemplate, ObservationRegistry observationRegistry) { - - this(vertexAI, options, - LegacyToolCallingManager.builder() - .functionCallbackResolver(functionCallbackResolver) - .functionCallbacks(toolFunctionCallbacks) - .build(), - retryTemplate, observationRegistry); - logger.warn("This constructor is deprecated and will be removed in the next milestone. " - + "Please use the new constructor accepting ToolCallingManager instead."); - - } - /** * Creates a new instance of VertexAiGeminiChatModel. * @param vertexAI the Vertex AI instance to use @@ -851,10 +784,6 @@ public static final class Builder { private ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate = new DefaultToolExecutionEligibilityPredicate(); - private FunctionCallbackResolver functionCallbackResolver; - - private List toolFunctionCallbacks; - private RetryTemplate retryTemplate = RetryUtils.DEFAULT_RETRY_TEMPLATE; private ObservationRegistry observationRegistry = ObservationRegistry.NOOP; @@ -883,18 +812,6 @@ public Builder toolExecutionEligibilityPredicate( return this; } - @Deprecated - public Builder functionCallbackResolver(FunctionCallbackResolver functionCallbackResolver) { - this.functionCallbackResolver = functionCallbackResolver; - return this; - } - - @Deprecated - public Builder toolFunctionCallbacks(List toolFunctionCallbacks) { - this.toolFunctionCallbacks = toolFunctionCallbacks; - return this; - } - public Builder retryTemplate(RetryTemplate retryTemplate) { this.retryTemplate = retryTemplate; return this; @@ -907,25 +824,9 @@ public Builder observationRegistry(ObservationRegistry observationRegistry) { public VertexAiGeminiChatModel build() { if (this.toolCallingManager != null) { - Assert.isNull(this.functionCallbackResolver, - "functionCallbackResolver cannot be set when toolCallingManager is set"); - Assert.isNull(this.toolFunctionCallbacks, - "toolFunctionCallbacks cannot be set when toolCallingManager is set"); - return new VertexAiGeminiChatModel(this.vertexAI, this.defaultOptions, this.toolCallingManager, this.retryTemplate, this.observationRegistry, this.toolExecutionEligibilityPredicate); } - - if (this.functionCallbackResolver != null) { - Assert.isNull(this.toolCallingManager, - "toolCallingManager cannot be set when functionCallbackResolver is set"); - List toolCallbacks = this.toolFunctionCallbacks != null ? this.toolFunctionCallbacks - : List.of(); - - return new VertexAiGeminiChatModel(this.vertexAI, this.defaultOptions, this.functionCallbackResolver, - toolCallbacks, this.retryTemplate, this.observationRegistry); - } - return new VertexAiGeminiChatModel(this.vertexAI, this.defaultOptions, DEFAULT_TOOL_CALLING_MANAGER, this.retryTemplate, this.observationRegistry, this.toolExecutionEligibilityPredicate); } diff --git a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java index 8b5f8619ccf..2f71b0a4099 100644 --- a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java +++ b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java @@ -30,7 +30,6 @@ import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; -import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.tool.ToolCallingChatOptions; import org.springframework.ai.tool.ToolCallback; import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatModel.ChatModel; @@ -101,7 +100,7 @@ public class VertexAiGeminiChatOptions implements ToolCallingChatOptions { * completion requests. */ @JsonIgnore - private List toolCallbacks = new ArrayList<>(); + private List toolCallbacks = new ArrayList<>(); /** * Collection of tool names to be resolved at runtime and used for tool calling in the @@ -234,44 +233,17 @@ public void setResponseMimeType(String mimeType) { } @Override - @JsonIgnore - @Deprecated - public List getFunctionCallbacks() { - return this.getToolCallbacks(); - } - - @Override - @JsonIgnore - @Deprecated - public void setFunctionCallbacks(List functionCallbacks) { - this.setToolCallbacks(functionCallbacks); - } - - @Override - public List getToolCallbacks() { + public List getToolCallbacks() { return this.toolCallbacks; } @Override - public void setToolCallbacks(List toolCallbacks) { + public void setToolCallbacks(List toolCallbacks) { Assert.notNull(toolCallbacks, "toolCallbacks cannot be null"); Assert.noNullElements(toolCallbacks, "toolCallbacks cannot contain null elements"); this.toolCallbacks = toolCallbacks; } - @Override - @JsonIgnore - @Deprecated - public Set getFunctions() { - return this.getToolNames(); - } - - @JsonIgnore - @Deprecated - public void setFunctions(Set functions) { - this.setToolNames(functions); - } - @Override public Set getToolNames() { return this.toolNames; @@ -325,19 +297,6 @@ public void setSafetySettings(List safetySettings) this.safetySettings = safetySettings; } - @Deprecated - @Override - @JsonIgnore - public Boolean getProxyToolCalls() { - return this.internalToolExecutionEnabled != null ? !this.internalToolExecutionEnabled : null; - } - - @Deprecated - @JsonIgnore - public void setProxyToolCalls(Boolean proxyToolCalls) { - this.internalToolExecutionEnabled = proxyToolCalls != null ? !proxyToolCalls : null; - } - @Override public Map getToolContext() { return this.toolContext; @@ -447,17 +406,12 @@ public Builder responseMimeType(String mimeType) { return this; } - @Deprecated - public Builder functionCallbacks(List functionCallbacks) { - return toolCallbacks(functionCallbacks); - } - - public Builder toolCallbacks(List toolCallbacks) { + public Builder toolCallbacks(List toolCallbacks) { this.options.toolCallbacks = toolCallbacks; return this; } - public Builder toolCallbacks(FunctionCallback... toolCallbacks) { + public Builder toolCallbacks(ToolCallback... toolCallbacks) { Assert.notNull(toolCallbacks, "toolCallbacks cannot be null"); this.options.toolCallbacks.addAll(Arrays.asList(toolCallbacks)); return this; diff --git a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/CreateGeminiRequestTests.java b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/CreateGeminiRequestTests.java index 444db43f20e..802fab3b8a9 100644 --- a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/CreateGeminiRequestTests.java +++ b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/CreateGeminiRequestTests.java @@ -32,7 +32,6 @@ import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.content.Media; -import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.tool.ToolCallingChatOptions; import org.springframework.ai.model.tool.ToolCallingManager; import org.springframework.ai.tool.definition.ToolDefinition; @@ -165,8 +164,7 @@ public void defaultOptionsTools() { .toolCallingManager(toolCallingManager) .defaultOptions(VertexAiGeminiChatOptions.builder() .model("DEFAULT_MODEL") - .functionCallbacks(List.of(FunctionCallback.builder() - .function(TOOL_FUNCTION_NAME, new MockWeatherService()) + .toolCallbacks(List.of(FunctionToolCallback.builder(TOOL_FUNCTION_NAME, new MockWeatherService()) .description("Get the weather in location") .inputType(MockWeatherService.Request.class) .build())) @@ -205,8 +203,7 @@ public void defaultOptionsTools() { // Override the default options function with one from the prompt requestPrompt = client.buildRequestPrompt(new Prompt("Test message content", VertexAiGeminiChatOptions.builder() - .functionCallbacks(List.of(FunctionCallback.builder() - .function(TOOL_FUNCTION_NAME, new MockWeatherService()) + .toolCallbacks(List.of(FunctionToolCallback.builder(TOOL_FUNCTION_NAME, new MockWeatherService()) .description("Overridden function description") .inputType(MockWeatherService.Request.class) .build())) diff --git a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/TestVertexAiGeminiChatModel.java b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/TestVertexAiGeminiChatModel.java index 73c498831fa..33af68c57d2 100644 --- a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/TestVertexAiGeminiChatModel.java +++ b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/TestVertexAiGeminiChatModel.java @@ -17,14 +17,12 @@ package org.springframework.ai.vertexai.gemini; import java.io.IOException; -import java.util.List; import com.google.cloud.vertexai.VertexAI; import com.google.cloud.vertexai.api.GenerateContentResponse; import com.google.cloud.vertexai.generativeai.GenerativeModel; -import org.springframework.ai.model.function.FunctionCallback; -import org.springframework.ai.model.function.FunctionCallbackResolver; +import org.springframework.ai.model.tool.ToolCallingManager; import org.springframework.retry.support.RetryTemplate; /** @@ -35,9 +33,8 @@ public class TestVertexAiGeminiChatModel extends VertexAiGeminiChatModel { private GenerativeModel mockGenerativeModel; public TestVertexAiGeminiChatModel(VertexAI vertexAI, VertexAiGeminiChatOptions options, - FunctionCallbackResolver functionCallbackResolver, List toolFunctionCallbacks, RetryTemplate retryTemplate) { - super(vertexAI, options, functionCallbackResolver, toolFunctionCallbacks, retryTemplate); + super(vertexAI, options, ToolCallingManager.builder().build(), retryTemplate, null); } @Override diff --git a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiRetryTests.java b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiRetryTests.java index 577e4b2086a..39ac598a181 100644 --- a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiRetryTests.java +++ b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiRetryTests.java @@ -17,7 +17,6 @@ package org.springframework.ai.vertexai.gemini; import java.io.IOException; -import java.util.Collections; import java.util.List; import com.google.cloud.vertexai.VertexAI; @@ -77,7 +76,7 @@ public void setUp() { .topP(1.0) .model(VertexAiGeminiChatModel.ChatModel.GEMINI_2_0_FLASH.getValue()) .build(), - null, Collections.emptyList(), this.retryTemplate); + this.retryTemplate); this.chatModel.setMockGenerativeModel(this.mockGenerativeModel); } diff --git a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/tool/VertexAiGeminiChatModelFunctionCallingIT.java b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/tool/VertexAiGeminiChatModelFunctionCallingIT.java index f0604e48eac..863c7055d4e 100644 --- a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/tool/VertexAiGeminiChatModelFunctionCallingIT.java +++ b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/tool/VertexAiGeminiChatModelFunctionCallingIT.java @@ -35,6 +35,7 @@ import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.model.tool.ToolCallingManager; import org.springframework.ai.tool.function.FunctionToolCallback; import org.springframework.ai.util.json.schema.JsonSchemaGenerator; import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatModel; @@ -84,7 +85,7 @@ public void functionCallExplicitOpenApiSchema() { """; var promptOptions = VertexAiGeminiChatOptions.builder() - .functionCallbacks(List.of(FunctionToolCallback.builder("get_current_weather", new MockWeatherService()) + .toolCallbacks(List.of(FunctionToolCallback.builder("get_current_weather", new MockWeatherService()) .description("Get the current weather in a given location") .inputSchema(openApiSchema) .inputType(MockWeatherService.Request.class) @@ -108,7 +109,7 @@ public void functionCallTestInferredOpenApiSchema() { var promptOptions = VertexAiGeminiChatOptions.builder() .model(VertexAiGeminiChatModel.ChatModel.GEMINI_2_0_FLASH) - .functionCallbacks(List.of( + .toolCallbacks(List.of( FunctionToolCallback.builder("get_current_weather", new MockWeatherService()) .inputSchema(JsonSchemaGenerator.generateForType(MockWeatherService.Request.class, JsonSchemaGenerator.SchemaOption.UPPER_CASE_TYPE_VALUES)) @@ -149,7 +150,7 @@ public void functionCallTestInferredOpenApiSchema2() { var promptOptions = VertexAiGeminiChatOptions.builder() .model(VertexAiGeminiChatModel.ChatModel.GEMINI_2_0_FLASH) - .functionCallbacks(List.of( + .toolCallbacks(List.of( FunctionToolCallback.builder("get_current_weather", new MockWeatherService()) .description("Get the current weather in a given location.") .inputType(MockWeatherService.Request.class) @@ -186,7 +187,7 @@ public void functionCallTestInferredOpenApiSchemaStream() { var promptOptions = VertexAiGeminiChatOptions.builder() .model(VertexAiGeminiChatModel.ChatModel.GEMINI_2_0_FLASH) - .functionCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) + .toolCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) .inputSchema(JsonSchemaGenerator.generateForType(MockWeatherService.Request.class, JsonSchemaGenerator.SchemaOption.UPPER_CASE_TYPE_VALUES)) .description("Get the current weather in a given location") @@ -247,7 +248,8 @@ public VertexAiGeminiChatModel vertexAiEmbedding(VertexAI vertexAi) { VertexAiGeminiChatOptions.builder() .model(VertexAiGeminiChatModel.ChatModel.GEMINI_2_0_FLASH) .temperature(0.9) - .build()); + .build(), + ToolCallingManager.builder().build(), null, null); } } diff --git a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/tool/VertexAiGeminiPaymentTransactionIT.java b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/tool/VertexAiGeminiPaymentTransactionIT.java index ee0574d7645..9ef439a8477 100644 --- a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/tool/VertexAiGeminiPaymentTransactionIT.java +++ b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/tool/VertexAiGeminiPaymentTransactionIT.java @@ -36,8 +36,8 @@ import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor; import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain; -import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.tool.ToolCallingManager; +import org.springframework.ai.tool.ToolCallback; import org.springframework.ai.tool.execution.DefaultToolExecutionExceptionProcessor; import org.springframework.ai.tool.resolution.DelegatingToolCallbackResolver; import org.springframework.ai.tool.resolution.SpringBeanToolCallbackResolver; @@ -146,7 +146,7 @@ private AdvisedRequest before(AdvisedRequest request) { logger.info("System params: " + request.systemParams()); logger.info("User text: \n" + request.userText()); logger.info("User params:" + request.userParams()); - logger.info("Function names: " + request.functionNames()); + logger.info("Function names: " + request.toolNames()); logger.info("Options: " + request.chatOptions().toString()); @@ -225,7 +225,7 @@ public VertexAiGeminiChatModel vertexAiChatModel(VertexAI vertexAi, ToolCallingM @Bean ToolCallingManager toolCallingManager(GenericApplicationContext applicationContext, - List toolCallbacks, ObjectProvider observationRegistry) { + List toolCallbacks, ObjectProvider observationRegistry) { var staticToolCallbackResolver = new StaticToolCallbackResolver(toolCallbacks); var springBeanToolCallbackResolver = SpringBeanToolCallbackResolver.builder() diff --git a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/tool/VertexAiGeminiPaymentTransactionMethodIT.java b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/tool/VertexAiGeminiPaymentTransactionMethodIT.java index 5d5fe556a4d..2cb65c80f0f 100644 --- a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/tool/VertexAiGeminiPaymentTransactionMethodIT.java +++ b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/tool/VertexAiGeminiPaymentTransactionMethodIT.java @@ -36,8 +36,8 @@ import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor; import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain; -import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.tool.ToolCallingManager; +import org.springframework.ai.tool.ToolCallback; import org.springframework.ai.tool.ToolCallbackProvider; import org.springframework.ai.tool.ToolCallbacks; import org.springframework.ai.tool.annotation.Tool; @@ -144,7 +144,7 @@ private AdvisedRequest before(AdvisedRequest request) { logger.info("System params: " + request.systemParams()); logger.info("User text: \n" + request.userText()); logger.info("User params:" + request.userParams()); - logger.info("Function names: " + request.functionNames()); + logger.info("Function names: " + request.toolNames()); logger.info("Options: " + request.chatOptions().toString()); @@ -220,10 +220,10 @@ public VertexAiGeminiChatModel vertexAiChatModel(VertexAI vertexAi, ToolCallingM @Bean ToolCallingManager toolCallingManager(GenericApplicationContext applicationContext, - List tcps, List functionCallbacks, + List tcps, List toolCallbacks, ObjectProvider observationRegistry) { - List allFunctionCallbacks = new ArrayList(functionCallbacks); + List allFunctionCallbacks = new ArrayList(toolCallbacks); tcps.stream().map(pr -> List.of(pr.getToolCallbacks())).forEach(allFunctionCallbacks::addAll); var staticToolCallbackResolver = new StaticToolCallbackResolver(allFunctionCallbacks); diff --git a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/tool/VertexAiGeminiPaymentTransactionToolsIT.java b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/tool/VertexAiGeminiPaymentTransactionToolsIT.java index 47a245e446a..9fc92285318 100644 --- a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/tool/VertexAiGeminiPaymentTransactionToolsIT.java +++ b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/tool/VertexAiGeminiPaymentTransactionToolsIT.java @@ -35,8 +35,8 @@ import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor; import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain; -import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.tool.ToolCallingManager; +import org.springframework.ai.tool.ToolCallback; import org.springframework.ai.tool.annotation.Tool; import org.springframework.ai.tool.execution.DefaultToolExecutionExceptionProcessor; import org.springframework.ai.tool.resolution.DelegatingToolCallbackResolver; @@ -145,7 +145,7 @@ private AdvisedRequest before(AdvisedRequest request) { logger.info("System params: " + request.systemParams()); logger.info("User text: \n" + request.userText()); logger.info("User params:" + request.userParams()); - logger.info("Function names: " + request.functionNames()); + logger.info("Function names: " + request.toolNames()); logger.info("Options: " + request.chatOptions().toString()); @@ -216,7 +216,7 @@ public VertexAiGeminiChatModel vertexAiChatModel(VertexAI vertexAi, ToolCallingM @Bean ToolCallingManager toolCallingManager(GenericApplicationContext applicationContext, - List toolCallbacks, ObjectProvider observationRegistry) { + List toolCallbacks, ObjectProvider observationRegistry) { var staticToolCallbackResolver = new StaticToolCallbackResolver(toolCallbacks); var springBeanToolCallbackResolver = SpringBeanToolCallbackResolver.builder() diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/ChatClient.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/ChatClient.java index 6a8d5525faf..14ca4890d9f 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/ChatClient.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/ChatClient.java @@ -34,7 +34,6 @@ import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.content.Media; import org.springframework.ai.converter.StructuredOutputConverter; -import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.tool.ToolCallback; import org.springframework.ai.tool.ToolCallbackProvider; import org.springframework.core.ParameterizedTypeReference; @@ -222,7 +221,7 @@ interface ChatClientRequestSpec { ChatClientRequestSpec tools(String... toolNames); - ChatClientRequestSpec tools(FunctionCallback... toolCallbacks); + ChatClientRequestSpec tools(ToolCallback... toolCallbacks); ChatClientRequestSpec tools(List toolCallbacks); @@ -230,12 +229,6 @@ interface ChatClientRequestSpec { ChatClientRequestSpec tools(ToolCallbackProvider... toolCallbackProviders); - @Deprecated - ChatClientRequestSpec functions(FunctionCallback... functionCallbacks); - - @Deprecated - ChatClientRequestSpec functions(String... functionBeanNames); - ChatClientRequestSpec toolContext(Map toolContext); ChatClientRequestSpec system(String text); @@ -291,7 +284,7 @@ interface Builder { Builder defaultTools(String... toolNames); - Builder defaultTools(FunctionCallback... toolCallbacks); + Builder defaultTools(ToolCallback... toolCallbacks); Builder defaultTools(List toolCallbacks); @@ -299,18 +292,6 @@ interface Builder { Builder defaultTools(ToolCallbackProvider... toolCallbackProviders); - /** - * @deprecated in favor of {@link #defaultTools(String...)} - */ - @Deprecated - Builder defaultFunctions(String... functionNames); - - /** - * @deprecated in favor of {@link #defaultTools(Object...)} - */ - @Deprecated - Builder defaultFunctions(FunctionCallback... functionCallbacks); - Builder defaultToolContext(Map toolContext); Builder clone(); diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java index 5e4abd074c4..3a4655ed95e 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java @@ -32,19 +32,14 @@ import io.micrometer.observation.Observation; import io.micrometer.observation.ObservationRegistry; import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; +import reactor.core.publisher.Flux; import org.springframework.ai.chat.client.advisor.ChatModelCallAdvisor; import org.springframework.ai.chat.client.advisor.ChatModelStreamAdvisor; +import org.springframework.ai.chat.client.advisor.DefaultAroundAdvisorChain; import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; import org.springframework.ai.chat.client.advisor.api.Advisor; import org.springframework.ai.chat.client.advisor.api.BaseAdvisorChain; -import org.springframework.ai.tool.ToolCallback; -import org.springframework.ai.tool.ToolCallbackProvider; -import org.springframework.ai.tool.ToolCallbacks; -import org.springframework.lang.NonNull; -import reactor.core.publisher.Flux; - -import org.springframework.ai.chat.client.advisor.DefaultAroundAdvisorChain; import org.springframework.ai.chat.client.observation.ChatClientObservationContext; import org.springframework.ai.chat.client.observation.ChatClientObservationConvention; import org.springframework.ai.chat.client.observation.ChatClientObservationDocumentation; @@ -62,9 +57,12 @@ import org.springframework.ai.content.Media; import org.springframework.ai.converter.BeanOutputConverter; import org.springframework.ai.converter.StructuredOutputConverter; -import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.ai.tool.ToolCallback; +import org.springframework.ai.tool.ToolCallbackProvider; +import org.springframework.ai.tool.ToolCallbacks; import org.springframework.core.ParameterizedTypeReference; import org.springframework.core.io.Resource; +import org.springframework.lang.NonNull; import org.springframework.lang.Nullable; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; @@ -135,7 +133,7 @@ public static DefaultChatClientRequestSpec toDefaultChatClientRequestSpec(Advise return new DefaultChatClientRequestSpec(advisedRequest.chatModel(), advisedRequest.userText(), advisedRequest.userParams(), advisedRequest.systemText(), advisedRequest.systemParams(), - advisedRequest.functionCallbacks(), advisedRequest.messages(), advisedRequest.functionNames(), + advisedRequest.toolCallbacks(), advisedRequest.messages(), advisedRequest.toolNames(), advisedRequest.media(), advisedRequest.chatOptions(), advisedRequest.advisors(), advisedRequest.advisorParams(), observationRegistry, customObservationConvention, advisedRequest.toolContext()); @@ -650,7 +648,7 @@ public static class DefaultChatClientRequestSpec implements ChatClientRequestSpe private final List toolNames = new ArrayList<>(); - private final List toolCallbacks = new ArrayList<>(); + private final List toolCallbacks = new ArrayList<>(); private final List messages = new ArrayList<>(); @@ -684,7 +682,7 @@ public static class DefaultChatClientRequestSpec implements ChatClientRequestSpe public DefaultChatClientRequestSpec(ChatModel chatModel, @Nullable String userText, Map userParams, @Nullable String systemText, Map systemParams, - List toolCallbacks, List messages, List toolNames, List media, + List toolCallbacks, List messages, List toolNames, List media, @Nullable ChatOptions chatOptions, List advisors, Map advisorParams, ObservationRegistry observationRegistry, @Nullable ChatClientObservationConvention observationConvention, Map toolContext) { @@ -777,11 +775,11 @@ public List getMedia() { return this.media; } - public List getFunctionNames() { + public List getToolNames() { return this.toolNames; } - public List getFunctionCallbacks() { + public List getToolCallbacks() { return this.toolCallbacks; } @@ -873,7 +871,7 @@ public ChatClientRequestSpec tools(String... toolNames) { } @Override - public ChatClientRequestSpec tools(FunctionCallback... toolCallbacks) { + public ChatClientRequestSpec tools(ToolCallback... toolCallbacks) { Assert.notNull(toolCallbacks, "toolCallbacks cannot be null"); Assert.noNullElements(toolCallbacks, "toolCallbacks cannot contain null elements"); this.toolCallbacks.addAll(List.of(toolCallbacks)); @@ -906,19 +904,6 @@ public ChatClientRequestSpec tools(ToolCallbackProvider... toolCallbackProviders return this; } - @Deprecated // Use tools() - public ChatClientRequestSpec functions(String... functionBeanNames) { - return tools(functionBeanNames); - } - - @Deprecated // Use tools() - public ChatClientRequestSpec functions(FunctionCallback... functionCallbacks) { - Assert.notNull(functionCallbacks, "functionCallbacks cannot be null"); - Assert.noNullElements(functionCallbacks, "functionCallbacks cannot contain null elements"); - this.toolCallbacks.addAll(Arrays.asList(functionCallbacks)); - return this; - } - public ChatClientRequestSpec toolContext(Map toolContext) { Assert.notNull(toolContext, "toolContext cannot be null"); Assert.noNullElements(toolContext.keySet(), "toolContext keys cannot contain null elements"); diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java index 594fd355222..02b3e29f681 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java @@ -32,11 +32,10 @@ import org.springframework.ai.chat.client.observation.ChatClientObservationConvention; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.model.ChatModel; -import org.springframework.ai.chat.model.ToolContext; import org.springframework.ai.chat.prompt.ChatOptions; -import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.tool.ToolCallback; import org.springframework.ai.tool.ToolCallbackProvider; +import org.springframework.ai.tool.function.FunctionToolCallback; import org.springframework.core.io.Resource; import org.springframework.lang.Nullable; import org.springframework.util.Assert; @@ -157,7 +156,7 @@ public Builder defaultTools(String... toolNames) { } @Override - public Builder defaultTools(FunctionCallback... toolCallbacks) { + public Builder defaultTools(ToolCallback... toolCallbacks) { this.defaultRequest.tools(toolCallbacks); return this; } @@ -182,28 +181,7 @@ public Builder defaultTools(ToolCallbackProvider... toolCallbackProviders) { @Deprecated // Use defaultTools() public Builder defaultFunction(String name, String description, java.util.function.Function function) { - this.defaultRequest - .functions(FunctionCallback.builder().function(name, function).description(description).build()); - return this; - } - - @Deprecated // Use defaultTools() - public Builder defaultFunction(String name, String description, - java.util.function.BiFunction biFunction) { - this.defaultRequest - .functions(FunctionCallback.builder().function(name, biFunction).description(description).build()); - return this; - } - - @Deprecated // Use defaultTools() - public Builder defaultFunctions(String... functionNames) { - this.defaultRequest.functions(functionNames); - return this; - } - - @Deprecated // Use defaultTools() - public Builder defaultFunctions(FunctionCallback... functionCallbacks) { - this.defaultRequest.functions(functionCallbacks); + this.defaultRequest.tools(FunctionToolCallback.builder(name, function).description(description).build()); return this; } @@ -216,9 +194,9 @@ void addMessages(List messages) { this.defaultRequest.messages(messages); } - void addToolCallbacks(List toolCallbacks) { + void addToolCallbacks(List toolCallbacks) { Assert.notNull(toolCallbacks, "toolCallbacks cannot be null"); - this.defaultRequest.tools(toolCallbacks.toArray(FunctionCallback[]::new)); + this.defaultRequest.tools(toolCallbacks.toArray(ToolCallback[]::new)); } void addToolContext(Map toolContext) { diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedRequest.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedRequest.java index ff0fb8c14ec..9a518c95573 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedRequest.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedRequest.java @@ -23,8 +23,8 @@ import java.util.LinkedList; import java.util.List; import java.util.Map; -import java.util.function.Function; import java.util.Objects; +import java.util.function.Function; import org.springframework.ai.chat.client.ChatClientAttributes; import org.springframework.ai.chat.client.ChatClientRequest; @@ -36,9 +36,8 @@ import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.PromptTemplate; import org.springframework.ai.content.Media; -import org.springframework.ai.model.function.FunctionCallback; -import org.springframework.ai.model.function.FunctionCallingOptions; import org.springframework.ai.model.tool.ToolCallingChatOptions; +import org.springframework.ai.tool.ToolCallback; import org.springframework.lang.Nullable; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; @@ -53,8 +52,8 @@ * @param systemText the text provided by the system * @param chatOptions the options for the chat * @param media the list of media items - * @param functionNames the list of function names - * @param functionCallbacks the list of function callbacks + * @param toolNames the list of function names + * @param toolCallbacks the list of function callbacks * @param messages the list of messages * @param userParams the map of user parameters * @param systemParams the map of system parameters @@ -68,7 +67,6 @@ * @deprecated Use {@link ChatClientRequest} instead. * @since 1.0.0 */ -@Deprecated public record AdvisedRequest( // @formatter:off ChatModel chatModel, @@ -78,8 +76,8 @@ public record AdvisedRequest( @Nullable ChatOptions chatOptions, List media, - List functionNames, - List functionCallbacks, + List toolNames, + List toolCallbacks, List messages, Map userParams, Map systemParams, @@ -97,10 +95,10 @@ public record AdvisedRequest( "userText cannot be null or empty unless messages are provided and contain Tool Response message."); Assert.notNull(media, "media cannot be null"); Assert.noNullElements(media, "media cannot contain null elements"); - Assert.notNull(functionNames, "functionNames cannot be null"); - Assert.noNullElements(functionNames, "functionNames cannot contain null elements"); - Assert.notNull(functionCallbacks, "functionCallbacks cannot be null"); - Assert.noNullElements(functionCallbacks, "functionCallbacks cannot contain null elements"); + Assert.notNull(toolNames, "toolNames cannot be null"); + Assert.noNullElements(toolNames, "toolNames cannot contain null elements"); + Assert.notNull(toolCallbacks, "toolCallbacks cannot be null"); + Assert.noNullElements(toolCallbacks, "toolCallbacks cannot contain null elements"); Assert.notNull(messages, "messages cannot be null"); Assert.noNullElements(messages, "messages cannot contain null elements"); Assert.notNull(userParams, "userParams cannot be null"); @@ -135,8 +133,8 @@ public static Builder from(AdvisedRequest from) { builder.systemText = from.systemText; builder.chatOptions = from.chatOptions; builder.media = from.media; - builder.functionNames = from.functionNames; - builder.functionCallbacks = from.functionCallbacks; + builder.toolNames = from.toolNames; + builder.toolCallbacks = from.toolCallbacks; builder.messages = from.messages; builder.userParams = from.userParams; builder.systemParams = from.systemParams; @@ -179,8 +177,8 @@ public static AdvisedRequest from(ChatClientRequest from) { builder.chatOptions = Objects.requireNonNullElse(from.prompt().getOptions(), ChatOptions.builder().build()); if (from.prompt().getOptions() instanceof ToolCallingChatOptions options) { - builder.functionNames = options.getToolNames().stream().toList(); - builder.functionCallbacks = options.getToolCallbacks(); + builder.toolNames = options.getToolNames().stream().toList(); + builder.toolCallbacks = options.getToolCallbacks(); builder.toolContext = options.getToolContext(); } @@ -231,15 +229,15 @@ public Prompt toPrompt() { messages.add(new UserMessage(processedUserText, this.media())); } - if (this.chatOptions() instanceof FunctionCallingOptions functionCallingOptions) { - if (!this.functionNames().isEmpty()) { - functionCallingOptions.setFunctions(new HashSet<>(this.functionNames())); + if (this.chatOptions() instanceof ToolCallingChatOptions toolCallingChatOptions) { + if (!this.toolNames().isEmpty()) { + toolCallingChatOptions.setToolNames(new HashSet<>(this.toolNames())); } - if (!this.functionCallbacks().isEmpty()) { - functionCallingOptions.setFunctionCallbacks(this.functionCallbacks()); + if (!this.toolCallbacks().isEmpty()) { + toolCallingChatOptions.setToolCallbacks(this.toolCallbacks()); } if (!CollectionUtils.isEmpty(this.toolContext())) { - functionCallingOptions.setToolContext(this.toolContext()); + toolCallingChatOptions.setToolContext(this.toolContext()); } } @@ -261,9 +259,9 @@ public static final class Builder { private List media = List.of(); - private List functionNames = List.of(); + private List toolNames = List.of(); - private List functionCallbacks = List.of(); + private List toolCallbacks = List.of(); private List messages = List.of(); @@ -333,22 +331,22 @@ public Builder media(List media) { } /** - * Set the function names. - * @param functionNames the function names + * Set the tool names. + * @param toolNames the function names * @return this {@link Builder} instance */ - public Builder functionNames(List functionNames) { - this.functionNames = functionNames; + public Builder toolNames(List toolNames) { + this.toolNames = toolNames; return this; } /** - * Set the function callbacks. - * @param functionCallbacks the function callbacks + * Set the tool callbacks. + * @param toolCallbacks the tool callbacks * @return this {@link Builder} instance */ - public Builder functionCallbacks(List functionCallbacks) { - this.functionCallbacks = functionCallbacks; + public Builder functionCallbacks(List toolCallbacks) { + this.toolCallbacks = toolCallbacks; return this; } @@ -430,7 +428,7 @@ public Builder toolContext(Map toolContext) { */ public AdvisedRequest build() { return new AdvisedRequest(this.chatModel, this.userText, this.systemText, this.chatOptions, this.media, - this.functionNames, this.functionCallbacks, this.messages, this.userParams, this.systemParams, + this.toolNames, this.toolCallbacks, this.messages, this.userParams, this.systemParams, this.advisors, this.advisorParams, this.adviseContext, this.toolContext); } diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/observation/DefaultChatClientObservationConvention.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/observation/DefaultChatClientObservationConvention.java index d52cb3be1a6..425861aa566 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/observation/DefaultChatClientObservationConvention.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/observation/DefaultChatClientObservationConvention.java @@ -16,6 +16,9 @@ package org.springframework.ai.chat.client.observation; +import java.util.Arrays; +import java.util.List; + import io.micrometer.common.KeyValue; import io.micrometer.common.KeyValues; @@ -23,16 +26,12 @@ import org.springframework.ai.chat.client.advisor.api.Advisor; import org.springframework.ai.chat.client.observation.ChatClientObservationDocumentation.LowCardinalityKeyNames; import org.springframework.ai.chat.observation.ChatModelObservationDocumentation; -import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.tool.ToolCallingChatOptions; import org.springframework.ai.observation.conventions.SpringAiKind; import org.springframework.ai.observation.tracing.TracingHelper; import org.springframework.lang.Nullable; import org.springframework.util.CollectionUtils; -import java.util.Arrays; -import java.util.List; - /** * Default conventions to populate observations for chat client workflows. * @@ -154,7 +153,10 @@ protected KeyValues toolCallbacks(KeyValues keyValues, ChatClientObservationCont return keyValues; } - var toolCallbackNames = toolCallbacks.stream().map(FunctionCallback::getName).sorted().toList(); + var toolCallbackNames = toolCallbacks.stream() + .map(toolCallback -> toolCallback.getToolDefinition().name()) + .sorted() + .toList(); return keyValues .and(ChatClientObservationDocumentation.HighCardinalityKeyNames.CHAT_CLIENT_TOOL_FUNCTION_CALLBACKS .asString(), TracingHelper.concatenateStrings(toolCallbackNames)); diff --git a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/ChatClientTest.java b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/ChatClientTest.java index 43bb6c81a53..352a7bdece3 100644 --- a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/ChatClientTest.java +++ b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/ChatClientTest.java @@ -40,9 +40,9 @@ import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.content.Media; -import org.springframework.ai.model.function.DefaultFunctionCallingOptions; -import org.springframework.ai.model.function.FunctionCallback; -import org.springframework.ai.model.function.FunctionCallingOptions; +import org.springframework.ai.model.tool.DefaultToolCallingChatOptions; +import org.springframework.ai.model.tool.ToolCallingChatOptions; +import org.springframework.ai.tool.function.FunctionToolCallback; import org.springframework.core.io.DefaultResourceLoader; import org.springframework.util.MimeTypeUtils; @@ -198,7 +198,7 @@ void defaultSystemTextLambda() { @Test void mutateDefaults() { - FunctionCallingOptions options = new DefaultFunctionCallingOptions(); + ToolCallingChatOptions options = new DefaultToolCallingChatOptions(); given(this.chatModel.getDefaultOptions()).willReturn(options); given(this.chatModel.call(this.promptCaptor.capture())) @@ -216,9 +216,8 @@ void mutateDefaults() { .defaultSystem(s -> s.text("Default system text {param1}, {param2}") .param("param1", "value1") .param("param2", "value2")) - .defaultFunctions("fun1", "fun2") - .defaultFunctions(FunctionCallback.builder() - .function("fun3", mockFunction) + .defaultTools("fun1", "fun2") + .defaultTools(FunctionToolCallback.builder("fun3", mockFunction) .description("fun3description") .inputType(String.class) .build()) @@ -246,10 +245,10 @@ void mutateDefaults() { assertThat(userMessage.getMedia()).hasSize(1); assertThat(userMessage.getMedia().iterator().next().getMimeType()).isEqualTo(MimeTypeUtils.IMAGE_JPEG); - var fco = (FunctionCallingOptions) prompt.getOptions(); + var fco = (ToolCallingChatOptions) prompt.getOptions(); - assertThat(fco.getFunctions()).containsExactly("fun1", "fun2"); - assertThat(fco.getFunctionCallbacks().iterator().next().getName()).isEqualTo("fun3"); + assertThat(fco.getToolNames()).containsExactlyInAnyOrder("fun1", "fun2"); + assertThat(fco.getToolCallbacks().iterator().next().getName()).isEqualTo("fun3"); // Streaming content = join(chatClient.prompt().stream().content()); @@ -268,16 +267,16 @@ void mutateDefaults() { assertThat(userMessage.getMedia()).hasSize(1); assertThat(userMessage.getMedia().iterator().next().getMimeType()).isEqualTo(MimeTypeUtils.IMAGE_JPEG); - fco = (FunctionCallingOptions) prompt.getOptions(); + fco = (ToolCallingChatOptions) prompt.getOptions(); - assertThat(fco.getFunctions()).containsExactly("fun1", "fun2"); - assertThat(fco.getFunctionCallbacks().iterator().next().getName()).isEqualTo("fun3"); + assertThat(fco.getToolNames()).containsExactlyInAnyOrder("fun1", "fun2"); + assertThat(fco.getToolCallbacks().iterator().next().getName()).isEqualTo("fun3"); // mutate builder // @formatter:off chatClient = chatClient.mutate() .defaultSystem("Mutated default system text {param1}, {param2}") - .defaultFunctions("fun4") + .defaultTools("fun4") .defaultUser("Mutated default user text {uparam1}, {uparam2}") .build(); // @formatter:on @@ -298,10 +297,10 @@ void mutateDefaults() { assertThat(userMessage.getMedia()).hasSize(1); assertThat(userMessage.getMedia().iterator().next().getMimeType()).isEqualTo(MimeTypeUtils.IMAGE_JPEG); - fco = (FunctionCallingOptions) prompt.getOptions(); + fco = (ToolCallingChatOptions) prompt.getOptions(); - assertThat(fco.getFunctions()).containsExactly("fun1", "fun2", "fun4"); - assertThat(fco.getFunctionCallbacks().iterator().next().getName()).isEqualTo("fun3"); + assertThat(fco.getToolNames()).containsExactlyInAnyOrder("fun1", "fun2", "fun4"); + assertThat(fco.getToolCallbacks().iterator().next().getName()).isEqualTo("fun3"); // Streaming content = join(chatClient.prompt().stream().content()); @@ -320,17 +319,17 @@ void mutateDefaults() { assertThat(userMessage.getMedia()).hasSize(1); assertThat(userMessage.getMedia().iterator().next().getMimeType()).isEqualTo(MimeTypeUtils.IMAGE_JPEG); - fco = (FunctionCallingOptions) prompt.getOptions(); + fco = (ToolCallingChatOptions) prompt.getOptions(); - assertThat(fco.getFunctions()).containsExactly("fun1", "fun2", "fun4"); - assertThat(fco.getFunctionCallbacks().iterator().next().getName()).isEqualTo("fun3"); + assertThat(fco.getToolNames()).containsExactlyInAnyOrder("fun1", "fun2", "fun4"); + assertThat(fco.getToolCallbacks().iterator().next().getName()).isEqualTo("fun3"); } @Test void mutatePrompt() { - FunctionCallingOptions options = new DefaultFunctionCallingOptions(); + ToolCallingChatOptions options = new DefaultToolCallingChatOptions(); given(this.chatModel.getDefaultOptions()).willReturn(options); given(this.chatModel.call(this.promptCaptor.capture())) @@ -347,9 +346,8 @@ void mutatePrompt() { .defaultSystem(s -> s.text("Default system text {param1}, {param2}") .param("param1", "value1") .param("param2", "value2")) - .defaultFunctions("fun1", "fun2") - .defaultFunctions(FunctionCallback.builder() - .function("fun3", mockFunction) + .defaultTools("fun1", "fun2") + .defaultTools(FunctionToolCallback.builder("fun3", mockFunction) .description("fun3description") .inputType(String.class) .build()) @@ -365,7 +363,7 @@ void mutatePrompt() { .system("New default system text {param1}, {param2}") .user(u -> u.param("uparam1", "userValue1") .param("uparam2", "userValue2")) - .functions("fun5") + .tools("fun5") .mutate().build() // mutate and build new prompt .prompt().call().content(); // @formatter:on @@ -384,10 +382,10 @@ void mutatePrompt() { assertThat(userMessage.getMedia()).hasSize(1); assertThat(userMessage.getMedia().iterator().next().getMimeType()).isEqualTo(MimeTypeUtils.IMAGE_JPEG); - var fco = (FunctionCallingOptions) prompt.getOptions(); + var tco = (ToolCallingChatOptions) prompt.getOptions(); - assertThat(fco.getFunctions()).containsExactly("fun1", "fun2", "fun5"); - assertThat(fco.getFunctionCallbacks().iterator().next().getName()).isEqualTo("fun3"); + assertThat(tco.getToolNames()).containsExactlyInAnyOrder("fun1", "fun2", "fun5"); + assertThat(tco.getToolCallbacks().iterator().next().getName()).isEqualTo("fun3"); // Streaming // @formatter:off @@ -396,7 +394,7 @@ void mutatePrompt() { .system("New default system text {param1}, {param2}") .user(u -> u.param("uparam1", "userValue1") .param("uparam2", "userValue2")) - .functions("fun5") + .tools("fun5") .mutate().build() // mutate and build new prompt .prompt().stream().content()); // @formatter:on @@ -415,10 +413,10 @@ void mutatePrompt() { assertThat(userMessage.getMedia()).hasSize(1); assertThat(userMessage.getMedia().iterator().next().getMimeType()).isEqualTo(MimeTypeUtils.IMAGE_JPEG); - fco = (FunctionCallingOptions) prompt.getOptions(); + var tcoptions = (ToolCallingChatOptions) prompt.getOptions(); - assertThat(fco.getFunctions()).containsExactly("fun1", "fun2", "fun5"); - assertThat(fco.getFunctionCallbacks().iterator().next().getName()).isEqualTo("fun3"); + assertThat(tcoptions.getToolNames()).containsExactlyInAnyOrder("fun1", "fun2", "fun5"); + assertThat(tcoptions.getToolCallbacks().iterator().next().getName()).isEqualTo("fun3"); } @Test @@ -517,7 +515,7 @@ void complexCall() throws MalformedURLException { given(this.chatModel.call(this.promptCaptor.capture())) .willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))); - var options = FunctionCallingOptions.builder().build(); + var options = ToolCallingChatOptions.builder().build(); given(this.chatModel.getDefaultOptions()).willReturn(options); var url = new URL("https://docs.spring.io/spring-ai/reference/_images/multimodal.test.png"); @@ -525,7 +523,7 @@ void complexCall() throws MalformedURLException { // @formatter:off ChatClient client = ChatClient.builder(this.chatModel) .defaultSystem("System text") - .defaultFunctions("function1") + .defaultTools("function1") .build(); String response = client.prompt() @@ -549,10 +547,10 @@ void complexCall() throws MalformedURLException { assertThat(userMessage.getMedia().iterator().next().getData()) .isEqualTo("https://docs.spring.io/spring-ai/reference/_images/multimodal.test.png"); - FunctionCallingOptions runtieOptions = (FunctionCallingOptions) this.promptCaptor.getValue().getOptions(); + ToolCallingChatOptions runtieOptions = (ToolCallingChatOptions) this.promptCaptor.getValue().getOptions(); - assertThat(runtieOptions.getFunctions()).containsExactly("function1"); - assertThat(options.getFunctions()).isEmpty(); + assertThat(runtieOptions.getToolNames()).containsExactly("function1"); + assertThat(options.getToolNames()).isEmpty(); } // Constructors diff --git a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientTests.java b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientTests.java index 78415dd7752..3a8d269fd77 100644 --- a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientTests.java +++ b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientTests.java @@ -27,15 +27,15 @@ import java.util.function.Consumer; import io.micrometer.observation.ObservationRegistry; +import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.mockito.ArgumentCaptor; -import org.springframework.ai.chat.client.advisor.api.BaseAdvisorChain; -import org.springframework.ai.chat.client.observation.ChatClientObservationConvention; -import org.springframework.ai.tool.ToolCallback; import reactor.core.publisher.Flux; import org.springframework.ai.chat.client.advisor.SimpleLoggerAdvisor; import org.springframework.ai.chat.client.advisor.api.Advisor; +import org.springframework.ai.chat.client.advisor.api.BaseAdvisorChain; +import org.springframework.ai.chat.client.observation.ChatClientObservationConvention; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.SystemMessage; @@ -48,7 +48,8 @@ import org.springframework.ai.content.Media; import org.springframework.ai.converter.ListOutputConverter; import org.springframework.ai.converter.StructuredOutputConverter; -import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.ai.tool.ToolCallback; +import org.springframework.ai.tool.function.FunctionToolCallback; import org.springframework.core.ParameterizedTypeReference; import org.springframework.core.convert.support.DefaultConversionService; import org.springframework.core.io.ClassPathResource; @@ -1467,7 +1468,7 @@ void whenToolNamesThenReturn() { String toolName = "myTool"; spec = spec.tools(toolName); DefaultChatClient.DefaultChatClientRequestSpec defaultSpec = (DefaultChatClient.DefaultChatClientRequestSpec) spec; - assertThat(defaultSpec.getFunctionNames()).contains(toolName); + assertThat(defaultSpec.getToolNames()).contains(toolName); } @Test @@ -1486,7 +1487,7 @@ void whenToolCallbacksThenReturn() { ToolCallback toolCallback = mock(ToolCallback.class); spec = spec.tools(toolCallback); DefaultChatClient.DefaultChatClientRequestSpec defaultSpec = (DefaultChatClient.DefaultChatClientRequestSpec) spec; - assertThat(defaultSpec.getFunctionCallbacks()).contains(toolCallback); + assertThat(defaultSpec.getToolCallbacks()).contains(toolCallback); } // FunctionCallback.builder().description("description").function(null,input->"hello").inputType(String.class).build() @@ -1495,108 +1496,106 @@ void whenToolCallbacksThenReturn() { void whenFunctionNameIsNullThenThrow() { ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); - assertThatThrownBy(() -> spec.functions(FunctionCallback.builder() - .function(null, input -> "hello") + assertThatThrownBy(() -> spec.tools(FunctionToolCallback.builder(null, input -> "hello") .description("description") .inputType(String.class) - .build())).isInstanceOf(IllegalArgumentException.class).hasMessage("Name must not be empty"); + .build())).isInstanceOf(IllegalArgumentException.class).hasMessage("name cannot be null or empty"); } @Test void whenFunctionNameIsEmptyThenThrow() { ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); - assertThatThrownBy(() -> spec.functions(FunctionCallback.builder() - .function("", input -> "hello") + assertThatThrownBy(() -> spec.tools(FunctionToolCallback.builder("", input -> "hello") .description("description") .inputType(String.class) - .build())).isInstanceOf(IllegalArgumentException.class).hasMessage("Name must not be empty"); + .build())).isInstanceOf(IllegalArgumentException.class).hasMessage("name cannot be null or empty"); } @Test + @Disabled("This fails now as the FunctionToolCallback description is allowed to be empty") void whenFunctionDescriptionIsNullThenThrow() { ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); - assertThatThrownBy(() -> spec.functions(FunctionCallback.builder() - .function("name", input -> "hello") + assertThatThrownBy(() -> spec.tools(FunctionToolCallback.builder("name", input -> "hello") .description(null) .inputType(String.class) .build())).isInstanceOf(IllegalArgumentException.class).hasMessage("Description must not be empty"); } @Test + @Disabled("This fails now as the FunctionToolCallback description is allowed to be empty") void whenFunctionDescriptionIsEmptyThenThrow() { ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); - assertThatThrownBy(() -> spec.functions(FunctionCallback.builder() - .function("name", input -> "hello") - .description("") - .inputType(String.class) - .build())).isInstanceOf(IllegalArgumentException.class).hasMessage("Description must not be empty"); + assertThatThrownBy(() -> spec.tools( + FunctionToolCallback.builder("name", input -> "hello").description("").inputType(String.class).build())) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Description must not be empty"); } @Test void whenFunctionThenReturn() { ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); - spec = spec.functions(FunctionCallback.builder() - .function("name", input -> "hello") + spec = spec.tools(FunctionToolCallback.builder("name", input -> "hello") .inputType(String.class) .description("description") .build()); DefaultChatClient.DefaultChatClientRequestSpec defaultSpec = (DefaultChatClient.DefaultChatClientRequestSpec) spec; - assertThat(defaultSpec.getFunctionCallbacks()).anyMatch(callback -> callback.getName().equals("name")); + assertThat(defaultSpec.getToolCallbacks()).anyMatch(callback -> callback.getName().equals("name")); } @Test void whenFunctionAndInputTypeThenReturn() { ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); - spec = spec.functions(FunctionCallback.builder() - .function("name", input -> "hello") + spec = spec.tools(FunctionToolCallback.builder("name", input -> "hello") .inputType(String.class) .description("description") .build()); DefaultChatClient.DefaultChatClientRequestSpec defaultSpec = (DefaultChatClient.DefaultChatClientRequestSpec) spec; - assertThat(defaultSpec.getFunctionCallbacks()).anyMatch(callback -> callback.getName().equals("name")); + assertThat(defaultSpec.getToolCallbacks()).anyMatch(callback -> callback.getName().equals("name")); } @Test void whenBiFunctionNameIsNullThenThrow() { ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); - assertThatThrownBy(() -> spec.functions( - FunctionCallback.builder().function(null, (input, ctx) -> "hello").description("description").build())) + assertThatThrownBy(() -> spec + .tools(FunctionToolCallback.builder(null, (input, ctx) -> "hello").description("description").build())) .isInstanceOf(IllegalArgumentException.class) - .hasMessage("Name must not be empty"); + .hasMessage("name cannot be null or empty"); } @Test void whenBiFunctionNameIsEmptyThenThrow() { ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); - assertThatThrownBy(() -> spec.functions( - FunctionCallback.builder().function("", (input, ctx) -> "hello").description("description").build())) + assertThatThrownBy(() -> spec + .tools(FunctionToolCallback.builder("", (input, ctx) -> "hello").description("description").build())) .isInstanceOf(IllegalArgumentException.class) - .hasMessage("Name must not be empty"); + .hasMessage("name cannot be null or empty"); } @Test + @Disabled("This fails now as the FunctionToolCallback description is allowed to be empty") void whenBiFunctionDescriptionIsNullThenThrow() { ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); - assertThatThrownBy(() -> spec - .functions(FunctionCallback.builder().function("name", (input, ctx) -> "hello").description(null).build())) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("Description must not be empty"); + assertThatThrownBy(() -> spec.tools(FunctionToolCallback.builder("name", (input, ctx) -> "hello") + .inputType(String.class) + .description(null) + .build())).isInstanceOf(IllegalArgumentException.class).hasMessage("Description must not be empty"); } @Test + @Disabled("This fails now as the FunctionToolCallback description is allowed to be empty") void whenBiFunctionDescriptionIsEmptyThenThrow() { ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); - assertThatThrownBy(() -> spec - .functions(FunctionCallback.builder().function("name", (input, ctx) -> "hello").description("").build())) + assertThatThrownBy( + () -> spec.tools(FunctionToolCallback.builder("name", (input, ctx) -> "hello").description("").build())) .isInstanceOf(IllegalArgumentException.class) .hasMessage("Description must not be empty"); } @@ -1605,20 +1604,19 @@ void whenBiFunctionDescriptionIsEmptyThenThrow() { void whenBiFunctionThenReturn() { ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); - spec = spec.functions(FunctionCallback.builder() - .function("name", (input, ctx) -> "hello") + spec = spec.tools(FunctionToolCallback.builder("name", (input, ctx) -> "hello") .description("description") .inputType(String.class) .build()); DefaultChatClient.DefaultChatClientRequestSpec defaultSpec = (DefaultChatClient.DefaultChatClientRequestSpec) spec; - assertThat(defaultSpec.getFunctionCallbacks()).anyMatch(callback -> callback.getName().equals("name")); + assertThat(defaultSpec.getToolCallbacks()).anyMatch(callback -> callback.getName().equals("name")); } @Test void whenFunctionBeanNamesElementIsNullThenThrow() { ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); - assertThatThrownBy(() -> spec.functions("myFunction", null)).isInstanceOf(IllegalArgumentException.class) + assertThatThrownBy(() -> spec.tools("myFunction", null)).isInstanceOf(IllegalArgumentException.class) .hasMessage("toolNames cannot contain null elements"); } @@ -1627,28 +1625,28 @@ void whenFunctionBeanNamesThenReturn() { ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); String functionBeanName = "myFunction"; - spec = spec.functions(functionBeanName); + spec = spec.tools(functionBeanName); DefaultChatClient.DefaultChatClientRequestSpec defaultSpec = (DefaultChatClient.DefaultChatClientRequestSpec) spec; - assertThat(defaultSpec.getFunctionNames()).contains(functionBeanName); + assertThat(defaultSpec.getToolNames()).contains(functionBeanName); } @Test void whenFunctionCallbacksElementIsNullThenThrow() { ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); - assertThatThrownBy(() -> spec.functions(mock(FunctionCallback.class), null)) + assertThatThrownBy(() -> spec.tools(mock(FunctionToolCallback.class), null)) .isInstanceOf(IllegalArgumentException.class) - .hasMessage("functionCallbacks cannot contain null elements"); + .hasMessage("toolCallbacks cannot contain null elements"); } @Test void whenFunctionCallbacksThenReturn() { ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); - FunctionCallback functionCallback = mock(FunctionCallback.class); - spec = spec.functions(functionCallback); + FunctionToolCallback functionToolCallback = mock(FunctionToolCallback.class); + spec = spec.tools(functionToolCallback); DefaultChatClient.DefaultChatClientRequestSpec defaultSpec = (DefaultChatClient.DefaultChatClientRequestSpec) spec; - assertThat(defaultSpec.getFunctionCallbacks()).contains(functionCallback); + assertThat(defaultSpec.getToolCallbacks()).contains(functionToolCallback); } @Test diff --git a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/api/AdvisedRequestTests.java b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/api/AdvisedRequestTests.java index 3e4e8e0c270..0cb3fc93ff7 100644 --- a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/api/AdvisedRequestTests.java +++ b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/api/AdvisedRequestTests.java @@ -89,7 +89,7 @@ void whenFunctionNamesIsNullThenThrows() { assertThatThrownBy(() -> new AdvisedRequest(mock(ChatModel.class), "user", null, null, List.of(), null, List.of(), List.of(), Map.of(), Map.of(), List.of(), Map.of(), Map.of(), Map.of())) .isInstanceOf(IllegalArgumentException.class) - .hasMessage("functionNames cannot be null"); + .hasMessage("toolNames cannot be null"); } @Test @@ -97,7 +97,7 @@ void whenFunctionCallbacksIsNullThenThrows() { assertThatThrownBy(() -> new AdvisedRequest(mock(ChatModel.class), "user", null, null, List.of(), List.of(), null, List.of(), Map.of(), Map.of(), List.of(), Map.of(), Map.of(), Map.of())) .isInstanceOf(IllegalArgumentException.class) - .hasMessage("functionCallbacks cannot be null"); + .hasMessage("toolCallbacks cannot be null"); } @Test @@ -180,7 +180,7 @@ void whenConvertToAndFromChatClientRequest() { .userText(userMessage.getText()) .userParams(userParams) .media(userMessage.getMedia()) - .functionNames(toolNames) + .toolNames(toolNames) .functionCallbacks(List.of(toolCallback)) .toolContext(toolContext) .advisors(advisors) diff --git a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/observation/DefaultChatClientObservationConventionTests.java b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/observation/DefaultChatClientObservationConventionTests.java index 36ba65f0eb6..bac155c1617 100644 --- a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/observation/DefaultChatClientObservationConventionTests.java +++ b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/observation/DefaultChatClientObservationConventionTests.java @@ -36,10 +36,11 @@ import org.springframework.ai.chat.client.observation.ChatClientObservationDocumentation.LowCardinalityKeyNames; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.tool.ToolCallingChatOptions; import org.springframework.ai.observation.conventions.AiProvider; import org.springframework.ai.observation.conventions.SpringAiKind; +import org.springframework.ai.tool.ToolCallback; +import org.springframework.ai.tool.definition.ToolDefinition; import static org.assertj.core.api.Assertions.assertThat; @@ -80,24 +81,12 @@ public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvis }; } - static FunctionCallback dummyFunction(String name) { - return new FunctionCallback() { + static ToolCallback dummyFunction(String name) { + return new ToolCallback() { @Override - public String getName() { - return name; - } - - @Override - public String getDescription() { - // TODO Auto-generated method stub - throw new UnsupportedOperationException("Unimplemented method 'getDescription'"); - } - - @Override - public String getInputTypeSchema() { - // TODO Auto-generated method stub - throw new UnsupportedOperationException("Unimplemented method 'getInputTypeSchema'"); + public ToolDefinition getToolDefinition() { + return ToolDefinition.builder().name(name).inputSchema("{}").build(); } @Override diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc index 7a89e6e772f..86e963deb0d 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc @@ -107,7 +107,7 @@ ** xref:api/chat/prompt-engineering-patterns.adoc[] * xref:api/testing.adoc[AI Model Evaluation] * xref:api/functions.adoc[Function Calling (Deprecated)] -** xref:api/function-callback.adoc[FunctionCallback API (Deprecated)] + * Service Connections ** xref:api/docker-compose.adoc[Docker Compose] diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/anthropic-chat-functions.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/anthropic-chat-functions.adoc index 37f1da40a0c..ef2410283e3 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/anthropic-chat-functions.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/anthropic-chat-functions.adoc @@ -127,11 +127,10 @@ Another way to register a function is to create a `FunctionCallback` instance li static class Config { @Bean - public FunctionCallback weatherFunctionInfo() { + public FunctionToolCallback weatherFunctionInfo() { - return FunctionCallback.builder() + return FunctionToolCallback.builder("CurrentWeather", new MockWeatherService()) // (1) function name and instance .description("Get the weather in location") // (2) function description - .function("CurrentWeather", new MockWeatherService()) // (1) function name and instance .inputType(MockWeatherService.Request.class) // (3) function signature .build(); } @@ -144,7 +143,7 @@ It also provides a description (2) and input type (3) used to generate the JSON NOTE: By default, the response converter does a JSON serialization of the Response object. -NOTE: The `FunctionCallback` internally resolves the function call signature based on the `MockWeatherService.Request` class. +NOTE: The `FunctionToolCallback` internally resolves the function call signature based on the `MockWeatherService.Request` class. === Specifying functions in Chat Options @@ -157,7 +156,7 @@ AnthropicChatModel chatModel = ... UserMessage userMessage = new UserMessage("What's the weather like in Paris?"); ChatResponse response = this.chatModel.call(new Prompt(List.of(this.userMessage), - AnthropicChatOptions.builder().function("CurrentWeather").build())); // (1) Enable the function + AnthropicChatOptions.builder().toolNames("CurrentWeather").build())); // (1) Enable the function logger.info("Response: {}", response); ---- @@ -177,8 +176,7 @@ AnthropicChatModel chatModel = ... UserMessage userMessage = new UserMessage("What's the weather like in Paris?"); var promptOptions = AnthropicChatOptions.builder() - .functionCallbacks(List.of(FunctionCallback.builder() - .function("CurrentWeather", new MockWeatherService()) // (1) function name and instance + .toolCallbacks(List.of(FunctionToolCallback.builder("CurrentWeather", new MockWeatherService()) // (1) function name and instance .description("Get the weather in location") // (2) function description .inputType(MockWeatherService.Request.class) // (3) function signature .build())) // function code diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/azure-open-ai-chat-functions.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/azure-open-ai-chat-functions.adoc index 8c92c8a30dd..4f2e985e9be 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/azure-open-ai-chat-functions.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/azure-open-ai-chat-functions.adoc @@ -123,10 +123,9 @@ Another way to register a function is to create a `FunctionCallback` instance li static class Config { @Bean - public FunctionCallback weatherFunctionInfo() { + public FunctionToolCallback weatherFunctionInfo() { - return FunctionCallback.builder() - .function("CurrentWeather", new MockWeatherService()) // (1) function name + return FunctionToolCallback.builder("CurrentWeather", new MockWeatherService()) // (1) function name .description("Get the current weather in a given location") // (2) function description .inputType(MockWeatherService.Request.class) // (3) function input type .build(); @@ -139,7 +138,7 @@ It wraps the 3rd party `MockWeatherService` function and registers it as a `Curr NOTE: The default response converter does a JSON serialization of the Response object. -NOTE: The `FunctionCallback` internally resolves the function call signature based on the `MockWeatherService.Request` class and internally generates an JSON schema for the function call. +NOTE: The `FunctionToolCallback` internally resolves the function call signature based on the `MockWeatherService.Request` class and internally generates an JSON schema for the function call. === Specifying functions in Chat Options @@ -152,7 +151,7 @@ AzureOpenAiChatModel chatModel = ... UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?"); ChatResponse response = this.chatModel.call(new Prompt(List.of(this.userMessage), - AzureOpenAiChatOptions.builder().function("CurrentWeather").build())); // (1) Enable the function + AzureOpenAiChatOptions.builder().tools("CurrentWeather").build())); // (1) Enable the function logger.info("Response: {}", response); ---- @@ -182,8 +181,7 @@ AzureOpenAiChatModel chatModel = ... UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris? Use Multi-turn function calling."); var promptOptions = AzureOpenAiChatOptions.builder() - .functionCallbacks(List.of(FunctionCallback.builder() - .function("CurrentWeather", new MockWeatherService()) // (1) function name and instance + .toolCallbacks(List.of(FunctionToolCallback.builder("CurrentWeather", new MockWeatherService()) // (1) function name and instance .description("Get the current weather in a given location") // (2) function description .inputType(MockWeatherService.Request.class) // (3) function input type .build())) diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/mistralai-chat-functions.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/mistralai-chat-functions.adoc index 438adbfc205..f8d1356f579 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/mistralai-chat-functions.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/mistralai-chat-functions.adoc @@ -117,7 +117,7 @@ Mistral AI is almost identical to OpenAI in this regard. ==== FunctionCallback Wrapper -Another way to register a function is to create a `FunctionCallback` like this: +Another way to register a function is to create a `FunctionToolCallback` like this: [source,java] ---- @@ -125,10 +125,9 @@ Another way to register a function is to create a `FunctionCallback` like this: static class Config { @Bean - public FunctionCallback weatherFunctionInfo() { + public FunctionToolCallback weatherFunctionInfo() { - return FunctionCallback.builder() - .function("CurrentWeather", new MockWeatherService()) // (1) function name and instance + return FunctionToolCallback.builder("CurrentWeather", new MockWeatherService()) // (1) function name and instance .description("Get the weather in location") // (2) function description .inputType(MockWeatherService.Request.class) // (3) function signature .build(); @@ -155,7 +154,7 @@ MistralAiChatModel chatModel = ... UserMessage userMessage = new UserMessage("What's the weather like in Paris?"); ChatResponse response = this.chatModel.call(new Prompt(this.userMessage, - MistralAiChatOptions.builder().function("CurrentWeather").build())); // Enable the function + MistralAiChatOptions.builder().tools("CurrentWeather").build())); // Enable the function logger.info("Response: {}", response); ---- @@ -175,8 +174,7 @@ MistralAiChatModel chatModel = ... UserMessage userMessage = new UserMessage("What's the weather like in Paris?"); var promptOptions = MistralAiChatOptions.builder() - .functionCallbacks(List.of(FunctionCallback.builder() - .function("CurrentWeather", new MockWeatherService()) // (1) function name and instance + .toolCallbacks(List.of(FunctionToolCallback.builder("CurrentWeather", new MockWeatherService()) // (1) function name and instance .description("Get the weather in location") // (2) function description .inputType(MockWeatherService.Request.class) // (3) function signature .build())) // function code diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/ollama-chat-functions.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/ollama-chat-functions.adoc index 94fc17155a8..131cadf5e49 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/ollama-chat-functions.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/ollama-chat-functions.adoc @@ -127,10 +127,9 @@ Another way to register a function is to create a `FunctionCallback` like this: static class Config { @Bean - public FunctionCallback weatherFunctionInfo() { + public FunctionToolCallback weatherFunctionInfo() { - return FunctionCallback.builder() - .function("CurrentWeather", new MockWeatherService()) // (1) function name + return FunctionToolCallback.builder("CurrentWeather", new MockWeatherService()) // (1) function name .description("Get the weather in location") // (2) function description .inputType(MockWeatherService.Request.class) // (3) function signature .build(); @@ -144,7 +143,7 @@ It also provides a description (2) and the function signature (3) to let the mod NOTE: By default, the response converter performs a JSON serialization of the Response object. -NOTE: The `FunctionCallback` internally resolves the function call signature based on the `MockWeatherService.Request` class. +NOTE: The `FunctionToolCallback` internally resolves the function call signature based on the `MockWeatherService.Request` class. === Specifying functions in Chat Options @@ -157,7 +156,7 @@ OllamaChatModel chatModel = ... UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?"); ChatResponse response = this.chatModel.call(new Prompt(this.userMessage, - OllamaOptions.builder().function("CurrentWeather").build())); // Enable the function + OllamaOptions.builder().tools("CurrentWeather").build())); // Enable the function logger.info("Response: {}", response); ---- @@ -186,8 +185,7 @@ OllamaChatModel chatModel = ... UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?"); var promptOptions = OllamaOptions.builder() - .functionCallbacks(List.of(FunctionCallback.builder() - .function("CurrentWeather", new MockWeatherService()) // (1) function name and instance + .toolCallbacks(List.of(FunctionToolCallback.builder("CurrentWeather", new MockWeatherService()) // (1) function name and instance .description("Get the weather in location") // (2) function description .inputType(MockWeatherService.Request.class) // (3) function signature .build())) // function code diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/openai-chat-functions.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/openai-chat-functions.adoc index 133cea23076..efc8f256bea 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/openai-chat-functions.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/openai-chat-functions.adoc @@ -122,10 +122,9 @@ Another way to register a function is to create a `FunctionCallback` like this: static class Config { @Bean - public FunctionCallback weatherFunctionInfo() { + public FunctionToolCallback weatherFunctionInfo() { - return FunctionCallback.builder() - .function("CurrentWeather", new MockWeatherService()) // (1) function name and instance + return FunctionToolCallback.builder("CurrentWeather", new MockWeatherService()) // (1) function name and instance .description("Get the weather in location") // (2) function description .inputType(MockWeatherService.Request.class) // (3) function input type .build(); @@ -139,7 +138,7 @@ It also provides a description (2) and an input type (3) used to generate the JS NOTE: By default, the response converter performs a JSON serialization of the Response object. -NOTE: The `FunctionCallback` internally resolves the function call signature based on the `MockWeatherService.Request` class. +NOTE: The `FunctionToolCallback` internally resolves the function call signature based on the `MockWeatherService.Request` class. === Specifying functions in Chat Options @@ -152,7 +151,7 @@ OpenAiChatModel chatModel = ... UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?"); ChatResponse response = this.chatModel.call(new Prompt(this.userMessage, - OpenAiChatOptions.builder().function("CurrentWeather").build())); // Enable the function + OpenAiChatOptions.builder().tools("CurrentWeather").build())); // Enable the function logger.info("Response: {}", response); ---- @@ -181,8 +180,7 @@ OpenAiChatModel chatModel = ... UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?"); var promptOptions = OpenAiChatOptions.builder() - .functionCallbacks(List.of(FunctionCallback.builder() - .function("CurrentWeather", new MockWeatherService()) // (1) function name and instance + .toolCallbacks(List.of(FunctionToolCallback.builder("CurrentWeather", new MockWeatherService()) // (1) function name and instance .description("Get the weather in location") // (2) function description .inputType(MockWeatherService.Request.class) // (3) function input type .build())) // function code @@ -232,8 +230,7 @@ BiFunction OpenAiChatOptions options = OpenAiChatOptions.builder() .model(OpenAiApi.ChatModel.GPT_4_O.getValue()) - .functionCallbacks(List.of(FunctionCallback.builder() - .function("getCurrentWeather", this.weatherFunction) + .toolCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", this.weatherFunction) .description("Get the weather in location") .inputType(MockWeatherService.Request.class) .build())) diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/functions.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/functions.adoc index aa796c5ae13..5de11ccda3b 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/functions.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/functions.adoc @@ -34,7 +34,7 @@ As a developer, you need to implement a function that takes the function call ar Spring AI makes this as easy as defining a `@Bean` definition that returns a `java.util.Function` and supplying the bean name as an option when invoking the `ChatClient` or registering the function dynamically in your prompt request. Under the hood, Spring wraps your POJO (the function) with the appropriate adapter code that enables interaction with the AI Model, saving you from writing tedious boilerplate code. -The basis of the underlying infrastructure is the link:https://github.com/spring-projects/spring-ai/blob/main/spring-ai-model/src/main/java/org/springframework/ai/model/function/FunctionCallback.java[FunctionCallback.java] interface and the companion Builder utility class to simplify the implementation and registration of Java callback functions. +The basis of the underlying infrastructure is the link:https://github.com/spring-projects/spring-ai/blob/main/spring-ai-core/src/main/java/org/springframework/ai/tool/function/FunctionToolCallback.java[FunctionToolCallback.java] interface and the companion Builder utility class to simplify the implementation and registration of Java callback functions. == How it works @@ -110,7 +110,7 @@ We start by describing the most POJO-friendly options. In this approach, you define a `@Bean` in your application context as you would any other Spring managed object. -Internally, Spring AI `ChatModel` will create an instance of a `FunctionCallback` that adds the logic for it being invoked via the AI model. +Internally, Spring AI `ChatModel` will create an instance of a `FunctionToolCallback` that adds the logic for it being invoked via the AI model. The name of the `@Bean` is used function name. -- @@ -191,9 +191,9 @@ data class Request(val location: String, val unit: Unit) It is a best practice to annotate the request object with information such that the generated JSON schema of that function is as descriptive as possible to help the AI model pick the correct function to invoke. -==== FunctionCallback +==== FunctionToolCallback -Another way to register a function is to create a `FunctionCallback` like this: +Another way to register a function is to create a `FunctionToolCallback` like this: -- [tabs] @@ -206,10 +206,9 @@ Java:: static class Config { @Bean - public FunctionCallback weatherFunctionInfo() { + public FunctionToolCallback weatherFunctionInfo() { - return FunctionCallback.builder() - .function("CurrentWeather", new MockWeatherService()) // (1) function name and instance + return FunctionToolCallback.builder("CurrentWeather", new MockWeatherService()) // (1) function name and instance .description("Get the weather in location") // (2) function description .inputType(MockWeatherService.Request.class) // (3) input type to build the JSON schema .build(); @@ -226,10 +225,9 @@ import org.springframework.ai.model.function.withInputType class Config { @Bean - fun weatherFunctionInfo(): FunctionCallback { + fun weatherFunctionInfo(): FunctionToolCallback { - return FunctionCallback.builder() - .function("CurrentWeather", MockWeatherService()) // (1) function name and instance + return FunctionToolCallback.builder("CurrentWeather", MockWeatherService()) // (1) function name and instance .description("Get the weather in location") // (2) function description // (3) Required due to Kotlin SAM conversion being an opaque lambda .inputType() @@ -246,7 +244,7 @@ It also provides a description (2) and an optional response converter to convert NOTE: By default, the response converter performs a JSON serialization of the Response object. -NOTE: The `FunctionCallback.Builder` internally resolves the function call signature based on the `MockWeatherService.Request` class. +NOTE: The `FunctionToolCallback.Builder` internally resolves the function call signature based on the `MockWeatherService.Request` class. === Enable functions by bean name @@ -257,7 +255,7 @@ To let the model know and call your `CurrentWeather` function you need to enable ChatClient chatClient = ... ChatResponse response = this.chatClient.prompt("What's the weather like in San Francisco, Tokyo, and Paris?") - .functions("CurrentWeather") // Enable the function + .tools("CurrentWeather") // Enable the function .call(). chatResponse(); @@ -289,8 +287,7 @@ The client-side registration enables you to register functions by default. ChatClient chatClient = ... ChatResponse response = this.chatClient.prompt("What's the weather like in San Francisco, Tokyo, and Paris?") - .functions(FunctionCallback.builder() - .function("currentWeather", (Request request) -> new Response(30.0, Unit.C)) // (1) function name and instance + .tools(FunctionToolCallback.builder("currentWeather", (Request request) -> new Response(30.0, Unit.C)) // (1) function name and instance .description("Get the weather in location") // (2) function description .inputType(MockWeatherService.Request.class) // (3) input type to build the JSON schema .build()) @@ -317,7 +314,7 @@ The `MethodInvokingFunctionCallback` implements the `FunctionCallback` interface - Any parameter/return types (primitives, objects, collections) - Special handling for `ToolContext` parameters -You need the `FunctionCallback.Builder` to create `MethodInvokingFunctionCallback` like this: +You need the `MethodToolCallback.Builder` to create `MethodInvokingFunctionCallback` like this: [source,java] ---- @@ -344,11 +341,14 @@ public class WeatherService { } // Usage -FunctionCallback callback = FunctionCallback.builder() - .method("getWeather", String.class, TemperatureUnit.class) - .description("Get weather information for a city") - .targetClass(WeatherService.class) - .build(); +var toolMethod = ReflectionUtils.findMethod(WeatherService.class, "getWeather", String.class, TemperatureUnit.class); +MethodToolCallback callback = MethodToolCallback.builder() + .toolDefinition(ToolDefinition.builder(toolMethod) + .description("Get the weather in location") + .build()) + .toolMethod(toolMethod) + .toolObject(targetObject) + .build(); ---- Instance Method with ToolContext:: + @@ -363,19 +363,21 @@ public class DeviceController { // Usage DeviceController controller = new DeviceController(); - +var toolMethod = ReflectionUtils.findMethod( + DeviceController.class, "setDeviceState", String.class, Boolean.class, ToolContext.class); String response = ChatClient.create(chatModel).prompt() - .user("Turn on the living room lights") - .functions(FunctionCallback.builder() - .method("setDeviceState", String.class,boolean.class,ToolContext.class) - .description("Control device state") - .targetObject(controller) - .build()) - .toolContext(Map.of("location", "home")) - .call() - .content(); + .user("Turn on the living room lights") + .tools(MethodToolCallback.builder() + .toolDefinition(ToolDefinition.builder(toolMethod) + .description("Control device state") + .build()) + .toolMethod(toolMethod) + .toolObject(controller) + .build()) + .toolContext(Map.of("location", "home")) + .call() + .content(); ---- - ====== The https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientMethodInvokingFunctionCallbackIT.java[OpenAiChatClientMethodInvokingFunctionCallbackIT] @@ -425,8 +427,7 @@ BiFunction ChatResponse response = chatClient.prompt("What's the weather like in San Francisco, Tokyo, and Paris?") - .functions(FunctionCallback.builder() - .function("getCurrentWeather", this.weatherFunction) + .tools(FunctionToolCallback.builder("getCurrentWeather", this.weatherFunction) .description("Get the weather in location") .inputType(MockWeatherService.Request.class) .build()) @@ -452,15 +453,18 @@ public class DeviceController { // Usage DeviceController controller = new DeviceController(); - +var toolMethod = ReflectionUtils.findMethod( + DeviceController.class, "setDeviceState", String.class, Boolean.class, ToolContext.class); String response = ChatClient.create(chatModel).prompt() - .user("Turn on the living room lights") - .functions(FunctionCallback.builder() - .method("setDeviceState", String.class,boolean.class,ToolContext.class) - .description("Control device state") - .targetObject(controller) - .build()) - .toolContext(Map.of("location", "home")) - .call() - .content(); + .user("Turn on the living room lights") + .tools(MethodToolCallback.builder() + .toolDefinition(ToolDefinition.builder(toolMethod) + .description("Control device state") + .build()) + .toolMethod(toolMethod) + .toolObject(controller) + .build()) + .toolContext(Map.of("location", "home")) + .call() + .content(); ---- diff --git a/spring-ai-model/src/main/java/org/springframework/ai/model/function/FunctionCallingHelper.java b/spring-ai-model/src/main/java/org/springframework/ai/model/function/FunctionCallingHelper.java deleted file mode 100644 index cdbe33ba15b..00000000000 --- a/spring-ai-model/src/main/java/org/springframework/ai/model/function/FunctionCallingHelper.java +++ /dev/null @@ -1,189 +0,0 @@ -/* - * Copyright 2023-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.ai.model.function; - -import java.util.ArrayList; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.Set; -import java.util.function.Function; - -import reactor.core.publisher.Flux; - -import org.springframework.ai.chat.messages.AssistantMessage; -import org.springframework.ai.chat.messages.Message; -import org.springframework.ai.chat.messages.ToolResponseMessage; -import org.springframework.ai.chat.model.AbstractToolCallSupport; -import org.springframework.ai.chat.model.ChatModel; -import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.ai.chat.model.Generation; -import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.model.ModelOptionsUtils; -import org.springframework.ai.model.tool.ToolCallingManager; -import org.springframework.util.CollectionUtils; - -/** - * Helper class that reuses the {@link AbstractToolCallSupport} to implement the function - * call handling logic on the client side. Used when the withProxyToolCalls(true) option - * is enabled. - * - * @deprecated Use {@link ToolCallingManager} instead. - */ -@Deprecated -public class FunctionCallingHelper extends AbstractToolCallSupport { - - public FunctionCallingHelper() { - this(null, FunctionCallingOptions.builder().build(), List.of()); - } - - public FunctionCallingHelper(FunctionCallbackResolver functionCallbackResolver, - FunctionCallingOptions functionCallingOptions, List toolFunctionCallbacks) { - super(functionCallbackResolver, functionCallingOptions, toolFunctionCallbacks); - } - - @Override - public boolean isToolCall(ChatResponse chatResponse, Set toolCallFinishReasons) { - return super.isToolCall(chatResponse, toolCallFinishReasons); - } - - @Override - public List buildToolCallConversation(List previousMessages, AssistantMessage assistantMessage, - ToolResponseMessage toolResponseMessage) { - return super.buildToolCallConversation(previousMessages, assistantMessage, toolResponseMessage); - } - - @Override - public List handleToolCalls(Prompt prompt, ChatResponse response) { - return super.handleToolCalls(prompt, response); - } - - public Flux processStream(ChatModel chatModel, Prompt prompt, Set finishReasons, - Function customFunction) { - - Flux chatResponses = chatModel.stream(prompt); - - return chatResponses.flatMap(chatResponse -> { - - boolean isToolCall = this.isToolCall(chatResponse, finishReasons); - - if (isToolCall) { - - Optional toolCallGeneration = chatResponse.getResults() - .stream() - .filter(g -> !CollectionUtils.isEmpty(g.getOutput().getToolCalls())) - .findFirst(); - - AssistantMessage assistantMessage = toolCallGeneration.get().getOutput(); - - List toolResponses = new ArrayList<>(); - - for (AssistantMessage.ToolCall toolCall : assistantMessage.getToolCalls()) { - - String functionResponse = customFunction.apply(toolCall); - - toolResponses.add(new ToolResponseMessage.ToolResponse(toolCall.id(), toolCall.name(), - ModelOptionsUtils.toJsonString(functionResponse))); - } - - ToolResponseMessage toolMessageResponse = new ToolResponseMessage(toolResponses, Map.of()); - - List toolCallConversation = this.buildToolCallConversation(prompt.getInstructions(), - assistantMessage, toolMessageResponse); - - var prompt2 = new Prompt(toolCallConversation, prompt.getOptions()); - - return processStream(chatModel, prompt2, finishReasons, customFunction); - } - - return Flux.just(chatResponse); - }); - } - - public ChatResponse processCall(ChatModel chatModel, Prompt prompt, Set finishReasons, - Function customFunction) { - - ChatResponse chatResponse = chatModel.call(prompt); - - boolean isToolCall = this.isToolCall(chatResponse, finishReasons); - - if (!isToolCall) { - return chatResponse; - } - - Optional toolCallGeneration = chatResponse.getResults() - .stream() - .filter(g -> !CollectionUtils.isEmpty(g.getOutput().getToolCalls())) - .findFirst(); - - AssistantMessage assistantMessage = toolCallGeneration.get().getOutput(); - - List toolResponses = new ArrayList<>(); - - for (AssistantMessage.ToolCall toolCall : assistantMessage.getToolCalls()) { - - String functionResponse = customFunction.apply(toolCall); - - toolResponses.add(new ToolResponseMessage.ToolResponse(toolCall.id(), toolCall.name(), - ModelOptionsUtils.toJsonString(functionResponse))); - } - - ToolResponseMessage toolMessageResponse = new ToolResponseMessage(toolResponses, Map.of()); - - List toolCallConversation = this.buildToolCallConversation(prompt.getInstructions(), assistantMessage, - toolMessageResponse); - - var prompt2 = new Prompt(toolCallConversation, prompt.getOptions()); - - return processCall(chatModel, prompt2, finishReasons, customFunction); - } - - /** - * Helper used to provide only the function definition, without the actual function - * call implementation. - * - * @param name the function name - * @param description the function description - * @param inputTypeSchema the input type schema - */ - public static record FunctionDefinition(String name, String description, - String inputTypeSchema) implements FunctionCallback { - - @Override - public String getName() { - return this.name(); - } - - @Override - public String getDescription() { - return this.description(); - } - - @Override - public String getInputTypeSchema() { - return this.inputTypeSchema(); - } - - @Override - public String call(String functionInput) { - throw new UnsupportedOperationException( - "FunctionDefinition provides only metadata. It doesn't implement the call method."); - } - - } - -} diff --git a/spring-ai-model/src/main/java/org/springframework/ai/model/function/FunctionCallingOptions.java b/spring-ai-model/src/main/java/org/springframework/ai/model/function/FunctionCallingOptions.java index 2363bc12390..e45e8db5f92 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/model/function/FunctionCallingOptions.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/model/function/FunctionCallingOptions.java @@ -44,10 +44,10 @@ static Builder builder() { /** * Function Callbacks to be registered with the ChatModel. For Prompt Options the - * functionCallbacks are automatically enabled for the duration of the prompt - * execution. For Default Options the FunctionCallbacks are registered but disabled by - * default. You have to use "functions" property to list the function names from the - * ChatModel registry to be used in the chat completion requests. + * toolCallbacks are automatically enabled for the duration of the prompt execution. + * For Default Options the FunctionCallbacks are registered but disabled by default. + * You have to use "functions" property to list the function names from the ChatModel + * registry to be used in the chat completion requests. * @return Return the Function Callbacks to be registered with the ChatModel. */ List getFunctionCallbacks(); diff --git a/spring-ai-model/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingChatOptions.java b/spring-ai-model/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingChatOptions.java index f5eca5496c0..870db6931b9 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingChatOptions.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingChatOptions.java @@ -25,7 +25,7 @@ import java.util.Set; import org.springframework.ai.chat.prompt.ChatOptions; -import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.ai.tool.ToolCallback; import org.springframework.lang.Nullable; import org.springframework.util.Assert; @@ -37,7 +37,7 @@ */ public class DefaultToolCallingChatOptions implements ToolCallingChatOptions { - private List toolCallbacks = new ArrayList<>(); + private List toolCallbacks = new ArrayList<>(); private Set toolNames = new HashSet<>(); @@ -71,12 +71,12 @@ public class DefaultToolCallingChatOptions implements ToolCallingChatOptions { private Double topP; @Override - public List getToolCallbacks() { + public List getToolCallbacks() { return List.copyOf(this.toolCallbacks); } @Override - public void setToolCallbacks(List toolCallbacks) { + public void setToolCallbacks(List toolCallbacks) { Assert.notNull(toolCallbacks, "toolCallbacks cannot be null"); Assert.noNullElements(toolCallbacks, "toolCallbacks cannot contain null elements"); this.toolCallbacks = new ArrayList<>(toolCallbacks); @@ -118,37 +118,6 @@ public void setInternalToolExecutionEnabled(@Nullable Boolean internalToolExecut this.internalToolExecutionEnabled = internalToolExecutionEnabled; } - @Override - public List getFunctionCallbacks() { - return getToolCallbacks(); - } - - @Override - public void setFunctionCallbacks(List functionCallbacks) { - setToolCallbacks(functionCallbacks); - } - - @Override - public Set getFunctions() { - return getToolNames(); - } - - @Override - public void setFunctions(Set functions) { - setToolNames(functions); - } - - @Override - @Nullable - public Boolean getProxyToolCalls() { - return getInternalToolExecutionEnabled() != null ? !getInternalToolExecutionEnabled() : null; - } - - @Override - public void setProxyToolCalls(@Nullable Boolean proxyToolCalls) { - setInternalToolExecutionEnabled(proxyToolCalls == null || !proxyToolCalls); - } - @Override @Nullable public String getModel() { @@ -260,13 +229,13 @@ public static class Builder implements ToolCallingChatOptions.Builder { private final DefaultToolCallingChatOptions options = new DefaultToolCallingChatOptions(); @Override - public ToolCallingChatOptions.Builder toolCallbacks(List toolCallbacks) { + public ToolCallingChatOptions.Builder toolCallbacks(List toolCallbacks) { this.options.setToolCallbacks(toolCallbacks); return this; } @Override - public ToolCallingChatOptions.Builder toolCallbacks(FunctionCallback... toolCallbacks) { + public ToolCallingChatOptions.Builder toolCallbacks(ToolCallback... toolCallbacks) { Assert.notNull(toolCallbacks, "toolCallbacks cannot be null"); this.options.setToolCallbacks(Arrays.asList(toolCallbacks)); return this; @@ -308,37 +277,6 @@ public ToolCallingChatOptions.Builder internalToolExecutionEnabled( return this; } - @Override - @Deprecated // Use toolCallbacks() instead - public ToolCallingChatOptions.Builder functionCallbacks(List functionCallbacks) { - return toolCallbacks(functionCallbacks); - } - - @Override - @Deprecated // Use toolCallbacks() instead - public ToolCallingChatOptions.Builder functionCallbacks(FunctionCallback... functionCallbacks) { - Assert.notNull(functionCallbacks, "functionCallbacks cannot be null"); - return functionCallbacks(List.of(functionCallbacks)); - } - - @Override - @Deprecated // Use toolNames() instead - public ToolCallingChatOptions.Builder functions(Set functions) { - return toolNames(functions); - } - - @Override - @Deprecated // Use toolNames() instead - public ToolCallingChatOptions.Builder function(String function) { - return toolNames(function); - } - - @Override - @Deprecated // Use internalToolExecutionEnabled() instead - public ToolCallingChatOptions.Builder proxyToolCalls(@Nullable Boolean proxyToolCalls) { - return internalToolExecutionEnabled(proxyToolCalls == null || !proxyToolCalls); - } - @Override public ToolCallingChatOptions.Builder model(@Nullable String model) { this.options.setModel(model); diff --git a/spring-ai-model/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingManager.java b/spring-ai-model/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingManager.java index bf885f8da29..a3afa93e517 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingManager.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingManager.java @@ -34,7 +34,6 @@ import org.springframework.ai.chat.model.ToolContext; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.function.FunctionCallback; -import org.springframework.ai.model.function.FunctionCallingOptions; import org.springframework.ai.tool.ToolCallback; import org.springframework.ai.tool.definition.ToolDefinition; import org.springframework.ai.tool.execution.DefaultToolExecutionExceptionProcessor; @@ -89,7 +88,7 @@ public DefaultToolCallingManager(ObservationRegistry observationRegistry, ToolCa public List resolveToolDefinitions(ToolCallingChatOptions chatOptions) { Assert.notNull(chatOptions, "chatOptions cannot be null"); - List toolCallbacks = new ArrayList<>(chatOptions.getToolCallbacks()); + List toolCallbacks = new ArrayList<>(chatOptions.getToolCallbacks()); for (String toolName : chatOptions.getToolNames()) { // Skip the tool if it is already present in the request toolCallbacks. // That might happen if a tool is defined in the options @@ -97,25 +96,14 @@ public List resolveToolDefinitions(ToolCallingChatOptions chatOp if (chatOptions.getToolCallbacks().stream().anyMatch(tool -> tool.getName().equals(toolName))) { continue; } - FunctionCallback toolCallback = this.toolCallbackResolver.resolve(toolName); + ToolCallback toolCallback = this.toolCallbackResolver.resolve(toolName); if (toolCallback == null) { throw new IllegalStateException("No ToolCallback found for tool name: " + toolName); } toolCallbacks.add(toolCallback); } - return toolCallbacks.stream().map(functionCallback -> { - if (functionCallback instanceof ToolCallback toolCallback) { - return toolCallback.getToolDefinition(); - } - else { - return ToolDefinition.builder() - .name(functionCallback.getName()) - .description(functionCallback.getDescription()) - .inputSchema(functionCallback.getInputTypeSchema()) - .build(); - } - }).toList(); + return toolCallbacks.stream().map(toolCallback -> toolCallback.getToolDefinition()).toList(); } @Override @@ -151,9 +139,9 @@ public ToolExecutionResult executeToolCalls(Prompt prompt, ChatResponse chatResp private static ToolContext buildToolContext(Prompt prompt, AssistantMessage assistantMessage) { Map toolContextMap = Map.of(); - if (prompt.getOptions() instanceof FunctionCallingOptions functionOptions - && !CollectionUtils.isEmpty(functionOptions.getToolContext())) { - toolContextMap = new HashMap<>(functionOptions.getToolContext()); + if (prompt.getOptions() instanceof ToolCallingChatOptions toolCallingChatOptions + && !CollectionUtils.isEmpty(toolCallingChatOptions.getToolContext())) { + toolContextMap = new HashMap<>(toolCallingChatOptions.getToolContext()); List messageHistory = new ArrayList<>(prompt.copy().getInstructions()); messageHistory.add(new AssistantMessage(assistantMessage.getText(), assistantMessage.getMetadata(), @@ -181,13 +169,10 @@ private static List buildConversationHistoryBeforeToolExecution(Prompt */ private InternalToolExecutionResult executeToolCall(Prompt prompt, AssistantMessage assistantMessage, ToolContext toolContext) { - List toolCallbacks = List.of(); + List toolCallbacks = List.of(); if (prompt.getOptions() instanceof ToolCallingChatOptions toolCallingChatOptions) { toolCallbacks = toolCallingChatOptions.getToolCallbacks(); } - else if (prompt.getOptions() instanceof FunctionCallingOptions functionOptions) { - toolCallbacks = functionOptions.getFunctionCallbacks(); - } List toolResponses = new ArrayList<>(); @@ -200,7 +185,7 @@ else if (prompt.getOptions() instanceof FunctionCallingOptions functionOptions) String toolName = toolCall.name(); String toolInputArguments = toolCall.arguments(); - FunctionCallback toolCallback = toolCallbacks.stream() + ToolCallback toolCallback = toolCallbacks.stream() .filter(tool -> toolName.equals(tool.getName())) .findFirst() .orElseGet(() -> this.toolCallbackResolver.resolve(toolName)); @@ -209,17 +194,11 @@ else if (prompt.getOptions() instanceof FunctionCallingOptions functionOptions) throw new IllegalStateException("No ToolCallback found for tool name: " + toolName); } - if (returnDirect == null && toolCallback instanceof ToolCallback callback) { - returnDirect = callback.getToolMetadata().returnDirect(); - } - else if (toolCallback instanceof ToolCallback callback) { - returnDirect = returnDirect && callback.getToolMetadata().returnDirect(); + if (returnDirect == null) { + returnDirect = toolCallback.getToolMetadata().returnDirect(); } - else if (returnDirect == null) { - // This is a temporary solution to ensure backward compatibility with - // FunctionCallback. - // TODO: remove this block when FunctionCallback is removed. - returnDirect = false; + else { + returnDirect = returnDirect && toolCallback.getToolMetadata().returnDirect(); } String toolResult; diff --git a/spring-ai-model/src/main/java/org/springframework/ai/model/tool/LegacyToolCallingManager.java b/spring-ai-model/src/main/java/org/springframework/ai/model/tool/LegacyToolCallingManager.java deleted file mode 100644 index 760c8537e71..00000000000 --- a/spring-ai-model/src/main/java/org/springframework/ai/model/tool/LegacyToolCallingManager.java +++ /dev/null @@ -1,251 +0,0 @@ -/* - * Copyright 2023-2025 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.ai.model.tool; - -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Optional; - -import org.springframework.ai.chat.messages.AssistantMessage; -import org.springframework.ai.chat.messages.Message; -import org.springframework.ai.chat.messages.ToolResponseMessage; -import org.springframework.ai.chat.model.AbstractToolCallSupport; -import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.ai.chat.model.Generation; -import org.springframework.ai.chat.model.ToolContext; -import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.model.function.FunctionCallback; -import org.springframework.ai.model.function.FunctionCallbackResolver; -import org.springframework.ai.model.function.FunctionCallingOptions; -import org.springframework.ai.tool.ToolCallback; -import org.springframework.ai.tool.definition.ToolDefinition; -import org.springframework.ai.tool.execution.DefaultToolExecutionExceptionProcessor; -import org.springframework.ai.tool.execution.ToolExecutionException; -import org.springframework.ai.tool.execution.ToolExecutionExceptionProcessor; -import org.springframework.lang.Nullable; -import org.springframework.util.Assert; -import org.springframework.util.CollectionUtils; - -/** - * Implementation of {@link ToolCallingManager} supporting the migration from - * {@link AbstractToolCallSupport} to {@link ToolCallingManager} and ensuring AI - * compatibility for all the ChatModel implementations. - * - * @author Thomas Vitale - * @since 1.0.0 - * @deprecated Only to help moving away from {@link AbstractToolCallSupport}. It will be - * removed in the next milestone. - */ -@Deprecated -public class LegacyToolCallingManager implements ToolCallingManager { - - @Nullable - private final FunctionCallbackResolver functionCallbackResolver; - - private final Map functionCallbacks = new HashMap<>(); - - private final ToolExecutionExceptionProcessor toolExecutionExceptionProcessor = DefaultToolExecutionExceptionProcessor - .builder() - .build(); - - public LegacyToolCallingManager(@Nullable FunctionCallbackResolver functionCallbackResolver, - @Nullable List functionCallbacks) { - this.functionCallbackResolver = functionCallbackResolver; - if (functionCallbacks != null) { - functionCallbacks.forEach(toolCallback -> this.functionCallbacks.put(toolCallback.getName(), toolCallback)); - } - } - - @Override - public List resolveToolDefinitions(ToolCallingChatOptions chatOptions) { - Assert.notNull(chatOptions, "chatOptions cannot be null"); - - List toolCallbacks = new ArrayList<>(chatOptions.getToolCallbacks()); - for (String toolName : chatOptions.getToolNames()) { - // Skip the tool if it is already present in the request toolCallbacks. - // That might happen if a tool is defined in the options - // both as a ToolCallback and as a tool name. - if (chatOptions.getToolCallbacks().stream().anyMatch(tool -> tool.getName().equals(toolName))) { - continue; - } - FunctionCallback toolCallback = resolveFunctionCallback(toolName); - if (toolCallback == null) { - throw new IllegalStateException("No ToolCallback found for tool name: " + toolName); - } - toolCallbacks.add(toolCallback); - } - - return toolCallbacks.stream().map(functionCallback -> { - if (functionCallback instanceof ToolCallback toolCallback) { - return toolCallback.getToolDefinition(); - } - else { - return ToolDefinition.builder() - .name(functionCallback.getName()) - .description(functionCallback.getDescription()) - .inputSchema(functionCallback.getInputTypeSchema()) - .build(); - } - }).toList(); - } - - @Nullable - private FunctionCallback resolveFunctionCallback(String toolName) { - Assert.hasText(toolName, "toolName cannot be null or empty"); - if (this.functionCallbacks.get(toolName) != null) { - return this.functionCallbacks.get(toolName); - } - return this.functionCallbackResolver != null ? this.functionCallbackResolver.resolve(toolName) : null; - } - - @Override - public ToolExecutionResult executeToolCalls(Prompt prompt, ChatResponse chatResponse) { - Assert.notNull(prompt, "prompt cannot be null"); - Assert.notNull(chatResponse, "chatResponse cannot be null"); - - Optional toolCallGeneration = chatResponse.getResults() - .stream() - .filter(g -> !CollectionUtils.isEmpty(g.getOutput().getToolCalls())) - .findFirst(); - - if (toolCallGeneration.isEmpty()) { - throw new IllegalStateException("No tool call requested by the chat model"); - } - - AssistantMessage assistantMessage = toolCallGeneration.get().getOutput(); - - ToolContext toolContext = buildToolContext(prompt, assistantMessage); - - ToolResponseMessage toolMessageResponse = executeToolCall(prompt, assistantMessage, toolContext); - - List conversationHistory = buildConversationHistoryAfterToolExecution(prompt.getInstructions(), - assistantMessage, toolMessageResponse); - - return ToolExecutionResult.builder().conversationHistory(conversationHistory).returnDirect(false).build(); - } - - private static ToolContext buildToolContext(Prompt prompt, AssistantMessage assistantMessage) { - Map toolContextMap = Map.of(); - - if (prompt.getOptions() instanceof FunctionCallingOptions functionOptions - && !CollectionUtils.isEmpty(functionOptions.getToolContext())) { - toolContextMap = new HashMap<>(functionOptions.getToolContext()); - - List messageHistory = new ArrayList<>(prompt.copy().getInstructions()); - messageHistory.add(new AssistantMessage(assistantMessage.getText(), assistantMessage.getMetadata(), - assistantMessage.getToolCalls())); - - toolContextMap.put(ToolContext.TOOL_CALL_HISTORY, - buildConversationHistoryBeforeToolExecution(prompt, assistantMessage)); - } - - return new ToolContext(toolContextMap); - } - - private static List buildConversationHistoryBeforeToolExecution(Prompt prompt, - AssistantMessage assistantMessage) { - List messageHistory = new ArrayList<>(prompt.copy().getInstructions()); - messageHistory.add(new AssistantMessage(assistantMessage.getText(), assistantMessage.getMetadata(), - assistantMessage.getToolCalls())); - return messageHistory; - } - - /** - * Execute the tool call and return the response message. To ensure backward - * compatibility, both {@link ToolCallback} and {@link FunctionCallback} are - * supported. - */ - private ToolResponseMessage executeToolCall(Prompt prompt, AssistantMessage assistantMessage, - ToolContext toolContext) { - List toolCallbacks = List.of(); - if (prompt.getOptions() instanceof ToolCallingChatOptions toolCallingChatOptions) { - toolCallbacks = toolCallingChatOptions.getToolCallbacks(); - } - else if (prompt.getOptions() instanceof FunctionCallingOptions functionOptions) { - toolCallbacks = functionOptions.getFunctionCallbacks(); - } - - List toolResponses = new ArrayList<>(); - - for (AssistantMessage.ToolCall toolCall : assistantMessage.getToolCalls()) { - - String toolName = toolCall.name(); - String toolInputArguments = toolCall.arguments(); - - FunctionCallback toolCallback = toolCallbacks.stream() - .filter(tool -> toolName.equals(tool.getName())) - .findFirst() - .orElseGet(() -> resolveFunctionCallback(toolName)); - - if (toolCallback == null) { - throw new IllegalStateException("No ToolCallback found for tool name: " + toolName); - } - - String toolResult; - try { - toolResult = toolCallback.call(toolInputArguments, toolContext); - } - catch (ToolExecutionException ex) { - toolResult = this.toolExecutionExceptionProcessor.process(ex); - } - - toolResponses.add(new ToolResponseMessage.ToolResponse(toolCall.id(), toolName, toolResult)); - } - - return new ToolResponseMessage(toolResponses, Map.of()); - } - - private List buildConversationHistoryAfterToolExecution(List previousMessages, - AssistantMessage assistantMessage, ToolResponseMessage toolResponseMessage) { - List messages = new ArrayList<>(previousMessages); - messages.add(assistantMessage); - messages.add(toolResponseMessage); - return messages; - } - - public static Builder builder() { - return new Builder(); - } - - public final static class Builder { - - private FunctionCallbackResolver functionCallbackResolver; - - private List functionCallbacks = new ArrayList<>(); - - private Builder() { - } - - public Builder functionCallbackResolver(FunctionCallbackResolver functionCallbackResolver) { - this.functionCallbackResolver = functionCallbackResolver; - return this; - } - - public Builder functionCallbacks(List functionCallbacks) { - this.functionCallbacks = functionCallbacks; - return this; - } - - public LegacyToolCallingManager build() { - return new LegacyToolCallingManager(this.functionCallbackResolver, this.functionCallbacks); - } - - } - -} diff --git a/spring-ai-model/src/main/java/org/springframework/ai/model/tool/ToolCallingChatOptions.java b/spring-ai-model/src/main/java/org/springframework/ai/model/tool/ToolCallingChatOptions.java index 50bd7adf230..17280809a5a 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/model/tool/ToolCallingChatOptions.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/model/tool/ToolCallingChatOptions.java @@ -25,8 +25,7 @@ import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.prompt.ChatOptions; -import org.springframework.ai.model.function.FunctionCallback; -import org.springframework.ai.model.function.FunctionCallingOptions; +import org.springframework.ai.tool.ToolCallback; import org.springframework.ai.tool.util.ToolUtils; import org.springframework.lang.Nullable; import org.springframework.util.Assert; @@ -40,19 +39,19 @@ * @author Ilayaperumal Gopinathan * @since 1.0.0 */ -public interface ToolCallingChatOptions extends FunctionCallingOptions { +public interface ToolCallingChatOptions extends ChatOptions { boolean DEFAULT_TOOL_EXECUTION_ENABLED = true; /** * ToolCallbacks to be registered with the ChatModel. */ - List getToolCallbacks(); + List getToolCallbacks(); /** * Set the ToolCallbacks to be registered with the ChatModel. */ - void setToolCallbacks(List toolCallbacks); + void setToolCallbacks(List toolCallbacks); /** * Names of the tools to register with the ChatModel. @@ -87,6 +86,18 @@ default Boolean isInternalToolExecutionEnabled() { */ void setInternalToolExecutionEnabled(@Nullable Boolean internalToolExecutionEnabled); + /** + * Get the configured tool context. + * @return the tool context map. + */ + Map getToolContext(); + + /** + * Set the tool context values as map. + * @param toolContext as map + */ + void setToolContext(Map toolContext); + /** * A builder to create a new {@link ToolCallingChatOptions} instance. */ @@ -102,10 +113,6 @@ static boolean isInternalToolExecutionEnabled(ChatOptions chatOptions) { internalToolExecutionEnabled = Boolean.TRUE .equals(toolCallingChatOptions.getInternalToolExecutionEnabled()); } - else if (chatOptions instanceof FunctionCallingOptions functionCallingOptions - && functionCallingOptions.getProxyToolCalls() != null) { - internalToolExecutionEnabled = Boolean.TRUE.equals(!functionCallingOptions.getProxyToolCalls()); - } else { internalToolExecutionEnabled = DEFAULT_TOOL_EXECUTION_ENABLED; } @@ -121,8 +128,8 @@ static Set mergeToolNames(Set runtimeToolNames, Set defa return new HashSet<>(runtimeToolNames); } - static List mergeToolCallbacks(List runtimeToolCallbacks, - List defaultToolCallbacks) { + static List mergeToolCallbacks(List runtimeToolCallbacks, + List defaultToolCallbacks) { Assert.notNull(runtimeToolCallbacks, "runtimeToolCallbacks cannot be null"); Assert.notNull(defaultToolCallbacks, "defaultToolCallbacks cannot be null"); if (CollectionUtils.isEmpty(runtimeToolCallbacks)) { @@ -142,7 +149,7 @@ static Map mergeToolContext(Map runtimeToolConte return mergedToolContext; } - static void validateToolCallbacks(List toolCallbacks) { + static void validateToolCallbacks(List toolCallbacks) { List duplicateToolNames = ToolUtils.getDuplicateToolNames(toolCallbacks); if (!duplicateToolNames.isEmpty()) { throw new IllegalStateException("Multiple tools with the same name (%s) found in ToolCallingChatOptions" @@ -153,17 +160,17 @@ static void validateToolCallbacks(List toolCallbacks) { /** * A builder to create a {@link ToolCallingChatOptions} instance. */ - interface Builder extends FunctionCallingOptions.Builder { + interface Builder extends ChatOptions.Builder { /** * ToolCallbacks to be registered with the ChatModel. */ - Builder toolCallbacks(List functionCallbacks); + Builder toolCallbacks(List toolCallbacks); /** * ToolCallbacks to be registered with the ChatModel. */ - Builder toolCallbacks(FunctionCallback... functionCallbacks); + Builder toolCallbacks(ToolCallback... toolCallbacks); /** * Names of the tools to register with the ChatModel. @@ -181,34 +188,21 @@ interface Builder extends FunctionCallingOptions.Builder { */ Builder internalToolExecutionEnabled(@Nullable Boolean internalToolExecutionEnabled); - // FunctionCallingOptions.Builder methods - - @Override + /** + * Add a {@link Map} of context values into tool context. + * @param context the map representing the tool context. + * @return the {@link ToolCallingChatOptions} Builder. + */ Builder toolContext(Map context); - @Override + /** + * Add a specific key/value pair to the tool context. + * @param key the key to use. + * @param value the corresponding value. + * @return the {@link ToolCallingChatOptions} Builder. + */ Builder toolContext(String key, Object value); - @Override - @Deprecated // Use toolCallbacks() instead - Builder functionCallbacks(List functionCallbacks); - - @Override - @Deprecated // Use toolCallbacks() instead - Builder functionCallbacks(FunctionCallback... functionCallbacks); - - @Override - @Deprecated // Use tools() instead - Builder functions(Set functions); - - @Override - @Deprecated // Use tools() instead - Builder function(String function); - - @Override - @Deprecated // Use internalToolExecutionEnabled() instead - Builder proxyToolCalls(@Nullable Boolean proxyToolCalls); - // ChatOptions.Builder methods @Override diff --git a/spring-ai-model/src/main/java/org/springframework/ai/tool/StaticToolCallbackProvider.java b/spring-ai-model/src/main/java/org/springframework/ai/tool/StaticToolCallbackProvider.java index 71116d80034..e4aa55fcbcd 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/tool/StaticToolCallbackProvider.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/tool/StaticToolCallbackProvider.java @@ -18,12 +18,11 @@ import java.util.List; -import org.springframework.ai.model.function.FunctionCallback; import org.springframework.util.Assert; /** * A simple implementation of {@link ToolCallbackProvider} that maintains a static array - * of {@link FunctionCallback} objects. This provider is immutable after construction and + * of {@link ToolCallback} objects. This provider is immutable after construction and * provides a straightforward way to supply a fixed set of tool callbacks to AI models. * *

@@ -46,11 +45,11 @@ * @author Christian Tzolov * @since 1.0.0 * @see ToolCallbackProvider - * @see FunctionCallback + * @see ToolCallback */ public class StaticToolCallbackProvider implements ToolCallbackProvider { - private final FunctionCallback[] toolCallbacks; + private final ToolCallback[] toolCallbacks; /** * Constructs a new StaticToolCallbackProvider with the specified array of function @@ -59,7 +58,7 @@ public class StaticToolCallbackProvider implements ToolCallbackProvider { * provider. Must not be null, though an empty array is permitted. * @throws IllegalArgumentException if the toolCallbacks array is null */ - public StaticToolCallbackProvider(FunctionCallback... toolCallbacks) { + public StaticToolCallbackProvider(ToolCallback... toolCallbacks) { Assert.notNull(toolCallbacks, "ToolCallbacks must not be null"); this.toolCallbacks = toolCallbacks; } @@ -72,9 +71,9 @@ public StaticToolCallbackProvider(FunctionCallback... toolCallbacks) { * @throws IllegalArgumentException if the toolCallbacks list is null or contains null * elements */ - public StaticToolCallbackProvider(List toolCallbacks) { + public StaticToolCallbackProvider(List toolCallbacks) { Assert.noNullElements(toolCallbacks, "toolCallbacks cannot contain null elements"); - this.toolCallbacks = toolCallbacks.toArray(new FunctionCallback[0]); + this.toolCallbacks = toolCallbacks.toArray(new ToolCallback[0]); } /** @@ -84,7 +83,7 @@ public StaticToolCallbackProvider(List toolCallbacks * are expected to be immutable. */ @Override - public FunctionCallback[] getToolCallbacks() { + public ToolCallback[] getToolCallbacks() { return this.toolCallbacks; } diff --git a/spring-ai-model/src/main/java/org/springframework/ai/tool/ToolCallbackProvider.java b/spring-ai-model/src/main/java/org/springframework/ai/tool/ToolCallbackProvider.java index 17ff1e73e84..960e6c83b58 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/tool/ToolCallbackProvider.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/tool/ToolCallbackProvider.java @@ -18,8 +18,6 @@ import java.util.List; -import org.springframework.ai.model.function.FunctionCallback; - /** * Provides {@link ToolCallback} instances for tools defined in different sources. * @@ -28,13 +26,13 @@ */ public interface ToolCallbackProvider { - FunctionCallback[] getToolCallbacks(); + ToolCallback[] getToolCallbacks(); - static ToolCallbackProvider from(List toolCallbacks) { + static ToolCallbackProvider from(List toolCallbacks) { return new StaticToolCallbackProvider(toolCallbacks); } - static ToolCallbackProvider from(FunctionCallback... toolCallbacks) { + static ToolCallbackProvider from(ToolCallback... toolCallbacks) { return new StaticToolCallbackProvider(toolCallbacks); } diff --git a/spring-ai-model/src/main/java/org/springframework/ai/tool/resolution/DelegatingToolCallbackResolver.java b/spring-ai-model/src/main/java/org/springframework/ai/tool/resolution/DelegatingToolCallbackResolver.java index 28ce9da98e0..dcb1249289c 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/tool/resolution/DelegatingToolCallbackResolver.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/tool/resolution/DelegatingToolCallbackResolver.java @@ -18,7 +18,7 @@ import java.util.List; -import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.ai.tool.ToolCallback; import org.springframework.lang.Nullable; import org.springframework.util.Assert; @@ -41,11 +41,11 @@ public DelegatingToolCallbackResolver(List toolCallbackRes @Override @Nullable - public FunctionCallback resolve(String toolName) { + public ToolCallback resolve(String toolName) { Assert.hasText(toolName, "toolName cannot be null or empty"); for (ToolCallbackResolver toolCallbackResolver : this.toolCallbackResolvers) { - FunctionCallback toolCallback = toolCallbackResolver.resolve(toolName); + ToolCallback toolCallback = toolCallbackResolver.resolve(toolName); if (toolCallback != null) { return toolCallback; } diff --git a/spring-ai-model/src/main/java/org/springframework/ai/tool/resolution/StaticToolCallbackResolver.java b/spring-ai-model/src/main/java/org/springframework/ai/tool/resolution/StaticToolCallbackResolver.java index 3b2a77b64de..c3848a18c43 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/tool/resolution/StaticToolCallbackResolver.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/tool/resolution/StaticToolCallbackResolver.java @@ -23,7 +23,6 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.tool.ToolCallback; import org.springframework.util.Assert; @@ -37,22 +36,18 @@ public class StaticToolCallbackResolver implements ToolCallbackResolver { private static final Logger logger = LoggerFactory.getLogger(StaticToolCallbackResolver.class); - private final Map toolCallbacks = new HashMap<>(); + private final Map toolCallbacks = new HashMap<>(); - public StaticToolCallbackResolver(List toolCallbacks) { + public StaticToolCallbackResolver(List toolCallbacks) { Assert.notNull(toolCallbacks, "toolCallbacks cannot be null"); Assert.noNullElements(toolCallbacks, "toolCallbacks cannot contain null elements"); - toolCallbacks.forEach(callback -> { - if (callback instanceof ToolCallback toolCallback) { - this.toolCallbacks.put(toolCallback.getToolDefinition().name(), toolCallback); - } - this.toolCallbacks.put(callback.getName(), callback); - }); + toolCallbacks + .forEach(toolCallback -> this.toolCallbacks.put(toolCallback.getToolDefinition().name(), toolCallback)); } @Override - public FunctionCallback resolve(String toolName) { + public ToolCallback resolve(String toolName) { Assert.hasText(toolName, "toolName cannot be null or empty"); logger.debug("ToolCallback resolution attempt from static registry"); return this.toolCallbacks.get(toolName); diff --git a/spring-ai-model/src/main/java/org/springframework/ai/tool/resolution/ToolCallbackResolver.java b/spring-ai-model/src/main/java/org/springframework/ai/tool/resolution/ToolCallbackResolver.java index 1155e4042e3..259be087309 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/tool/resolution/ToolCallbackResolver.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/tool/resolution/ToolCallbackResolver.java @@ -32,6 +32,6 @@ public interface ToolCallbackResolver { * Resolve the {@link FunctionCallback} for the given tool name. */ @Nullable - FunctionCallback resolve(String toolName); + ToolCallback resolve(String toolName); } diff --git a/spring-ai-model/src/main/java/org/springframework/ai/tool/util/ToolUtils.java b/spring-ai-model/src/main/java/org/springframework/ai/tool/util/ToolUtils.java index 4df5383e9f6..5d07d28311e 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/tool/util/ToolUtils.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/tool/util/ToolUtils.java @@ -22,7 +22,7 @@ import java.util.Map; import java.util.stream.Collectors; -import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.ai.tool.ToolCallback; import org.springframework.ai.tool.annotation.Tool; import org.springframework.ai.tool.execution.DefaultToolCallResultConverter; import org.springframework.ai.tool.execution.ToolCallResultConverter; @@ -84,10 +84,11 @@ public static ToolCallResultConverter getToolCallResultConverter(Method method) } } - public static List getDuplicateToolNames(List toolCallbacks) { + public static List getDuplicateToolNames(List toolCallbacks) { Assert.notNull(toolCallbacks, "toolCallbacks cannot be null"); return toolCallbacks.stream() - .collect(Collectors.groupingBy(FunctionCallback::getName, Collectors.counting())) + .collect(Collectors.groupingBy(toolCallback -> toolCallback.getToolDefinition().name(), + Collectors.counting())) .entrySet() .stream() .filter(entry -> entry.getValue() > 1) @@ -95,7 +96,7 @@ public static List getDuplicateToolNames(List toolCall .collect(Collectors.toList()); } - public static List getDuplicateToolNames(FunctionCallback... toolCallbacks) { + public static List getDuplicateToolNames(ToolCallback... toolCallbacks) { Assert.notNull(toolCallbacks, "toolCallbacks cannot be null"); return getDuplicateToolNames(Arrays.asList(toolCallbacks)); } diff --git a/spring-ai-model/src/test/java/org/springframework/ai/chat/prompt/ChatOptionsBuilderTests.java b/spring-ai-model/src/test/java/org/springframework/ai/chat/prompt/ChatOptionsBuilderTests.java index 247e82b6f00..bf8e0e1fd01 100644 --- a/spring-ai-model/src/test/java/org/springframework/ai/chat/prompt/ChatOptionsBuilderTests.java +++ b/spring-ai-model/src/test/java/org/springframework/ai/chat/prompt/ChatOptionsBuilderTests.java @@ -23,8 +23,8 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import org.springframework.ai.model.function.FunctionCallback; -import org.springframework.ai.model.function.FunctionCallingOptions; +import org.springframework.ai.model.tool.ToolCallingChatOptions; +import org.springframework.ai.tool.function.FunctionToolCallback; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; @@ -96,25 +96,24 @@ void shouldCopyOptions() { @Test void shouldUpcastToChatOptions() { // Given - FunctionCallback callback = FunctionCallback.builder() - .function("function1", x -> "result") + FunctionToolCallback callback = FunctionToolCallback.builder("function1", x -> "result") .description("Test function") .inputType(String.class) .build(); - FunctionCallingOptions functionOptions = FunctionCallingOptions.builder() + ToolCallingChatOptions toolCallingChatOptions = ToolCallingChatOptions.builder() .model("gpt-4") .maxTokens(100) .temperature(0.7) .topP(1.0) .topK(40) .stopSequences(List.of("stop1", "stop2")) - .functions(Set.of("function1", "function2")) - .functionCallbacks(List.of(callback)) + .toolNames(Set.of("function1", "function2")) + .toolCallbacks(List.of(callback)) .build(); // When - ChatOptions chatOptions = functionOptions; + ChatOptions chatOptions = toolCallingChatOptions; // Then assertThat(chatOptions.getModel()).isEqualTo("gpt-4"); diff --git a/spring-ai-model/src/test/java/org/springframework/ai/model/function/DefaultFunctionCallbackBuilderTests.java b/spring-ai-model/src/test/java/org/springframework/ai/model/function/DefaultFunctionCallbackBuilderTests.java deleted file mode 100644 index 6a414770849..00000000000 --- a/spring-ai-model/src/test/java/org/springframework/ai/model/function/DefaultFunctionCallbackBuilderTests.java +++ /dev/null @@ -1,333 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.ai.model.function; - -import java.util.function.BiFunction; -import java.util.function.Function; - -import org.junit.jupiter.api.Test; - -import org.springframework.ai.model.function.FunctionCallback.FunctionInvokingSpec; -import org.springframework.ai.model.function.FunctionCallback.MethodInvokingSpec; -import org.springframework.core.ParameterizedTypeReference; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; - -/** - * Unit tests for {@link DefaultFunctionCallbackBuilder}. - * - * @author Christian Tzolov - */ -class DefaultFunctionCallbackBuilderTests { - - // Function - @Test - void whenFunctionDescriptionIsNullThenThrow() { - assertThatThrownBy( - () -> FunctionCallback.builder().function("functionName", input -> "output").description(null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("Description must not be empty"); - } - - @Test - void whenFunctionDescriptionIsEmptyThenThrow() { - assertThatThrownBy(() -> FunctionCallback.builder().function("functionName", input -> "output").description("")) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("Description must not be empty"); - } - - @Test - void whenFunctionInputTypeSchemaIsNullThenThrow() { - assertThatThrownBy( - () -> FunctionCallback.builder().function("functionName", input -> "output").inputTypeSchema(null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("InputTypeSchema must not be empty"); - } - - @Test - void whenFunctionInputTypeSchemaIsEmptyThenThrow() { - assertThatThrownBy( - () -> FunctionCallback.builder().function("functionName", input -> "output").inputTypeSchema("")) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("InputTypeSchema must not be empty"); - } - - @Test - void whenFunctionSchemaTypeIsNullThenThrow() { - assertThatThrownBy( - () -> FunctionCallback.builder().function("functionName", input -> "output").schemaType(null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("SchemaType must not be null"); - } - - @Test - void whenFunctionResponseConverterIsNullThenThrow() { - assertThatThrownBy( - () -> FunctionCallback.builder().function("functionName", input -> "output").responseConverter(null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("ResponseConverter must not be null"); - } - - @Test - void whenFunctionNameIsNullThenThrow2() { - assertThatThrownBy(() -> FunctionCallback.builder().function(null, (Function) null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("Name must not be empty"); - } - - @Test - void whenFunctionIsNullThenThrow() { - assertThatThrownBy(() -> FunctionCallback.builder().function("functionName", (Function) null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("Function must not be null"); - } - - @Test - void whenFunctionThenReturn() { - FunctionInvokingSpec functionBuilder = FunctionCallback.builder() - .function("functionName", input -> "output"); - assertThat(functionBuilder).isNotNull(); - } - - @Test - void whenFunctionWithNullInputTypeThenThrow() { - assertThatThrownBy(() -> FunctionCallback.builder().function("functionName", input -> "output").build()) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("InputType must not be null"); - } - - @Test - void whenFunctionWithInputTypeThenReturn() { - FunctionCallback functionCallback = FunctionCallback.builder() - .function("functionName", input -> "output") - .description("description") - .inputType(String.class) - .build(); - assertThat(functionCallback).isNotNull(); - assertThat(functionCallback.getDescription()).isEqualTo("description"); - assertThat(functionCallback.getName()).isEqualTo("functionName"); - assertThat(functionCallback.getInputTypeSchema()).isNotEmpty(); - } - - @Test - void whenFunctionWithGeneratedDescriptionThenReturn() { - FunctionCallback functionCallback = FunctionCallback.builder() - .function("veryLongDescriptiveFunctionName", input -> "output") - .inputType(String.class) - .build(); - assertThat(functionCallback.getDescription()).isEqualTo("very long descriptive function name"); - assertThat(functionCallback.getName()).isEqualTo("veryLongDescriptiveFunctionName"); - } - - @Test - void whenFunctionWithGenericInputTypeThenReturn() { - FunctionCallback functionCallback = FunctionCallback.builder() - .function("functionName", input -> "output") - .inputType(new ParameterizedTypeReference>() { - }) - .build(); - assertThat(functionCallback.getName()).isEqualTo("functionName"); - assertThat(functionCallback.getInputTypeSchema()).isEqualTo(""" - { - "$schema" : "https://json-schema.org/draft/2020-12/schema", - "type" : "object", - "properties" : { - "datum" : { - "type" : "object", - "properties" : { - "value" : { - "type" : "string" - } - } - } - } - }"""); - } - - // BiFunction - @Test - void whenBiFunctionNameIsNullThenThrow2() { - assertThatThrownBy(() -> FunctionCallback.builder().function(null, (BiFunction) null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("Name must not be empty"); - } - - @Test - void whenBiFunctionIsNullThenThrow() { - assertThatThrownBy(() -> FunctionCallback.builder().function("functionName", (BiFunction) null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("BiFunction must not be null"); - } - - @Test - void whenBiFunctionThenReturn() { - FunctionInvokingSpec functionBuilder = FunctionCallback.builder() - .function("functionName", (input, context) -> "output"); - assertThat(functionBuilder).isNotNull(); - } - - // Method - - @Test - void whenMethodDescriptionIsNullThenThrow() { - assertThatThrownBy(() -> FunctionCallback.builder().method("methodName").description(null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("Description must not be empty"); - } - - @Test - void whenMethodDescriptionIsEmptyThenThrow() { - assertThatThrownBy(() -> FunctionCallback.builder().method("methodName").description("")) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("Description must not be empty"); - } - - @Test - void whenMethodInputTypeSchemaIsNullThenThrow() { - assertThatThrownBy(() -> FunctionCallback.builder().method("methodName").inputTypeSchema(null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("InputTypeSchema must not be empty"); - } - - @Test - void whenMethodInputTypeSchemaIsEmptyThenThrow() { - assertThatThrownBy(() -> FunctionCallback.builder().method("methodName").inputTypeSchema("")) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("InputTypeSchema must not be empty"); - } - - @Test - void whenMethodSchemaTypeIsNullThenThrow() { - assertThatThrownBy(() -> FunctionCallback.builder().method("methodName").schemaType(null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("SchemaType must not be null"); - } - - @Test - void whenMethodResponseConverterIsNullThenThrow() { - assertThatThrownBy(() -> FunctionCallback.builder().method("methodName").responseConverter(null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("ResponseConverter must not be null"); - } - - @Test - void whenMethodNameIsNullThenThrow() { - assertThatThrownBy(() -> FunctionCallback.builder().method(null)).isInstanceOf(IllegalArgumentException.class) - .hasMessage("Method name must not be null"); - } - - @Test - void whenMethodArgumentTypesIsNullThenThrow() { - assertThatThrownBy(() -> FunctionCallback.builder().method("methodName", (Class[]) null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("Argument types must not be null"); - } - - @Test - void whenMethodThenReturn() { - MethodInvokingSpec methodInvokeBuilder = FunctionCallback.builder().method("methodName"); - assertThat(methodInvokeBuilder).isNotNull(); - } - - @Test - void whenMethodWithArgumentTypesThenReturn() { - MethodInvokingSpec methodInvokeBuilder = FunctionCallback.builder() - .method("methodName", String.class, Integer.class); - assertThat(methodInvokeBuilder).isNotNull(); - } - - @Test - void whenMethodWithMissingTargetObjectOrTargetClassThenThrow() { - assertThatThrownBy(() -> FunctionCallback.builder().method("methodName").build()) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("Target class or object must not be null"); - } - - @Test - void whenMethodWithMissingTargetObjectThenThrow() { - assertThatThrownBy(() -> FunctionCallback.builder() - .method("methodName", String.class, Integer.class) - .targetClass(TestClass.class) - .build()).isInstanceOf(IllegalArgumentException.class) - .hasMessage("Function object must be provided for non-static methods!"); - } - - @Test - void whenMethodNotExistingThenThrow() { - assertThatThrownBy(() -> FunctionCallback.builder().method("methodName").targetClass(TestClass.class).build()) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("Method: 'methodName' with arguments:[] not found!"); - } - - @Test - void whenMethodAndNameIsNullThenThrow() { - assertThatThrownBy(() -> FunctionCallback.builder() - .method("staticMethodName", String.class, Integer.class) - .targetClass(TestClass.class) - .name(null) - .build()).isInstanceOf(IllegalArgumentException.class).hasMessage("Name must not be empty"); - } - - @Test - void whenMethodAndTargetClassThenReturn() { - var functionCallback = FunctionCallback.builder() - .method("staticMethodName", String.class, Integer.class) - .targetClass(TestClass.class) - .build(); - assertThat(functionCallback).isNotNull(); - } - - @Test - void whenMethodAndTargetObjectThenReturn() { - var functionCallback = FunctionCallback.builder() - .method("methodName", String.class, Integer.class) - .targetObject(new TestClass()) - .build(); - assertThat(functionCallback).isNotNull(); - } - - public static class TestClass { - - public static String staticMethodName(String arg1, Integer arg2) { - return arg1 + arg2; - } - - public String methodName(String arg1, Integer arg2) { - return arg1 + arg2; - } - - } - - public record Request(String value) { - } - - public static class GenericsRequest { - - private T datum; - - public T getDatum() { - return this.datum; - } - - public void setDatum(T value) { - this.datum = value; - } - - } - -} diff --git a/spring-ai-model/src/test/java/org/springframework/ai/model/function/DefaultFunctionCallingOptionsBuilderTests.java b/spring-ai-model/src/test/java/org/springframework/ai/model/function/DefaultFunctionCallingOptionsBuilderTests.java deleted file mode 100644 index fe9973eae55..00000000000 --- a/spring-ai-model/src/test/java/org/springframework/ai/model/function/DefaultFunctionCallingOptionsBuilderTests.java +++ /dev/null @@ -1,580 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.ai.model.function; - -import java.util.ArrayList; -import java.util.HashMap; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Set; - -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; - -import org.springframework.ai.chat.prompt.ChatOptions; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; - -/** - * Unit tests for {@link DefaultFunctionCallingOptionsBuilder}. - * - */ -class DefaultFunctionCallingOptionsBuilderTests { - - private DefaultFunctionCallingOptionsBuilder builder; - - @BeforeEach - void setUp() { - this.builder = new DefaultFunctionCallingOptionsBuilder(); - } - - // Tests for inherited ChatOptions properties - - @Test - void shouldBuildWithModel() { - // When - ChatOptions options = this.builder.model("gpt-4").build(); - - // Then - assertThat(options.getModel()).isEqualTo("gpt-4"); - } - - @Test - void shouldBuildWithFrequencyPenalty() { - // When - ChatOptions options = this.builder.frequencyPenalty(0.5).build(); - - // Then - assertThat(options.getFrequencyPenalty()).isEqualTo(0.5); - } - - @Test - void shouldBuildWithMaxTokens() { - // When - ChatOptions options = this.builder.maxTokens(100).build(); - - // Then - assertThat(options.getMaxTokens()).isEqualTo(100); - } - - @Test - void shouldBuildWithPresencePenalty() { - // When - ChatOptions options = this.builder.presencePenalty(0.7).build(); - - // Then - assertThat(options.getPresencePenalty()).isEqualTo(0.7); - } - - @Test - void shouldBuildWithStopSequences() { - // Given - List stopSequences = List.of("stop1", "stop2"); - - // When - ChatOptions options = this.builder.stopSequences(stopSequences).build(); - - // Then - assertThat(options.getStopSequences()).hasSize(2).containsExactlyElementsOf(stopSequences); - } - - @Test - void shouldBuildWithTemperature() { - // When - ChatOptions options = this.builder.temperature(0.8).build(); - - // Then - assertThat(options.getTemperature()).isEqualTo(0.8); - } - - @Test - void shouldBuildWithTopK() { - // When - ChatOptions options = this.builder.topK(5).build(); - - // Then - assertThat(options.getTopK()).isEqualTo(5); - } - - @Test - void shouldBuildWithTopP() { - // When - ChatOptions options = this.builder.topP(0.9).build(); - - // Then - assertThat(options.getTopP()).isEqualTo(0.9); - } - - @Test - void shouldBuildWithAllInheritedOptions() { - // When - ChatOptions options = this.builder.model("gpt-4") - .frequencyPenalty(0.5) - .maxTokens(100) - .presencePenalty(0.7) - .stopSequences(List.of("stop1", "stop2")) - .temperature(0.8) - .topK(5) - .topP(0.9) - .build(); - - // Then - assertThat(options.getModel()).isEqualTo("gpt-4"); - assertThat(options.getFrequencyPenalty()).isEqualTo(0.5); - assertThat(options.getMaxTokens()).isEqualTo(100); - assertThat(options.getPresencePenalty()).isEqualTo(0.7); - assertThat(options.getStopSequences()).containsExactly("stop1", "stop2"); - assertThat(options.getTemperature()).isEqualTo(0.8); - assertThat(options.getTopK()).isEqualTo(5); - assertThat(options.getTopP()).isEqualTo(0.9); - } - - // Original FunctionCallingOptions tests - - @Test - void shouldBuildWithFunctionCallbacksList() { - // Given - FunctionCallback callback1 = FunctionCallback.builder() - .function("test1", (String input) -> "result1") - .description("Test function 1") - .inputType(String.class) - .build(); - FunctionCallback callback2 = FunctionCallback.builder() - .function("test2", (String input) -> "result2") - .description("Test function 2") - .inputType(String.class) - .build(); - List callbacks = List.of(callback1, callback2); - - // When - FunctionCallingOptions options = this.builder.functionCallbacks(callbacks).build(); - - // Then - assertThat(options.getFunctionCallbacks()).hasSize(2).containsExactlyElementsOf(callbacks); - } - - @Test - void shouldBuildWithFunctionCallbacksVarargs() { - // Given - FunctionCallback callback1 = FunctionCallback.builder() - .function("test1", (String input) -> "result1") - .description("Test function 1") - .inputType(String.class) - .build(); - FunctionCallback callback2 = FunctionCallback.builder() - .function("test2", (String input) -> "result2") - .description("Test function 2") - .inputType(String.class) - .build(); - - // When - FunctionCallingOptions options = this.builder.functionCallbacks(callback1, callback2).build(); - - // Then - assertThat(options.getFunctionCallbacks()).hasSize(2).containsExactly(callback1, callback2); - } - - @Test - void shouldThrowExceptionWhenFunctionCallbacksVarargsIsNull() { - assertThatThrownBy(() -> this.builder.functionCallbacks((FunctionCallback[]) null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("FunctionCallbacks must not be null"); - } - - @Test - void shouldBuildWithFunctionsSet() { - // Given - Set functions = Set.of("function1", "function2"); - - // When - FunctionCallingOptions options = this.builder.functions(functions).build(); - - // Then - assertThat(options.getFunctions()).hasSize(2).containsExactlyInAnyOrderElementsOf(functions); - } - - @Test - void shouldBuildWithSingleFunction() { - // When - FunctionCallingOptions options = this.builder.function("function1").function("function2").build(); - - // Then - assertThat(options.getFunctions()).hasSize(2).containsExactlyInAnyOrder("function1", "function2"); - } - - @Test - void shouldThrowExceptionWhenFunctionIsNull() { - assertThatThrownBy(() -> this.builder.function(null)).isInstanceOf(IllegalArgumentException.class) - .hasMessage("Function must not be null"); - } - - @Test - void shouldBuildWithProxyToolCalls() { - // When - FunctionCallingOptions options = this.builder.proxyToolCalls(true).build(); - - // Then - assertThat(options.getProxyToolCalls()).isTrue(); - } - - @Test - void shouldBuildWithToolContextMap() { - // Given - Map context = Map.of("key1", "value1", "key2", 42); - - // When - FunctionCallingOptions options = this.builder.toolContext(context).build(); - - // Then - assertThat(options.getToolContext()).hasSize(2).containsAllEntriesOf(context); - } - - @Test - void shouldThrowExceptionWhenToolContextMapIsNull() { - assertThatThrownBy(() -> this.builder.toolContext((Map) null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("Tool context must not be null"); - } - - @Test - void shouldBuildWithToolContextKeyValue() { - // When - FunctionCallingOptions options = this.builder.toolContext("key1", "value1").toolContext("key2", 42).build(); - - // Then - assertThat(options.getToolContext()).hasSize(2).containsEntry("key1", "value1").containsEntry("key2", 42); - } - - @Test - void shouldThrowExceptionWhenToolContextKeyIsNull() { - assertThatThrownBy(() -> this.builder.toolContext(null, "value")).isInstanceOf(IllegalArgumentException.class) - .hasMessage("Key must not be null"); - } - - @Test - void shouldThrowExceptionWhenToolContextValueIsNull() { - assertThatThrownBy(() -> this.builder.toolContext("key", null)).isInstanceOf(IllegalArgumentException.class) - .hasMessage("Value must not be null"); - } - - @Test - void shouldMergeToolContextMaps() { - // Given - Map context1 = Map.of("key1", "value1", "key2", 42); - Map context2 = Map.of("key2", "updated", "key3", true); - - // When - FunctionCallingOptions options = this.builder.toolContext(context1).toolContext(context2).build(); - - // Then - assertThat(options.getToolContext()).hasSize(3) - .containsEntry("key1", "value1") - .containsEntry("key2", "updated") - .containsEntry("key3", true); - } - - @Test - void shouldBuildWithAllOptions() { - // Given - FunctionCallback callback = FunctionCallback.builder() - .function("test", (String input) -> "result") - .description("Test function") - .inputType(String.class) - .build(); - Set functions = Set.of("function1"); - Map context = Map.of("key1", "value1"); - - // When - FunctionCallingOptions options = this.builder.model("gpt-4") - .frequencyPenalty(0.5) - .maxTokens(100) - .presencePenalty(0.7) - .stopSequences(List.of("stop1", "stop2")) - .temperature(0.8) - .topK(5) - .topP(0.9) - .functionCallbacks(callback) - .functions(functions) - .proxyToolCalls(true) - .toolContext(context) - .build(); - - // Then - assertThat(options.getFunctionCallbacks()).hasSize(1).containsExactly(callback); - assertThat(options.getFunctions()).hasSize(1).containsExactlyElementsOf(functions); - assertThat(options.getProxyToolCalls()).isTrue(); - assertThat(options.getToolContext()).hasSize(1).containsAllEntriesOf(context); - - ChatOptions chatOptions = options; - assertThat(chatOptions.getModel()).isEqualTo("gpt-4"); - assertThat(chatOptions.getFrequencyPenalty()).isEqualTo(0.5); - assertThat(chatOptions.getMaxTokens()).isEqualTo(100); - assertThat(chatOptions.getPresencePenalty()).isEqualTo(0.7); - assertThat(chatOptions.getStopSequences()).containsExactly("stop1", "stop2"); - assertThat(chatOptions.getTemperature()).isEqualTo(0.8); - assertThat(chatOptions.getTopK()).isEqualTo(5); - assertThat(chatOptions.getTopP()).isEqualTo(0.9); - } - - @Test - void shouldBuildWithEmptyFunctionCallbacks() { - // When - FunctionCallingOptions options = this.builder.functionCallbacks(List.of()).build(); - - // Then - assertThat(options.getFunctionCallbacks()).isEmpty(); - } - - @Test - void shouldBuildWithEmptyFunctions() { - // When - FunctionCallingOptions options = this.builder.functions(Set.of()).build(); - - // Then - assertThat(options.getFunctions()).isEmpty(); - } - - @Test - void shouldBuildWithEmptyToolContext() { - // When - FunctionCallingOptions options = this.builder.toolContext(Map.of()).build(); - - // Then - assertThat(options.getToolContext()).isEmpty(); - } - - @Test - void shouldDeduplicateFunctions() { - // When - FunctionCallingOptions options = this.builder.function("function1") - .function("function1") // Duplicate - .function("function2") - .build(); - - // Then - assertThat(options.getFunctions()).hasSize(2).containsExactlyInAnyOrder("function1", "function2"); - } - - @Test - void shouldCopyAllOptions() { - // Given - FunctionCallback callback = FunctionCallback.builder() - .function("test", (String input) -> "result") - .description("Test function") - .inputType(String.class) - .build(); - FunctionCallingOptions original = this.builder.model("gpt-4") - .frequencyPenalty(0.5) - .maxTokens(100) - .presencePenalty(0.7) - .stopSequences(List.of("stop1", "stop2")) - .temperature(0.8) - .topK(5) - .topP(0.9) - .functionCallbacks(callback) - .function("function1") - .proxyToolCalls(true) - .toolContext("key1", "value1") - .build(); - - // When - FunctionCallingOptions copy = original.copy(); - - // Then - assertThat(copy).usingRecursiveComparison().isEqualTo(original); - // Verify collections are actually copied - assertThat(copy.getFunctionCallbacks()).isNotSameAs(original.getFunctionCallbacks()); - assertThat(copy.getFunctions()).isNotSameAs(original.getFunctions()); - assertThat(copy.getToolContext()).isNotSameAs(original.getToolContext()); - } - - @Test - void shouldMergeWithFunctionCallingOptions() { - // Given - FunctionCallback callback1 = FunctionCallback.builder() - .function("test1", (String input) -> "result1") - .description("Test function 1") - .inputType(String.class) - .build(); - FunctionCallback callback2 = FunctionCallback.builder() - .function("test2", (String input) -> "result2") - .description("Test function 2") - .inputType(String.class) - .build(); - - DefaultFunctionCallingOptions options1 = (DefaultFunctionCallingOptions) this.builder.model("gpt-4") - .temperature(0.8) - .functionCallbacks(callback1) - .function("function1") - .proxyToolCalls(true) - .toolContext("key1", "value1") - .build(); - - DefaultFunctionCallingOptions options2 = (DefaultFunctionCallingOptions) FunctionCallingOptions.builder() - .model("gpt-3.5") - .maxTokens(100) - .functionCallbacks(callback2) - .function("function2") - .proxyToolCalls(false) - .toolContext("key2", "value2") - .build(); - - // When - FunctionCallingOptions merged = options1.merge(options2); - - // Then - assertThat(merged.getModel()).isEqualTo("gpt-3.5"); // Overridden - assertThat(merged.getTemperature()).isEqualTo(0.8); // Kept - assertThat(merged.getMaxTokens()).isEqualTo(100); // Added - assertThat(merged.getFunctionCallbacks()).containsExactly(callback1, callback2); // Combined - assertThat(merged.getFunctions()).containsExactlyInAnyOrder("function1", "function2"); // Combined - assertThat(merged.getProxyToolCalls()).isFalse(); // Overridden - assertThat(merged.getToolContext()).containsEntry("key1", "value1").containsEntry("key2", "value2"); // Combined - } - - @Test - void shouldMergeWithChatOptions() { - // Given - FunctionCallback callback = FunctionCallback.builder() - .function("test", (String input) -> "result") - .description("Test function") - .inputType(String.class) - .build(); - - DefaultFunctionCallingOptions options1 = (DefaultFunctionCallingOptions) this.builder.model("gpt-4") - .temperature(0.8) - .functionCallbacks(callback) - .function("function1") - .proxyToolCalls(true) - .toolContext("key1", "value1") - .build(); - - ChatOptions options2 = ChatOptions.builder().model("gpt-3.5").maxTokens(100).build(); - - // When - FunctionCallingOptions merged = options1.merge(options2); - - // Then - assertThat(merged.getModel()).isEqualTo("gpt-3.5"); // Overridden - assertThat(merged.getTemperature()).isEqualTo(0.8); // Kept - assertThat(merged.getMaxTokens()).isEqualTo(100); // Added - // Function-specific options should be preserved - assertThat(merged.getFunctionCallbacks()).containsExactly(callback); - assertThat(merged.getFunctions()).containsExactly("function1"); - assertThat(merged.getProxyToolCalls()).isTrue(); - assertThat(merged.getToolContext()).containsEntry("key1", "value1"); - } - - @Test - void shouldAllowBuilderReuse() { - // Given - FunctionCallback callback1 = FunctionCallback.builder() - .function("test1", (String input) -> "result1") - .description("Test function 1") - .inputType(String.class) - .build(); - FunctionCallback callback2 = FunctionCallback.builder() - .function("test2", (String input) -> "result2") - .description("Test function 2") - .inputType(String.class) - .build(); - - // When - FunctionCallingOptions options1 = this.builder.model("model1") - .temperature(0.7) - .functionCallbacks(callback1) - .build(); - - FunctionCallingOptions options2 = this.builder.model("model2").functionCallbacks(callback2).build(); - - // Then - assertThat(options1.getModel()).isEqualTo("model1"); - assertThat(options1.getTemperature()).isEqualTo(0.7); - assertThat(options1.getFunctionCallbacks()).containsExactly(callback1); - - assertThat(options2.getModel()).isEqualTo("model2"); - assertThat(options2.getTemperature()).isEqualTo(0.7); // Retains previous value - assertThat(options2.getFunctionCallbacks()).containsExactly(callback2); // Replaces - // previous - // callbacks - } - - @Test - void shouldReturnSameBuilderInstanceOnEachMethod() { - // When - FunctionCallingOptions.Builder returnedBuilder = this.builder.model("test"); - - // Then - assertThat(returnedBuilder).isSameAs(this.builder); - } - - @Test - void shouldHaveExpectedDefaultValues() { - // When - FunctionCallingOptions options = this.builder.build(); - - // Then - // ChatOptions defaults - assertThat(options.getModel()).isNull(); - assertThat(options.getTemperature()).isNull(); - assertThat(options.getMaxTokens()).isNull(); - assertThat(options.getTopP()).isNull(); - assertThat(options.getTopK()).isNull(); - assertThat(options.getFrequencyPenalty()).isNull(); - assertThat(options.getPresencePenalty()).isNull(); - assertThat(options.getStopSequences()).isNull(); - - // FunctionCallingOptions specific defaults - assertThat(options.getFunctionCallbacks()).isEmpty(); - assertThat(options.getFunctions()).isEmpty(); - assertThat(options.getToolContext()).isEmpty(); - assertThat(options.getProxyToolCalls()).isFalse(); - } - - @Test - void shouldBeImmutableAfterBuild() { - // Given - FunctionCallback callback = FunctionCallback.builder() - .function("test", (String input) -> "result") - .description("Test function") - .inputType(String.class) - .build(); - - List stopSequences = new ArrayList<>(List.of("stop1", "stop2")); - Set functions = new HashSet<>(Set.of("function1", "function2")); - Map context = new HashMap<>(Map.of("key1", "value1")); - - FunctionCallingOptions options = this.builder.stopSequences(stopSequences) - .functionCallbacks(callback) - .functions(functions) - .toolContext(context) - .build(); - - // Then - assertThatThrownBy(() -> options.getStopSequences().add("stop3")) - .isInstanceOf(UnsupportedOperationException.class); - assertThatThrownBy(() -> options.getFunctionCallbacks().add(callback)) - .isInstanceOf(UnsupportedOperationException.class); - assertThatThrownBy(() -> options.getFunctions().add("function3")) - .isInstanceOf(UnsupportedOperationException.class); - assertThatThrownBy(() -> options.getToolContext().put("key2", "value2")) - .isInstanceOf(UnsupportedOperationException.class); - } - -} diff --git a/spring-ai-model/src/test/java/org/springframework/ai/model/function/MethodInvokingFunctionCallbackTests.java b/spring-ai-model/src/test/java/org/springframework/ai/model/function/MethodInvokingFunctionCallbackTests.java deleted file mode 100644 index ec3e0c300c2..00000000000 --- a/spring-ai-model/src/test/java/org/springframework/ai/model/function/MethodInvokingFunctionCallbackTests.java +++ /dev/null @@ -1,161 +0,0 @@ -/* - * Copyright 2023-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.ai.model.function; - -import java.util.List; -import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; - -import com.fasterxml.jackson.databind.ObjectMapper; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; - -import static org.assertj.core.api.Assertions.assertThat; - -/** - * @author Christian Tzolov - * @since 1.0.0 - */ -public class MethodInvokingFunctionCallbackTests { - - private static final Map arguments = new ConcurrentHashMap<>(); - - String value = """ - { - "unit": "CELSIUS", - "city": "Barcelona", - "intNumber": 123, - "record": { - "foo": "foo", - "bar": "bar" - }, - "intList": [1, 2, 3] - } - """; - - @BeforeEach - public void beforeEach() { - arguments.clear(); - } - - @Test - public void staticMethod() throws NoSuchMethodException, SecurityException { - - var functionCallback = FunctionCallback.builder() - .method("myStaticMethod", String.class, Unit.class, int.class, MyRecord.class, List.class) - .description("weather at location") - .objectMapper(new ObjectMapper()) - .targetClass(TestClassWithFunctionMethods.class) - .build(); - - String response = functionCallback.call(this.value); - - assertThat(response).isEqualTo("23"); - - assertThat(arguments).hasSize(5); - assertThat(arguments.get("city")).isEqualTo("Barcelona"); - assertThat(arguments.get("unit")).isEqualTo(Unit.CELSIUS); - assertThat(arguments.get("intNumber")).isEqualTo(123); - assertThat(arguments.get("record")).isEqualTo(new MyRecord("foo", "bar")); - assertThat(arguments.get("intList")).isEqualTo(List.of(1, 2, 3)); - } - - @Test - public void nonStaticMethod() throws NoSuchMethodException, SecurityException { - - var object = new TestClassWithFunctionMethods(); - - var functionCallback = FunctionCallback.builder() - .method("myNonStaticMethod", String.class, Unit.class, int.class, MyRecord.class, List.class) - .description("weather at location") - .targetObject(object) - .build(); - - String response = functionCallback.call(this.value); - - assertThat(response).isEqualTo("23"); - - assertThat(arguments).hasSize(5); - assertThat(arguments.get("city")).isEqualTo("Barcelona"); - assertThat(arguments.get("unit")).isEqualTo(Unit.CELSIUS); - assertThat(arguments.get("intNumber")).isEqualTo(123); - assertThat(arguments.get("record")).isEqualTo(new MyRecord("foo", "bar")); - assertThat(arguments.get("intList")).isEqualTo(List.of(1, 2, 3)); - } - - @Test - public void noArgsNoReturnMethod() throws NoSuchMethodException, SecurityException { - - var functionCallback = FunctionCallback.builder() - .method("argumentLessReturnVoid") - .description("weather at location") - .objectMapper(new ObjectMapper()) - .targetClass(TestClassWithFunctionMethods.class) - .build(); - - String response = functionCallback.call(this.value); - - assertThat(response).isEqualTo("Done"); - - assertThat(arguments.get("method called")).isEqualTo("argumentLessReturnVoid"); - } - - record MyRecord(String foo, String bar) { - } - - public enum Unit { - - CELSIUS, FAHRENHEIT - - } - - public static class TestClassWithFunctionMethods { - - public static void argumentLessReturnVoid() { - arguments.put("method called", "argumentLessReturnVoid"); - } - - public static String myStaticMethod(String city, Unit unit, int intNumber, MyRecord record, - List intList) { - System.out.println("City: " + city + " Unit: " + unit + " intNumber: " + intNumber + " Record: " + record - + " List: " + intList); - - arguments.put("city", city); - arguments.put("unit", unit); - arguments.put("intNumber", intNumber); - arguments.put("record", record); - arguments.put("intList", intList); - - return "23"; - } - - public String myNonStaticMethod(String city, Unit unit, int intNumber, MyRecord record, List intList) { - System.out.println("City: " + city + " Unit: " + unit + " intNumber: " + intNumber + " Record: " + record - + " List: " + intList); - - arguments.put("city", city); - arguments.put("unit", unit); - arguments.put("intNumber", intNumber); - arguments.put("record", record); - arguments.put("intList", intList); - - return "23"; - } - - } - -} diff --git a/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingChatOptionsTests.java b/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingChatOptionsTests.java index 0f42925b553..45557f23a6d 100644 --- a/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingChatOptionsTests.java +++ b/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingChatOptionsTests.java @@ -23,7 +23,6 @@ import org.junit.jupiter.api.Test; -import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.tool.ToolCallback; import static org.assertj.core.api.Assertions.assertThat; @@ -42,7 +41,7 @@ void setToolCallbacksShouldStoreToolCallbacks() { DefaultToolCallingChatOptions options = new DefaultToolCallingChatOptions(); ToolCallback callback1 = mock(ToolCallback.class); ToolCallback callback2 = mock(ToolCallback.class); - List callbacks = List.of(callback1, callback2); + List callbacks = List.of(callback1, callback2); options.setToolCallbacks(callbacks); @@ -221,22 +220,19 @@ void builderShouldSupportToolContextAddition() { void deprecatedMethodsShouldWorkCorrectly() { DefaultToolCallingChatOptions options = new DefaultToolCallingChatOptions(); - FunctionCallback callback1 = mock(FunctionCallback.class); + ToolCallback callback1 = mock(ToolCallback.class); ToolCallback callback2 = mock(ToolCallback.class); - options.setFunctionCallbacks(List.of(callback1, callback2)); - assertThat(options.getFunctionCallbacks()).hasSize(2); + options.setToolCallbacks(List.of(callback1, callback2)); + assertThat(options.getToolCallbacks()).hasSize(2); options.setToolNames(Set.of("tool1")); - assertThat(options.getFunctions()).containsExactly("tool1"); + assertThat(options.getToolNames()).containsExactly("tool1"); - options.setFunctions(Set.of("function1")); + options.setToolNames(Set.of("function1")); assertThat(options.getToolNames()).containsExactly("function1"); options.setInternalToolExecutionEnabled(true); - assertThat(options.getProxyToolCalls()).isFalse(); - - options.setProxyToolCalls(true); - assertThat(options.getInternalToolExecutionEnabled()).isFalse(); + assertThat(options.getInternalToolExecutionEnabled()).isTrue(); } } diff --git a/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolExecutionEligibilityPredicateTests.java b/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolExecutionEligibilityPredicateTests.java index 021a88ca606..37ca8b7168a 100644 --- a/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolExecutionEligibilityPredicateTests.java +++ b/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolExecutionEligibilityPredicateTests.java @@ -112,22 +112,6 @@ void whenFunctionCallingOptionsAndToolExecutionEnabled() { assertThat(result).isTrue(); } - @Test - void whenFunctionCallingOptionsAndToolExecutionDisabled() { - // Create a FunctionCallingOptions with proxy tool calls enabled (which means - // internal tool execution is disabled) - FunctionCallingOptions options = FunctionCallingOptions.builder().proxyToolCalls(true).build(); - - // Create a ChatResponse with tool calls - AssistantMessage.ToolCall toolCall = new AssistantMessage.ToolCall("id1", "function", "testTool", "{}"); - AssistantMessage assistantMessage = new AssistantMessage("test", Map.of(), List.of(toolCall)); - ChatResponse chatResponse = new ChatResponse(List.of(new Generation(assistantMessage))); - - // Test the predicate - boolean result = this.predicate.test(options, chatResponse); - assertThat(result).isFalse(); - } - @Test void whenRegularChatOptionsAndHasToolCalls() { // Create regular ChatOptions (not ToolCallingChatOptions or diff --git a/spring-ai-model/src/test/java/org/springframework/ai/model/tool/LegacyToolCallingManagerTests.java b/spring-ai-model/src/test/java/org/springframework/ai/model/tool/LegacyToolCallingManagerTests.java deleted file mode 100644 index 89a57cf02cb..00000000000 --- a/spring-ai-model/src/test/java/org/springframework/ai/model/tool/LegacyToolCallingManagerTests.java +++ /dev/null @@ -1,214 +0,0 @@ -/* - * Copyright 2023-2025 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.ai.model.tool; - -import java.util.List; -import java.util.Map; - -import org.junit.jupiter.api.Test; - -import org.springframework.ai.chat.messages.AssistantMessage; -import org.springframework.ai.chat.messages.ToolResponseMessage; -import org.springframework.ai.chat.messages.UserMessage; -import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.ai.chat.model.Generation; -import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.tool.ToolCallback; -import org.springframework.ai.tool.definition.ToolDefinition; -import org.springframework.ai.tool.execution.ToolExecutionException; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.mockito.Mockito.mock; - -/** - * Unit tests for {@link LegacyToolCallingManager}. - * - * @author Thomas Vitale - */ -class LegacyToolCallingManagerTests { - - // RESOLVE TOOL DEFINITIONS - - @Test - void whenChatOptionsIsNullThenThrow() { - ToolCallingManager toolCallingManager = LegacyToolCallingManager.builder().build(); - assertThatThrownBy(() -> toolCallingManager.resolveToolDefinitions(null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("chatOptions cannot be null"); - } - - @Test - void whenToolCallbackExistsThenResolve() { - ToolCallback toolCallback = new TestToolCallback("toolA"); - ToolCallingManager toolCallingManager = LegacyToolCallingManager.builder() - .functionCallbacks(List.of(toolCallback)) - .build(); - - List toolDefinitions = toolCallingManager - .resolveToolDefinitions(ToolCallingChatOptions.builder().toolNames("toolA").build()); - - assertThat(toolDefinitions).containsExactly(toolCallback.getToolDefinition()); - } - - @Test - void whenToolCallbackDoesNotExistThenThrow() { - ToolCallingManager toolCallingManager = LegacyToolCallingManager.builder().functionCallbacks(List.of()).build(); - - assertThatThrownBy(() -> toolCallingManager - .resolveToolDefinitions(ToolCallingChatOptions.builder().toolNames("toolB").build())) - .isInstanceOf(IllegalStateException.class) - .hasMessage("No ToolCallback found for tool name: toolB"); - } - - // EXECUTE TOOL CALLS - - @Test - void whenPromptIsNullThenThrow() { - ToolCallingManager toolCallingManager = LegacyToolCallingManager.builder().build(); - assertThatThrownBy(() -> toolCallingManager.executeToolCalls(null, mock(ChatResponse.class))) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("prompt cannot be null"); - } - - @Test - void whenChatResponseIsNullThenThrow() { - ToolCallingManager toolCallingManager = LegacyToolCallingManager.builder().build(); - assertThatThrownBy(() -> toolCallingManager.executeToolCalls(mock(Prompt.class), null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("chatResponse cannot be null"); - } - - @Test - void whenNoToolCallInChatResponseThenThrow() { - ToolCallingManager toolCallingManager = LegacyToolCallingManager.builder().build(); - assertThatThrownBy(() -> toolCallingManager.executeToolCalls(mock(Prompt.class), - ChatResponse.builder().generations(List.of()).build())) - .isInstanceOf(IllegalStateException.class) - .hasMessage("No tool call requested by the chat model"); - } - - @Test - void whenSingleToolCallInChatResponseThenExecute() { - ToolCallback toolCallback = new TestToolCallback("toolA"); - ToolCallingManager toolCallingManager = LegacyToolCallingManager.builder() - .functionCallbacks(List.of(toolCallback)) - .build(); - - Prompt prompt = new Prompt(new UserMessage("Hello"), ToolCallingChatOptions.builder().build()); - ChatResponse chatResponse = ChatResponse.builder() - .generations(List.of(new Generation(new AssistantMessage("", Map.of(), - List.of(new AssistantMessage.ToolCall("toolA", "function", "toolA", "{}")))))) - .build(); - - ToolResponseMessage expectedToolResponse = new ToolResponseMessage( - List.of(new ToolResponseMessage.ToolResponse("toolA", "toolA", "Mission accomplished!"))); - - ToolExecutionResult toolExecutionResult = toolCallingManager.executeToolCalls(prompt, chatResponse); - - assertThat(toolExecutionResult.conversationHistory()).contains(expectedToolResponse); - assertThat(toolExecutionResult.returnDirect()).isFalse(); - } - - @Test - void whenMultipleToolCallsInChatResponseThenExecute() { - ToolCallback toolCallbackA = new TestToolCallback("toolA"); - ToolCallback toolCallbackB = new TestToolCallback("toolB"); - ToolCallingManager toolCallingManager = LegacyToolCallingManager.builder() - .functionCallbacks(List.of(toolCallbackA, toolCallbackB)) - .build(); - - Prompt prompt = new Prompt(new UserMessage("Hello"), ToolCallingChatOptions.builder().build()); - ChatResponse chatResponse = ChatResponse.builder() - .generations(List.of(new Generation(new AssistantMessage("", Map.of(), - List.of(new AssistantMessage.ToolCall("toolA", "function", "toolA", "{}"), - new AssistantMessage.ToolCall("toolB", "function", "toolB", "{}")))))) - .build(); - - ToolResponseMessage expectedToolResponse = new ToolResponseMessage( - List.of(new ToolResponseMessage.ToolResponse("toolA", "toolA", "Mission accomplished!"), - new ToolResponseMessage.ToolResponse("toolB", "toolB", "Mission accomplished!"))); - - ToolExecutionResult toolExecutionResult = toolCallingManager.executeToolCalls(prompt, chatResponse); - - assertThat(toolExecutionResult.conversationHistory()).contains(expectedToolResponse); - assertThat(toolExecutionResult.returnDirect()).isFalse(); - } - - @Test - void whenToolCallWithExceptionThenReturnError() { - ToolCallback toolCallback = new FailingToolCallback("toolC"); - ToolCallingManager toolCallingManager = LegacyToolCallingManager.builder() - .functionCallbacks(List.of(toolCallback)) - .build(); - - Prompt prompt = new Prompt(new UserMessage("Hello"), ToolCallingChatOptions.builder().build()); - ChatResponse chatResponse = ChatResponse.builder() - .generations(List.of(new Generation(new AssistantMessage("", Map.of(), - List.of(new AssistantMessage.ToolCall("toolC", "function", "toolC", "{}")))))) - .build(); - - ToolResponseMessage expectedToolResponse = new ToolResponseMessage( - List.of(new ToolResponseMessage.ToolResponse("toolC", "toolC", "You failed this city!"))); - - ToolExecutionResult toolExecutionResult = toolCallingManager.executeToolCalls(prompt, chatResponse); - - assertThat(toolExecutionResult.conversationHistory()).contains(expectedToolResponse); - assertThat(toolExecutionResult.returnDirect()).isFalse(); - } - - static class TestToolCallback implements ToolCallback { - - private final ToolDefinition toolDefinition; - - TestToolCallback(String name) { - this.toolDefinition = ToolDefinition.builder().name(name).inputSchema("{}").build(); - } - - @Override - public ToolDefinition getToolDefinition() { - return this.toolDefinition; - } - - @Override - public String call(String toolInput) { - return "Mission accomplished!"; - } - - } - - static class FailingToolCallback implements ToolCallback { - - private final ToolDefinition toolDefinition; - - FailingToolCallback(String name) { - this.toolDefinition = ToolDefinition.builder().name(name).inputSchema("{}").build(); - } - - @Override - public ToolDefinition getToolDefinition() { - return this.toolDefinition; - } - - @Override - public String call(String toolInput) { - throw new ToolExecutionException(this.toolDefinition, new IllegalStateException("You failed this city!")); - } - - } - -} diff --git a/spring-ai-model/src/test/java/org/springframework/ai/model/tool/ToolCallingChatOptionsTests.java b/spring-ai-model/src/test/java/org/springframework/ai/model/tool/ToolCallingChatOptionsTests.java index b8aeff06d14..62645ba84ae 100644 --- a/spring-ai-model/src/test/java/org/springframework/ai/model/tool/ToolCallingChatOptionsTests.java +++ b/spring-ai-model/src/test/java/org/springframework/ai/model/tool/ToolCallingChatOptionsTests.java @@ -22,7 +22,6 @@ import org.junit.jupiter.api.Test; -import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.function.FunctionCallingOptions; import org.springframework.ai.tool.ToolCallback; import org.springframework.ai.tool.definition.ToolDefinition; @@ -64,19 +63,6 @@ void whenFunctionCallingOptionsAndExecutionEnabledTrue() { assertThat(ToolCallingChatOptions.isInternalToolExecutionEnabled(options)).isTrue(); } - @Test - void whenFunctionCallingOptionsAndExecutionEnabledFalse() { - FunctionCallingOptions options = FunctionCallingOptions.builder().build(); - options.setProxyToolCalls(true); - assertThat(ToolCallingChatOptions.isInternalToolExecutionEnabled(options)).isFalse(); - } - - @Test - void whenFunctionCallingOptionsAndExecutionEnabledDefault() { - FunctionCallingOptions options = FunctionCallingOptions.builder().build(); - assertThat(ToolCallingChatOptions.isInternalToolExecutionEnabled(options)).isTrue(); - } - @Test void whenMergeRuntimeAndDefaultToolNames() { Set runtimeToolNames = Set.of("toolA"); @@ -111,9 +97,9 @@ void whenMergeEmptyRuntimeAndEmptyDefaultToolNames() { @Test void whenMergeRuntimeAndDefaultToolCallbacks() { - List runtimeToolCallbacks = List.of(new TestToolCallback("toolA")); - List defaultToolCallbacks = List.of(new TestToolCallback("toolB")); - List mergedToolCallbacks = ToolCallingChatOptions.mergeToolCallbacks(runtimeToolCallbacks, + List runtimeToolCallbacks = List.of(new TestToolCallback("toolA")); + List defaultToolCallbacks = List.of(new TestToolCallback("toolB")); + List mergedToolCallbacks = ToolCallingChatOptions.mergeToolCallbacks(runtimeToolCallbacks, defaultToolCallbacks); assertThat(mergedToolCallbacks).hasSize(1); assertThat(mergedToolCallbacks.get(0).getName()).isEqualTo("toolA"); @@ -121,9 +107,9 @@ void whenMergeRuntimeAndDefaultToolCallbacks() { @Test void whenMergeRuntimeAndEmptyDefaultToolCallbacks() { - List runtimeToolCallbacks = List.of(new TestToolCallback("toolA")); - List defaultToolCallbacks = List.of(); - List mergedToolCallbacks = ToolCallingChatOptions.mergeToolCallbacks(runtimeToolCallbacks, + List runtimeToolCallbacks = List.of(new TestToolCallback("toolA")); + List defaultToolCallbacks = List.of(); + List mergedToolCallbacks = ToolCallingChatOptions.mergeToolCallbacks(runtimeToolCallbacks, defaultToolCallbacks); assertThat(mergedToolCallbacks).hasSize(1); assertThat(mergedToolCallbacks.get(0).getName()).isEqualTo("toolA"); @@ -131,9 +117,9 @@ void whenMergeRuntimeAndEmptyDefaultToolCallbacks() { @Test void whenMergeEmptyRuntimeAndDefaultToolCallbacks() { - List runtimeToolCallbacks = List.of(); - List defaultToolCallbacks = List.of(new TestToolCallback("toolB")); - List mergedToolCallbacks = ToolCallingChatOptions.mergeToolCallbacks(runtimeToolCallbacks, + List runtimeToolCallbacks = List.of(); + List defaultToolCallbacks = List.of(new TestToolCallback("toolB")); + List mergedToolCallbacks = ToolCallingChatOptions.mergeToolCallbacks(runtimeToolCallbacks, defaultToolCallbacks); assertThat(mergedToolCallbacks).hasSize(1); assertThat(mergedToolCallbacks.get(0).getName()).isEqualTo("toolB"); @@ -141,9 +127,9 @@ void whenMergeEmptyRuntimeAndDefaultToolCallbacks() { @Test void whenMergeEmptyRuntimeAndEmptyDefaultToolCallbacks() { - List runtimeToolCallbacks = List.of(); - List defaultToolCallbacks = List.of(); - List mergedToolCallbacks = ToolCallingChatOptions.mergeToolCallbacks(runtimeToolCallbacks, + List runtimeToolCallbacks = List.of(); + List defaultToolCallbacks = List.of(); + List mergedToolCallbacks = ToolCallingChatOptions.mergeToolCallbacks(runtimeToolCallbacks, defaultToolCallbacks); assertThat(mergedToolCallbacks).hasSize(0); } @@ -191,7 +177,7 @@ void whenMergeEmptyRuntimeAndEmptyDefaultToolContext() { @Test void shouldEnsureUniqueToolNames() { - List toolCallbacks = List.of(new TestToolCallback("toolA"), new TestToolCallback("toolA")); + List toolCallbacks = List.of(new TestToolCallback("toolA"), new TestToolCallback("toolA")); assertThatThrownBy(() -> ToolCallingChatOptions.validateToolCallbacks(toolCallbacks)) .isInstanceOf(IllegalStateException.class) .hasMessageContaining("Multiple tools with the same name (toolA)");