Skip to content

Commit b2a4f01

Browse files
Kirbstompertzolov
andcommitted
Add Spring Annotation for registerding Function Calling tools.
- Built on top of the existing ToolFunctionCallback utilites. - Organized as part of auto-configuration project under the /common/function package to be reusable for different model implementations. - Add OpenAI support for the SpringAiFunction annotation. - Update the OpenAI function calling documentation. - Add ITs Co-Authored-By: Christian Tzolov <ctzolov@vmware.com>
1 parent 57c66b9 commit b2a4f01

File tree

6 files changed

+285
-8
lines changed

6 files changed

+285
-8
lines changed

spring-ai-core/src/main/java/org/springframework/ai/chat/messages/FunctionMessage.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,11 @@
2121
public class FunctionMessage extends AbstractMessage {
2222

2323
public FunctionMessage(String content) {
24-
super(MessageType.FUNCTION, content);
24+
super(MessageType.SYSTEM, content);
2525
}
2626

2727
public FunctionMessage(String content, Map<String, Object> properties) {
28-
super(MessageType.FUNCTION, content, properties);
28+
super(MessageType.SYSTEM, content, properties);
2929
}
3030

3131
@Override

spring-ai-docs/src/main/antora/modules/ROOT/pages/api/clients/functions/openai-chat-functions.adoc

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,16 +64,16 @@ The Spring AI auto-generates the JSON Scheme for the `MockWeatherService.Request
6464

6565
If you enable the link:../openai-chat.html#_auto_configuration[OpenAiChatClient Auto-Configuration], the easiest way to register a function is to created it as a bean in the Spring context:
6666

67-
[source,java,linenums]
67+
[source,java]
6868
----
6969
@Configuration
7070
static class Config {
7171
@Bean
7272
public WeatherFunctionCallback weatherFunctionInfo() {
7373
return new WeatherFunctionCallback(
74-
"CurrentWeather", // (1) name
75-
"Get the weather in location", // (2) description
76-
MockWeatherService.Request.class); // (3) signature
74+
"CurrentWeather", // (1) function name
75+
"Get the weather in location", // (2) function description
76+
MockWeatherService.Request.class); // (3) function input signature
7777
}
7878
...
7979
}
@@ -106,6 +106,35 @@ Here is the current weather for the requested cities:
106106

107107
The link:https://github.com/spring-projects/spring-ai/blob/main/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/ToolCallWithBeanFunctionRegistrationIT.java[ToolCallWithBeanFunctionRegistrationIT.java] integration test provides a complete example of how to register a function with the `OpenAiChatClient` using the auto-configuration.
108108

109+
==== @SpringAiFunction
110+
111+
You can use the `SpringAiFunction` annotation cam be used to register a `java.util.Function<I,O>` as a `ToolFunctionCallback` bean:
112+
113+
[source,java]
114+
----
115+
@Configuration
116+
static class Config {
117+
118+
@SpringAiFunction(
119+
name = "CurrentWeather", // (1)
120+
description = "Get the weather in location", // (2)
121+
classType = MockWeatherService.Request.class) // (3)
122+
public Function<Request, Response> weatherFunction() {
123+
MockWeatherService weatherService = new MockWeatherService();
124+
return (weatherService::apply);
125+
}
126+
127+
...
128+
}
129+
----
130+
131+
The `@SpringAiFunction` annotation defines the function name (1), description (2), and input signature (3) and registers the function as a bean in the Spring context.
132+
133+
NOTE: The `SpringAiFunction` annotation supported only if the auto-configuration is enabled.
134+
135+
NOTE: The Function<I, O> implementation is responsible to convert the response into a text as expected by the model.
136+
By default, the `AbstractToolFunctionCallback` provides a default converter that returns the `toString()` of the response object.
137+
109138
=== Register/Call Functions with Prompt Options
110139

111140
In addition to the auto-configuration you can register callback functions, dynamically, with your Prompt requests:
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
/*
2+
* Copyright 2024-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.autoconfigure.common.function;
18+
19+
import java.lang.annotation.ElementType;
20+
import java.lang.annotation.Retention;
21+
import java.lang.annotation.RetentionPolicy;
22+
import java.lang.annotation.Target;
23+
24+
import org.springframework.context.annotation.Bean;
25+
26+
/**
27+
* An annotation used to define functions for use in
28+
*
29+
* @author Christopher Smith
30+
*/
31+
@Bean
32+
@Target(ElementType.METHOD)
33+
@Retention(RetentionPolicy.RUNTIME)
34+
public @interface SpringAiFunction {
35+
36+
String name();
37+
38+
String description();
39+
40+
Class<?> classType();
41+
42+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
/*
2+
* Copyright 2024-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.ai.autoconfigure.common.function;
17+
18+
import java.util.ArrayList;
19+
import java.util.List;
20+
import java.util.Map;
21+
import java.util.function.Function;
22+
23+
import org.springframework.ai.model.AbstractToolFunctionCallback;
24+
import org.springframework.ai.model.ToolFunctionCallback;
25+
import org.springframework.beans.BeansException;
26+
import org.springframework.context.ApplicationContext;
27+
import org.springframework.context.ApplicationContextAware;
28+
import org.springframework.context.support.GenericApplicationContext;
29+
import org.springframework.lang.NonNull;
30+
import org.springframework.util.Assert;
31+
import org.springframework.util.CollectionUtils;
32+
import org.springframework.util.ReflectionUtils;
33+
34+
/**
35+
* Manages the chat functions that are annotated with {@link SpringAiFunction}.
36+
*
37+
* @author Christopher Smith
38+
* @author Christian Tzolov
39+
*/
40+
public class SpringAiFunctionAnnotationManager implements ApplicationContextAware {
41+
42+
private GenericApplicationContext applicationContext;
43+
44+
@Override
45+
public void setApplicationContext(@NonNull ApplicationContext applicationContext) throws BeansException {
46+
this.applicationContext = (GenericApplicationContext) applicationContext;
47+
}
48+
49+
/**
50+
* @return a list of all the {@link java.util.Function}s annotated with
51+
* {@link SpringAiFunction}.
52+
*/
53+
public List<ToolFunctionCallback> getAnnotatedToolFunctionCallbacks() {
54+
Map<String, Object> beans = this.applicationContext.getBeansWithAnnotation(SpringAiFunction.class);
55+
56+
List<ToolFunctionCallback> toolFunctionCallbacks = new ArrayList<>();
57+
58+
if (!CollectionUtils.isEmpty(beans)) {
59+
60+
beans.forEach((k, v) -> {
61+
if (v instanceof Function<?, ?> function) {
62+
SpringAiFunction functionAnnotation = applicationContext.findAnnotationOnBean(k,
63+
SpringAiFunction.class);
64+
65+
toolFunctionCallbacks.add(new SpringAiFunctionToolFunctionCallback(functionAnnotation.name(),
66+
functionAnnotation.description(), functionAnnotation.classType(), function));
67+
}
68+
else {
69+
ReflectionUtils.handleReflectionException(new IllegalArgumentException(
70+
"Bean annotated with @SpringAiFunction must be of type Function"));
71+
}
72+
});
73+
74+
}
75+
76+
return toolFunctionCallbacks;
77+
}
78+
79+
/**
80+
* Note that the underlying function is responsible for converting the output into
81+
* format that can be consumed by the Model. The default implementation converts the
82+
* output into String before sending it to the Model. Provide a custom Function<O,
83+
* String> responseConverter implementation to override this.
84+
*
85+
*/
86+
public static class SpringAiFunctionToolFunctionCallback<I, O> extends AbstractToolFunctionCallback<I, O> {
87+
88+
private Function<I, O> function;
89+
90+
protected SpringAiFunctionToolFunctionCallback(String name, String description, Class<I> inputType,
91+
Function<I, O> function) {
92+
super(name, description, inputType);
93+
Assert.notNull(function, "Function must not be null");
94+
this.function = function;
95+
}
96+
97+
protected SpringAiFunctionToolFunctionCallback(String name, String description, Class<I> inputType,
98+
Function<O, String> responseConverter, Function<I, O> function) {
99+
super(name, description, inputType, responseConverter);
100+
Assert.notNull(function, "Function must not be null");
101+
this.function = function;
102+
}
103+
104+
@Override
105+
public O apply(I input) {
106+
return this.function.apply(input);
107+
}
108+
109+
}
110+
111+
}

spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import java.util.List;
2020

2121
import org.springframework.ai.autoconfigure.NativeHints;
22+
import org.springframework.ai.autoconfigure.common.function.SpringAiFunctionAnnotationManager;
2223
import org.springframework.ai.embedding.EmbeddingClient;
2324
import org.springframework.ai.model.ToolFunctionCallback;
2425
import org.springframework.ai.openai.OpenAiChatClient;
@@ -31,14 +32,15 @@
3132
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
3233
import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration;
3334
import org.springframework.boot.context.properties.EnableConfigurationProperties;
35+
import org.springframework.context.ApplicationContext;
3436
import org.springframework.context.annotation.Bean;
3537
import org.springframework.context.annotation.ImportRuntimeHints;
3638
import org.springframework.util.Assert;
3739
import org.springframework.util.CollectionUtils;
3840
import org.springframework.util.StringUtils;
3941
import org.springframework.web.client.RestClient;
4042

41-
@AutoConfiguration(after = RestClientAutoConfiguration.class)
43+
@AutoConfiguration(after = { RestClientAutoConfiguration.class })
4244
@ConditionalOnClass(OpenAiApi.class)
4345
@EnableConfigurationProperties({ OpenAiConnectionProperties.class, OpenAiChatProperties.class,
4446
OpenAiEmbeddingProperties.class, OpenAiImageProperties.class })
@@ -56,7 +58,7 @@ public class OpenAiAutoConfiguration {
5658
@ConditionalOnMissingBean
5759
public OpenAiChatClient openAiChatClient(OpenAiConnectionProperties commonProperties,
5860
OpenAiChatProperties chatProperties, RestClient.Builder restClientBuilder,
59-
List<ToolFunctionCallback> toolFunctionCallbacks) {
61+
List<ToolFunctionCallback> toolFunctionCallbacks, SpringAiFunctionAnnotationManager functionManager) {
6062

6163
String apiKey = StringUtils.hasText(chatProperties.getApiKey()) ? chatProperties.getApiKey()
6264
: commonProperties.getApiKey();
@@ -73,6 +75,11 @@ public OpenAiChatClient openAiChatClient(OpenAiConnectionProperties commonProper
7375
chatProperties.getOptions().getToolCallbacks().addAll(toolFunctionCallbacks);
7476
}
7577

78+
var annotatedFunctionsList = functionManager.getAnnotatedToolFunctionCallbacks();
79+
if (!CollectionUtils.isEmpty(annotatedFunctionsList)) {
80+
chatProperties.getOptions().getToolCallbacks().addAll(annotatedFunctionsList);
81+
}
82+
7683
return new OpenAiChatClient(openAiApi, chatProperties.getOptions());
7784
}
7885

@@ -113,4 +120,12 @@ public OpenAiImageClient openAiImageClient(OpenAiConnectionProperties commonProp
113120
return new OpenAiImageClient(openAiImageApi).withDefaultOptions(imageProperties.getOptions());
114121
}
115122

123+
@Bean
124+
@ConditionalOnMissingBean
125+
public SpringAiFunctionAnnotationManager springAiFunctionManager(ApplicationContext context) {
126+
SpringAiFunctionAnnotationManager manager = new SpringAiFunctionAnnotationManager();
127+
manager.setApplicationContext(context);
128+
return manager;
129+
}
130+
116131
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
/*
2+
* Copyright 2024-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.autoconfigure.openai.tool;
18+
19+
import java.util.List;
20+
import java.util.function.Function;
21+
22+
import org.junit.jupiter.api.Test;
23+
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
24+
import org.slf4j.Logger;
25+
import org.slf4j.LoggerFactory;
26+
27+
import org.springframework.ai.autoconfigure.common.function.SpringAiFunction;
28+
import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration;
29+
import org.springframework.ai.chat.ChatResponse;
30+
import org.springframework.ai.chat.messages.UserMessage;
31+
import org.springframework.ai.chat.prompt.Prompt;
32+
import org.springframework.ai.openai.OpenAiChatClient;
33+
import org.springframework.ai.openai.OpenAiChatOptions;
34+
import org.springframework.boot.autoconfigure.AutoConfigurations;
35+
import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration;
36+
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
37+
import org.springframework.context.annotation.Configuration;
38+
39+
import static org.assertj.core.api.Assertions.assertThat;
40+
41+
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".*")
42+
class ToolCallWithSpringAIFunctionAnnotationIT {
43+
44+
private final Logger logger = LoggerFactory.getLogger(ToolCallWithBeanFunctionRegistrationIT.class);
45+
46+
private final ApplicationContextRunner contextRunner = new ApplicationContextRunner()
47+
.withPropertyValues("spring.ai.openai.apiKey=" + System.getenv("OPENAI_API_KEY"))
48+
.withConfiguration(AutoConfigurations.of(RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class))
49+
.withUserConfiguration(Config.class);
50+
51+
@Test
52+
void functionCallTest() {
53+
contextRunner.withPropertyValues("spring.ai.openai.chat.options.model=gpt-4-1106-preview").run(context -> {
54+
55+
OpenAiChatClient chatClient = context.getBean(OpenAiChatClient.class);
56+
57+
UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?");
58+
59+
ChatResponse response = chatClient.call(new Prompt(List.of(userMessage),
60+
OpenAiChatOptions.builder().withEnabledFunction("WeatherInfo").build()));
61+
62+
logger.info("Response: {}", response);
63+
64+
assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15");
65+
66+
});
67+
}
68+
69+
@Configuration
70+
static class Config {
71+
@SpringAiFunction(name = "WeatherInfo", description = "Get the weather in location",
72+
classType = MockWeatherService.Request.class)
73+
public Function<MockWeatherService.Request, MockWeatherService.Response> weatherFunction() {
74+
MockWeatherService weatherService = new MockWeatherService();
75+
return (weatherService::apply);
76+
}
77+
78+
}
79+
80+
}

0 commit comments

Comments
 (0)