Skip to content

Commit a8e305d

Browse files
tzolovilayaperumalg
authored andcommitted
refactor(bedrock): Migrate from function calling to tool calling
- Replace function calling with tool calling in BedrockProxyChatModel - Deprecate function calling related code and APIs - Add new tool calling manager and options - Update builder pattern to remove "with" prefix from methods - Update tests and documentation for tool calling Part of the #2207 epic Signed-off-by: Christian Tzolov <christian.tzolov@broadcom.com>
1 parent 5ebe485 commit a8e305d

File tree

12 files changed

+447
-199
lines changed

12 files changed

+447
-199
lines changed

models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java

Lines changed: 307 additions & 82 deletions
Large diffs are not rendered by default.

models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/api/ConverseApiUtils.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2023-2024 the original author or authors.
2+
* Copyright 2023-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -331,6 +331,10 @@ public static Document getChatOptionsAdditionalModelRequestFields(ChatOptions de
331331
attributes.remove("toolContext");
332332
attributes.remove("functionCallbacks");
333333

334+
attributes.remove("toolCallbacks");
335+
attributes.remove("toolNames");
336+
attributes.remove("internalToolExecutionEnabled");
337+
334338
attributes.remove("temperature");
335339
attributes.remove("topK");
336340
attributes.remove("stopSequences");

models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseChatClientIT.java

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,8 @@ void multiModalityEmbeddedImage(String modelName) throws IOException {
380380

381381
@ParameterizedTest(name = "{0} : {displayName} ")
382382
@ValueSource(strings = { "anthropic.claude-3-5-sonnet-20240620-v1:0" })
383-
void multiModalityImageUrl(String modelName) throws IOException {
383+
@Deprecated
384+
void multiModalityImageUrl2(String modelName) throws IOException {
384385

385386
// TODO: add url method that wrapps the checked exception.
386387
URL url = new URL("https://docs.spring.io/spring-ai/reference/_images/multimodal.test.png");
@@ -398,6 +399,26 @@ void multiModalityImageUrl(String modelName) throws IOException {
398399
assertThat(response).containsAnyOf("bananas", "apple", "bowl", "basket", "fruit stand");
399400
}
400401

402+
@ParameterizedTest(name = "{0} : {displayName} ")
403+
@ValueSource(strings = { "anthropic.claude-3-5-sonnet-20240620-v1:0" })
404+
void multiModalityImageUrl(String modelName) throws IOException {
405+
406+
// TODO: add url method that wrapps the checked exception.
407+
URL url = new URL("https://docs.spring.io/spring-ai/reference/_images/multimodal.test.png");
408+
409+
// @formatter:off
410+
String response = ChatClient.create(this.chatModel).prompt()
411+
// TODO consider adding model(...) method to ChatClient as a shortcut to
412+
.options(ToolCallingChatOptions.builder().model(modelName).build())
413+
.user(u -> u.text("Explain what do you see on this picture?").media(MimeTypeUtils.IMAGE_PNG, url))
414+
.call()
415+
.content();
416+
// @formatter:on
417+
418+
logger.info(response);
419+
assertThat(response).containsAnyOf("bananas", "apple", "bowl", "basket", "fruit stand");
420+
}
421+
401422
@Test
402423
void streamingMultiModalityImageUrl() throws IOException {
403424

models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseTestConfiguration.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2023-2024 the original author or authors.
2+
* Copyright 2023-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -38,10 +38,10 @@ public BedrockProxyChatModel bedrockConverseChatModel() {
3838
String modelId = "anthropic.claude-3-5-sonnet-20240620-v1:0";
3939

4040
return BedrockProxyChatModel.builder()
41-
.withCredentialsProvider(EnvironmentVariableCredentialsProvider.create())
42-
.withRegion(Region.US_EAST_1)
43-
.withTimeout(Duration.ofSeconds(120))
44-
// .withRegion(Region.US_EAST_1)
41+
.credentialsProvider(EnvironmentVariableCredentialsProvider.create())
42+
.region(Region.US_EAST_1)
43+
// .region(Region.US_EAST_1)
44+
.timeout(Duration.ofSeconds(120))
4545
.withDefaultOptions(FunctionCallingOptions.builder().model(modelId).build())
4646
.build();
4747
}

models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseUsageAggregationTests.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import io.micrometer.observation.ObservationRegistry;
2222
import org.junit.jupiter.api.BeforeEach;
23+
import org.junit.jupiter.api.Disabled;
2324
import org.junit.jupiter.api.Test;
2425
import org.junit.jupiter.api.extension.ExtendWith;
2526
import org.mockito.Mock;

models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModelIT.java

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@
4848
import org.springframework.ai.model.Media;
4949
import org.springframework.ai.model.function.FunctionCallback;
5050
import org.springframework.ai.model.function.FunctionCallingOptions;
51+
import org.springframework.ai.model.tool.ToolCallingChatOptions;
52+
import org.springframework.ai.tool.function.FunctionToolCallback;
5153
import org.springframework.beans.factory.annotation.Autowired;
5254
import org.springframework.beans.factory.annotation.Value;
5355
import org.springframework.boot.test.context.SpringBootTest;
@@ -244,8 +246,9 @@ void multiModalityTest() throws IOException {
244246
"fruit stand");
245247
}
246248

249+
@Deprecated
247250
@Test
248-
void functionCallTest() {
251+
void functionCallTestDeprecated() {
249252

250253
UserMessage userMessage = new UserMessage(
251254
"What's the weather like in San Francisco, Tokyo and Paris? Return the result in Celsius.");
@@ -269,6 +272,29 @@ void functionCallTest() {
269272
assertThat(generation.getOutput().getText()).contains("30", "10", "15");
270273
}
271274

275+
@Test
276+
void functionCallTest() {
277+
278+
UserMessage userMessage = new UserMessage(
279+
"What's the weather like in San Francisco, Tokyo and Paris? Return the result in Celsius.");
280+
281+
List<Message> messages = new ArrayList<>(List.of(userMessage));
282+
283+
var promptOptions = ToolCallingChatOptions.builder()
284+
.toolCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService())
285+
.description("Get the weather in location. Return in 36°C format")
286+
.inputType(MockWeatherService.Request.class)
287+
.build()))
288+
.build();
289+
290+
ChatResponse response = this.chatModel.call(new Prompt(messages, promptOptions));
291+
292+
logger.info("Response: {}", response);
293+
294+
Generation generation = response.getResult();
295+
assertThat(generation.getOutput().getText()).contains("30", "10", "15");
296+
}
297+
272298
@Test
273299
void streamFunctionCallTest() {
274300

models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/client/BedrockNovaChatClientIT.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import java.time.Duration;
2121
import java.util.Set;
2222

23+
import org.junit.jupiter.api.Disabled;
2324
import org.junit.jupiter.api.Test;
2425
import org.slf4j.Logger;
2526
import org.slf4j.LoggerFactory;
@@ -32,7 +33,6 @@
3233
import org.springframework.ai.chat.model.ChatModel;
3334
import org.springframework.ai.model.Media;
3435
import org.springframework.ai.model.function.FunctionCallingOptions;
35-
import org.springframework.ai.model.tool.ToolCallingChatOptions;
3636
import org.springframework.ai.tool.function.FunctionToolCallback;
3737
import org.springframework.beans.factory.annotation.Autowired;
3838
import org.springframework.boot.SpringBootConfiguration;
@@ -47,6 +47,7 @@
4747
/**
4848
* @author Christian Tzolov
4949
*/
50+
@Disabled
5051
@SpringBootTest(classes = BedrockNovaChatClientIT.Config.class)
5152
@RequiresAwsCredentials
5253
public class BedrockNovaChatClientIT {
@@ -181,9 +182,9 @@ public BedrockProxyChatModel bedrockConverseChatModel() {
181182
String modelId = "amazon.nova-pro-v1:0";
182183

183184
return BedrockProxyChatModel.builder()
184-
.withCredentialsProvider(EnvironmentVariableCredentialsProvider.create())
185-
.withRegion(Region.US_EAST_1)
186-
.withTimeout(Duration.ofSeconds(120))
185+
.credentialsProvider(EnvironmentVariableCredentialsProvider.create())
186+
.region(Region.US_EAST_1)
187+
.timeout(Duration.ofSeconds(120))
187188
.withDefaultOptions(FunctionCallingOptions.builder().model(modelId).build())
188189
.build();
189190
}

models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/experiments/BedrockConverseChatModelMain2.java

Lines changed: 0 additions & 76 deletions
This file was deleted.

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelIT.java

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.AudioParameters.Voice;
6363
import org.springframework.ai.openai.api.tool.MockWeatherService;
6464
import org.springframework.ai.openai.testutils.AbstractIT;
65+
import org.springframework.ai.tool.function.FunctionToolCallback;
6566
import org.springframework.beans.factory.annotation.Value;
6667
import org.springframework.boot.test.context.SpringBootTest;
6768
import org.springframework.core.convert.support.DefaultConversionService;
@@ -332,7 +333,8 @@ void beanStreamOutputConverterRecords() {
332333
}
333334

334335
@Test
335-
void functionCallTest() {
336+
@Deprecated
337+
void functionCallTestDeprecated() {
336338

337339
UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?");
338340

@@ -356,6 +358,28 @@ void functionCallTest() {
356358
assertThat(response.getResult().getOutput().getText()).containsAnyOf("15.0", "15");
357359
}
358360

361+
@Test
362+
void functionCallTest() {
363+
364+
UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?");
365+
366+
List<Message> messages = new ArrayList<>(List.of(userMessage));
367+
368+
var promptOptions = OpenAiChatOptions.builder()
369+
.model(OpenAiApi.ChatModel.GPT_4_O.getValue())
370+
.toolCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService())
371+
.description("Get the weather in location")
372+
.inputType(MockWeatherService.Request.class)
373+
.build()))
374+
.build();
375+
376+
ChatResponse response = this.chatModel.call(new Prompt(messages, promptOptions));
377+
378+
logger.info("Response: {}", response);
379+
380+
assertThat(response.getResult().getOutput().getText()).contains("30", "10", "15");
381+
}
382+
359383
@Test
360384
void streamFunctionCallTest() {
361385

spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock-converse.adoc

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -90,20 +90,19 @@ The prefix `spring.ai.bedrock.converse.chat` is the property prefix that configu
9090

9191
== Runtime Options [[chat-options]]
9292

93-
Use the portable `ChatOptions` or `FunctionCallingOptions` portable builders to create model configurations, such as temperature, maxToken, topP, etc.
93+
Use the portable `ChatOptions` or `ToolCallingChatOptions` portable builders to create model configurations, such as temperature, maxToken, topP, etc.
9494

9595
On start-up, the default options can be configured with the `BedrockConverseProxyChatModel(api, options)` constructor or the `spring.ai.bedrock.converse.chat.options.*` properties.
9696

9797
At run-time you can override the default options by adding new, request specific, options to the `Prompt` call:
9898

9999
[source,java]
100100
----
101-
var options = FunctionCallingOptions.builder()
102-
.withModel("anthropic.claude-3-5-sonnet-20240620-v1:0")
103-
.withTemperature(0.6)
104-
.withMaxTokens(300)
105-
.withFunctionCallbacks(List.of(FunctionCallback.builder()
106-
.function("getCurrentWeather", new WeatherService())
101+
var options = ToolCallingChatOptions.builder()
102+
.model("anthropic.claude-3-5-sonnet-20240620-v1:0")
103+
.temperature(0.6)
104+
.maxTokens(300)
105+
.toolCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new WeatherService())
107106
.description("Get the weather in location. Return temperature in 36°F or 36°C format. Use multi-turn if needed.")
108107
.inputType(WeatherService.Request.class)
109108
.build()))
@@ -118,7 +117,28 @@ String response = ChatClient.create(this.chatModel)
118117

119118
== Tool/Function Calling
120119

121-
The Bedrock Converse API supports function calling capabilities, allowing models to use tools during conversations. Here's an example of how to define and use functions:
120+
The Bedrock Converse API supports tool calling capabilities, allowing models to use tools during conversations.
121+
Here's an example of how to define and use @Tool based tools:
122+
123+
[source,java]
124+
----
125+
126+
public class WeatherService {
127+
128+
@Tool(description = "Get the weather in location")
129+
public String weatherByLocation(@ToolParam(description= "City or state name") String location) {
130+
...
131+
}
132+
}
133+
134+
String response = ChatClient.create(this.chatModel)
135+
.prompt("What's the weather like in Boston?")
136+
.tools(new WeatherService())
137+
.call()
138+
.content();
139+
----
140+
141+
You can use the java.util.function beans as tools as well:
122142

123143
[source,java]
124144
----
@@ -130,12 +150,14 @@ public Function<Request, Response> weatherFunction() {
130150
131151
String response = ChatClient.create(this.chatModel)
132152
.prompt("What's the weather like in Boston?")
133-
.function("weatherFunction")
153+
.tools("weatherFunction")
134154
.inputType(Request.class)
135155
.call()
136156
.content();
137157
----
138158

159+
Find more in xref:api/tools.adoc[Tools] documentation.
160+
139161
== Multimodal
140162

141163
Multimodality refers to a model's ability to simultaneously understand and process information from various sources, including text, images, video, pdf, doc, html, md and more data formats.

0 commit comments

Comments
 (0)