Skip to content

Commit fb0d99d

Browse files
tzolovsdeleuze
authored andcommitted
Revamp function callback builder API
Introduces a simplified, type-safe builder pattern for function callbacks to improve developer experience and code reliability. The new hierarchical API separates concerns between direct function invocation and method reflection, while providing better compile-time safety. This change deprecates the older FunctionCallbackWrapper in favor of a more intuitive FunctionCallback.Builder that better handles generic types via ParameterizedTypeReference. It also adds automatic function description generation as a fallback when none is provided, though explicit descriptions are still recommended. The update standardizes function callback handling across all AI model implementations (OpenAI, Ollama, Minimax, etc.) and improves response handling with configurable converters. Core API Enhancements: - New Builder Interface: Replaced FunctionCallbackWrapper.builder() with FunctionCallback.builder(), introducing a hierarchical approach that improves customization and type safety. - Specialized Builders: Introduced FunctionInvokerBuilder for direct Function/BiFunction implementations and MethodInvokerBuilder for reflection-based invocations. - Generic Type Support: Added ParameterizedTypeReference for better handling of generic parameters. - Unified Method Definition: Merged method() and argumentTypes() into a single method() call for simplicity and type safety. - Automatic Descriptions: Implemented auto-generation of function descriptions, with warnings to encourage explicit descriptions. - Configurable Response Converters: Enhanced response handling with support for custom converters, reducing unnecessary JSON conversions. Architecture Improvements: - Established common Builder interface for shared properties - Separated function object handling from constructor - Added method-specific configuration (name, arg types, target) - Added JSON schema generation support for ResolvableType - Moved to standardized schema types across AI providers - Set OPEN_API_SCHEMA as default for Vertex AI Gemini Builder Pattern Standardization: - Standardized builder method ordering across implementations - Moved function() call after description() for consistency - Improved function callback configuration with unified patterns - Enhanced error handling and validation in DefaultFunctionCallbackBuilder Deprecations: - FunctionCallbackWrapper.Builder replaced by DefaultFunctionCallbackBuilder - Removed CustomizedTypeReference in favor of ParameterizedTypeReference - Deprecated older ChatClient API methods for function handling Testing & Documentation: - Updated all AI model implementations (OpenAI, Ollama, Minimax, Moonshot, ZhiPuAI) - Added comprehensive integration tests for static/instance methods - Added integration tests for auto-generated descriptions - Updated documentation to reflect new builder pattern usage - Added Kotlin extension for inputType() support Co-authored-by: Sébastien Deleuze <sebastien.deleuze@broadcom.com>
1 parent 72c84fe commit fb0d99d

File tree

78 files changed

+1694
-908
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

78 files changed

+1694
-908
lines changed

models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelIT.java

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
import org.springframework.ai.converter.ListOutputConverter;
4949
import org.springframework.ai.converter.MapOutputConverter;
5050
import org.springframework.ai.model.Media;
51-
import org.springframework.ai.model.function.FunctionCallbackWrapper;
51+
import org.springframework.ai.model.function.FunctionCallback;
5252
import org.springframework.beans.factory.annotation.Autowired;
5353
import org.springframework.beans.factory.annotation.Value;
5454
import org.springframework.boot.SpringBootConfiguration;
@@ -256,10 +256,11 @@ void functionCallTest() {
256256

257257
var promptOptions = AnthropicChatOptions.builder()
258258
.withModel(AnthropicApi.ChatModel.CLAUDE_3_OPUS.getName())
259-
.withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService())
260-
.withName("getCurrentWeather")
261-
.withDescription(
259+
.withFunctionCallbacks(List.of(FunctionCallback.builder()
260+
.description(
262261
"Get the weather in location. Return temperature in 36°F or 36°C format. Use multi-turn if needed.")
262+
.function("getCurrentWeather", new MockWeatherService())
263+
.inputType(MockWeatherService.Request.class)
263264
.build()))
264265
.build();
265266

@@ -283,10 +284,11 @@ void streamFunctionCallTest() {
283284

284285
var promptOptions = AnthropicChatOptions.builder()
285286
.withModel(AnthropicApi.ChatModel.CLAUDE_3_5_SONNET.getName())
286-
.withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService())
287-
.withName("getCurrentWeather")
288-
.withDescription(
287+
.withFunctionCallbacks(List.of(FunctionCallback.builder()
288+
.description(
289289
"Get the weather in location. Return temperature in 36°F or 36°C format. Use multi-turn if needed.")
290+
.function("getCurrentWeather", new MockWeatherService())
291+
.inputType(MockWeatherService.Request.class)
290292
.build()))
291293
.build();
292294

models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/client/AnthropicChatClientIT.java

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
import org.springframework.ai.chat.model.ChatResponse;
4343
import org.springframework.ai.converter.BeanOutputConverter;
4444
import org.springframework.ai.converter.ListOutputConverter;
45+
import org.springframework.ai.model.function.FunctionCallback;
4546
import org.springframework.beans.factory.annotation.Autowired;
4647
import org.springframework.beans.factory.annotation.Value;
4748
import org.springframework.boot.test.context.SpringBootTest;
@@ -210,8 +211,30 @@ void functionCallTest() {
210211

211212
// @formatter:off
212213
String response = ChatClient.create(this.chatModel).prompt()
213-
.user(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris? Use Celsius."))
214-
.function("getCurrentWeather", "Get the weather in location", new MockWeatherService())
214+
.user("What's the weather like in San Francisco, Tokyo, and Paris? Use Celsius.")
215+
.functions(FunctionCallback.builder()
216+
.function("getCurrentWeather", new MockWeatherService())
217+
.inputType(MockWeatherService.Request.class)
218+
.build())
219+
.call()
220+
.content();
221+
// @formatter:on
222+
223+
logger.info("Response: {}", response);
224+
225+
assertThat(response).contains("30", "10", "15");
226+
}
227+
228+
@Test
229+
void functionCallWithGeneratedDescription() {
230+
231+
// @formatter:off
232+
String response = ChatClient.create(this.chatModel).prompt()
233+
.user("What's the weather like in San Francisco, Tokyo, and Paris? Use Celsius.")
234+
.functions(FunctionCallback.builder()
235+
.function("getCurrentWeatherInLocation", new MockWeatherService())
236+
.inputType(MockWeatherService.Request.class)
237+
.build())
215238
.call()
216239
.content();
217240
// @formatter:on
@@ -226,7 +249,11 @@ void defaultFunctionCallTest() {
226249

227250
// @formatter:off
228251
String response = ChatClient.builder(this.chatModel)
229-
.defaultFunction("getCurrentWeather", "Get the weather in location", new MockWeatherService())
252+
.defaultFunctions(FunctionCallback.builder()
253+
.description("Get the weather in location")
254+
.function("getCurrentWeather", new MockWeatherService())
255+
.inputType(MockWeatherService.Request.class)
256+
.build())
230257
.defaultUser(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris? Use Celsius."))
231258
.build()
232259
.prompt()
@@ -245,7 +272,11 @@ void streamFunctionCallTest() {
245272
// @formatter:off
246273
Flux<String> response = ChatClient.create(this.chatModel).prompt()
247274
.user("What's the weather like in San Francisco, Tokyo, and Paris? Use Celsius.")
248-
.function("getCurrentWeather", "Get the weather in location", new MockWeatherService())
275+
.functions(FunctionCallback.builder()
276+
.description("Get the weather in location")
277+
.function("getCurrentWeather", new MockWeatherService())
278+
.inputType(MockWeatherService.Request.class)
279+
.build())
249280
.stream()
250281
.content();
251282
// @formatter:on
Lines changed: 49 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -29,21 +29,21 @@
2929
import org.springframework.ai.chat.client.ChatClient;
3030
import org.springframework.ai.chat.model.ChatModel;
3131
import org.springframework.ai.chat.model.ToolContext;
32-
import org.springframework.ai.model.function.MethodFunctionCallback;
32+
import org.springframework.ai.model.function.FunctionCallback;
3333
import org.springframework.beans.factory.annotation.Autowired;
3434
import org.springframework.boot.test.context.SpringBootTest;
3535
import org.springframework.test.context.ActiveProfiles;
36-
import org.springframework.util.ReflectionUtils;
3736

3837
import static org.assertj.core.api.Assertions.assertThat;
3938
import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy;
4039

4140
@SpringBootTest(classes = AnthropicTestConfiguration.class, properties = "spring.ai.retry.on-http-codes=429")
4241
@EnabledIfEnvironmentVariable(named = "ANTHROPIC_API_KEY", matches = ".+")
4342
@ActiveProfiles("logging-test")
44-
class AnthropicChatClientMethodFunctionCallbackIT {
43+
class AnthropicChatClientMethodInvokingFunctionCallbackIT {
4544

46-
private static final Logger logger = LoggerFactory.getLogger(AnthropicChatClientMethodFunctionCallbackIT.class);
45+
private static final Logger logger = LoggerFactory
46+
.getLogger(AnthropicChatClientMethodInvokingFunctionCallbackIT.class);
4747

4848
public static Map<String, Object> arguments = new ConcurrentHashMap<>();
4949

@@ -52,16 +52,35 @@ void beforeEach() {
5252
arguments.clear();
5353
}
5454

55+
@Test
56+
void methodGetWeatherGeneratedDescription() {
57+
58+
// @formatter:off
59+
String response = ChatClient.create(this.chatModel).prompt()
60+
.user("What's the weather like in San Francisco, Tokyo, and Paris? Use Celsius.")
61+
.functions(FunctionCallback.builder()
62+
.method("getWeatherInLocation", String.class, Unit.class)
63+
.targetClass(TestFunctionClass.class)
64+
.build())
65+
.call()
66+
.content();
67+
// @formatter:on
68+
69+
logger.info("Response: {}", response);
70+
71+
assertThat(response).contains("30", "10", "15");
72+
}
73+
5574
@Test
5675
void methodGetWeatherStatic() {
5776

58-
var method = ReflectionUtils.findMethod(TestFunctionClass.class, "getWeatherStatic", String.class, Unit.class);
5977
// @formatter:off
6078
String response = ChatClient.create(this.chatModel).prompt()
6179
.user("What's the weather like in San Francisco, Tokyo, and Paris? Use Celsius.")
62-
.functions(MethodFunctionCallback.builder()
63-
.method(method)
80+
.functions(FunctionCallback.builder()
6481
.description("Get the weather in location")
82+
.method("getWeatherStatic", String.class, Unit.class)
83+
.targetClass(TestFunctionClass.class)
6584
.build())
6685
.call()
6786
.content();
@@ -77,15 +96,13 @@ void methodTurnLightNoResponse() {
7796

7897
TestFunctionClass targetObject = new TestFunctionClass();
7998

80-
var method = ReflectionUtils.findMethod(TestFunctionClass.class, "turnLight", String.class, boolean.class);
81-
8299
// @formatter:off
83100
String response = ChatClient.create(this.chatModel).prompt()
84101
.user("Turn light on in the living room.")
85-
.functions(MethodFunctionCallback.builder()
86-
.functionObject(targetObject)
87-
.method(method)
88-
.description("Can turn lights on or off by room name")
102+
.functions(FunctionCallback.builder()
103+
.description("Turn light on in the living room.")
104+
.method("turnLight", String.class, boolean.class)
105+
.targetObject(targetObject)
89106
.build())
90107
.call()
91108
.content();
@@ -102,16 +119,13 @@ void methodGetWeatherNonStatic() {
102119

103120
TestFunctionClass targetObject = new TestFunctionClass();
104121

105-
var method = ReflectionUtils.findMethod(TestFunctionClass.class, "getWeatherNonStatic", String.class,
106-
Unit.class);
107-
108122
// @formatter:off
109123
String response = ChatClient.create(this.chatModel).prompt()
110124
.user("What's the weather like in San Francisco, Tokyo, and Paris? Use Celsius.")
111-
.functions(MethodFunctionCallback.builder()
112-
.functionObject(targetObject)
113-
.method(method)
125+
.functions(FunctionCallback.builder()
114126
.description("Get the weather in location")
127+
.method("getWeatherNonStatic",String.class, Unit.class)
128+
.targetObject(targetObject)
115129
.build())
116130
.call()
117131
.content();
@@ -127,17 +141,14 @@ void methodGetWeatherToolContext() {
127141

128142
TestFunctionClass targetObject = new TestFunctionClass();
129143

130-
var method = ReflectionUtils.findMethod(TestFunctionClass.class, "getWeatherWithContext", String.class,
131-
Unit.class, ToolContext.class);
132-
133144
// @formatter:off
134145
String response = ChatClient.create(this.chatModel).prompt()
135146
.user("What's the weather like in San Francisco, Tokyo, and Paris? Use Celsius.")
136-
.functions(MethodFunctionCallback.builder()
137-
.functionObject(targetObject)
138-
.method(method)
147+
.functions(FunctionCallback.builder()
139148
.description("Get the weather in location")
140-
.build())
149+
.method("getWeatherWithContext", String.class, Unit.class, ToolContext.class)
150+
.targetObject(targetObject)
151+
.build())
141152
.toolContext(Map.of("tool", "value"))
142153
.call()
143154
.content();
@@ -154,17 +165,14 @@ void methodGetWeatherToolContextButNonContextMethod() {
154165

155166
TestFunctionClass targetObject = new TestFunctionClass();
156167

157-
var method = ReflectionUtils.findMethod(TestFunctionClass.class, "getWeatherNonStatic", String.class,
158-
Unit.class);
159-
160168
// @formatter:off
161169
assertThatThrownBy(() -> ChatClient.create(this.chatModel).prompt()
162170
.user("What's the weather like in San Francisco, Tokyo, and Paris? Use Celsius.")
163-
.functions(MethodFunctionCallback.builder()
164-
.functionObject(targetObject)
165-
.method(method)
166-
.description("Get the weather in location")
167-
.build())
171+
.functions(FunctionCallback.builder()
172+
.description("Get the weather in location")
173+
.method("getWeatherNonStatic", String.class, Unit.class)
174+
.targetObject(targetObject)
175+
.build())
168176
.toolContext(Map.of("tool", "value"))
169177
.call()
170178
.content())
@@ -178,15 +186,13 @@ void methodNoParameters() {
178186

179187
TestFunctionClass targetObject = new TestFunctionClass();
180188

181-
var method = ReflectionUtils.findMethod(TestFunctionClass.class, "turnLivingRoomLightOn");
182-
183189
// @formatter:off
184190
String response = ChatClient.create(this.chatModel).prompt()
185191
.user("Turn light on in the living room.")
186-
.functions(MethodFunctionCallback.builder()
187-
.functionObject(targetObject)
188-
.method(method)
192+
.functions(FunctionCallback.builder()
189193
.description("Can turn lights on in the Living Room")
194+
.method("turnLivingRoomLightOn")
195+
.targetObject(targetObject)
190196
.build())
191197
.call()
192198
.content();
@@ -215,6 +221,10 @@ public static void argumentLessReturnVoid() {
215221
arguments.put("method called", "argumentLessReturnVoid");
216222
}
217223

224+
public static String getWeatherInLocation(String city, Unit unit) {
225+
return getWeatherStatic(city, unit);
226+
}
227+
218228
public static String getWeatherStatic(String city, Unit unit) {
219229

220230
logger.info("City: " + city + " Unit: " + unit);

models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/AzureOpenAiChatModelFunctionCallIT.java

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
import org.springframework.ai.chat.model.ChatResponse;
4040
import org.springframework.ai.chat.model.Generation;
4141
import org.springframework.ai.chat.prompt.Prompt;
42-
import org.springframework.ai.model.function.FunctionCallbackWrapper;
42+
import org.springframework.ai.model.function.FunctionCallback;
4343
import org.springframework.beans.factory.annotation.Autowired;
4444
import org.springframework.boot.SpringBootConfiguration;
4545
import org.springframework.boot.test.context.SpringBootTest;
@@ -70,10 +70,10 @@ void functionCallTest() {
7070

7171
var promptOptions = AzureOpenAiChatOptions.builder()
7272
.withDeploymentName(this.selectedModel)
73-
.withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService())
74-
.withName("getCurrentWeather")
75-
.withDescription("Get the current weather in a given location")
76-
.withResponseConverter(response -> "" + response.temp() + response.unit())
73+
.withFunctionCallbacks(List.of(FunctionCallback.builder()
74+
.description("Get the current weather in a given location")
75+
.function("getCurrentWeather", new MockWeatherService())
76+
.inputType(MockWeatherService.Request.class)
7777
.build()))
7878
.build();
7979

@@ -94,10 +94,10 @@ void functionCallSequentialTest() {
9494

9595
var promptOptions = AzureOpenAiChatOptions.builder()
9696
.withDeploymentName(this.selectedModel)
97-
.withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService())
98-
.withName("getCurrentWeather")
99-
.withDescription("Get the current weather in a given location")
100-
.withResponseConverter(response -> "" + response.temp() + response.unit())
97+
.withFunctionCallbacks(List.of(FunctionCallback.builder()
98+
.description("Get the current weather in a given location")
99+
.function("getCurrentWeather", new MockWeatherService())
100+
.inputType(MockWeatherService.Request.class)
101101
.build()))
102102
.build();
103103

@@ -116,10 +116,10 @@ void streamFunctionCallTest() {
116116

117117
var promptOptions = AzureOpenAiChatOptions.builder()
118118
.withDeploymentName(this.selectedModel)
119-
.withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService())
120-
.withName("getCurrentWeather")
121-
.withDescription("Get the current weather in a given location")
122-
.withResponseConverter(response -> "" + response.temp() + response.unit())
119+
.withFunctionCallbacks(List.of(FunctionCallback.builder()
120+
.description("Get the current weather in a given location")
121+
.function("getCurrentWeather", new MockWeatherService())
122+
.inputType(MockWeatherService.Request.class)
123123
.build()))
124124
.build();
125125

@@ -153,10 +153,10 @@ void functionCallSequentialAndStreamTest() {
153153

154154
var promptOptions = AzureOpenAiChatOptions.builder()
155155
.withDeploymentName(this.selectedModel)
156-
.withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService())
157-
.withName("getCurrentWeather")
158-
.withDescription("Get the current weather in a given location")
159-
.withResponseConverter(response -> "" + response.temp() + response.unit())
156+
.withFunctionCallbacks(List.of(FunctionCallback.builder()
157+
.description("Get the current weather in a given location")
158+
.function("getCurrentWeather", new MockWeatherService())
159+
.inputType(MockWeatherService.Request.class)
160160
.build()))
161161
.build();
162162

0 commit comments

Comments
 (0)