Skip to content

Commit 432954d

Browse files
tzolovmarkpollack
authored andcommitted
Add Supplier and Consumer function callback support in function calling
Add support for no-argument Supplier and single-argument Consumer function callbacks in the Spring AI core module. This enhancement allows: - Registration of Supplier<O> callbacks with no input (Void) type - Registration of Consumer<I> callbacks with no output (Void) type - Support for Kotlin Function0 (equivalent to Java Supplier) - Handle empty properties for Void input types in schema generation - Enhance FunctionCallback builder to support Supplier/Consumer patterns Additional changes: - Add test coverage for both Supplier and Consumer callbacks in various scenarios - Enhance TypeResolverHelper to support Consumer input type resolution - Support lambda-style function declarations for improved ergonomics - Add test cases for void input/output handling in OpenAI chat model - Include examples of function calls without return values - Add support for parameterless functions through Supplier interface Add comprehensive documentation for the FunctionCallback API: - Overview of the interface and its key methods - Builder pattern usage with function and method invocation approaches - Examples for different function types (Function, BiFunction, Supplier, Consumer) - Best practices and common pitfalls - Schema generation and customization options Resolves #1718 , #1277 , #1118, #860
1 parent f9a9c02 commit 432954d

File tree

14 files changed

+813
-196
lines changed

14 files changed

+813
-196
lines changed

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/speech/OpenAiSpeechModelIT.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,8 @@ void shouldStreamNonEmptyResponsesForValidSpeechPrompts() {
106106
List<SpeechResponse> responses = responseFlux.collectList().block();
107107
assertThat(responses).isNotNull();
108108
responses.forEach(response -> {
109-
System.out.println("Audio data chunk size: " + response.getResult().getOutput().length);
109+
// System.out.println("Audio data chunk size: " +
110+
// response.getResult().getOutput().length);
110111
assertThat(response.getResult().getOutput()).isNotEmpty();
111112
});
112113
}

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelFunctionCallingIT.java

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import java.util.ArrayList;
2020
import java.util.List;
2121
import java.util.Map;
22+
import java.util.concurrent.ConcurrentHashMap;
2223
import java.util.function.BiFunction;
2324
import java.util.stream.Collectors;
2425

@@ -28,6 +29,7 @@
2829
import org.slf4j.LoggerFactory;
2930
import reactor.core.publisher.Flux;
3031

32+
import org.springframework.ai.chat.client.ChatClient;
3133
import org.springframework.ai.chat.messages.AssistantMessage;
3234
import org.springframework.ai.chat.messages.Message;
3335
import org.springframework.ai.chat.messages.UserMessage;
@@ -59,6 +61,25 @@ class OpenAiChatModelFunctionCallingIT {
5961
@Autowired
6062
ChatModel chatModel;
6163

64+
@Test
65+
void functionCallSupplier() {
66+
67+
Map<String, Object> state = new ConcurrentHashMap<>();
68+
69+
// @formatter:off
70+
String response = ChatClient.create(this.chatModel).prompt()
71+
.user("Turn the light on in the living room")
72+
.functions(FunctionCallback.builder()
73+
.function("turnsLightOnInTheLivingRoom", () -> state.put("Light", "ON"))
74+
.build())
75+
.call()
76+
.content();
77+
// @formatter:on
78+
79+
logger.info("Response: {}", response);
80+
assertThat(state).containsEntry("Light", "ON");
81+
}
82+
6283
@Test
6384
void functionCallTest() {
6485
functionCallTest(OpenAiChatOptions.builder()

spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ private static String toGetName(String name) {
340340
* @return the generated JSON Schema as a String.
341341
* @deprecated use {@link #getJsonSchema(Type, boolean)} instead.
342342
*/
343-
@Deprecated
343+
@Deprecated(since = "1.0 M4")
344344
public static String getJsonSchema(Class<?> clazz, boolean toUpperCaseTypeValues) {
345345

346346
if (SCHEMA_GENERATOR_CACHE.get() == null) {
@@ -395,6 +395,11 @@ public static String getJsonSchema(Type inputType, boolean toUpperCaseTypeValues
395395
}
396396

397397
ObjectNode node = SCHEMA_GENERATOR_CACHE.get().generateSchema(inputType);
398+
399+
if ((inputType == Void.class) && !node.has("properties")) {
400+
node.putObject("properties");
401+
}
402+
398403
if (toUpperCaseTypeValues) { // Required for OpenAPI 3.0 (at least Vertex AI
399404
// version of it).
400405
toUpperCaseTypeValues(node);

spring-ai-core/src/main/java/org/springframework/ai/model/function/DefaultFunctionCallbackBuilder.java

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,27 @@
11
/*
2-
* Copyright 2024 - 2024 the original author or authors.
3-
*
4-
* Licensed under the Apache License, Version 2.0 (the "License");
5-
* you may not use this file except in compliance with the License.
6-
* You may obtain a copy of the License at
7-
*
8-
* https://www.apache.org/licenses/LICENSE-2.0
9-
*
10-
* Unless required by applicable law or agreed to in writing, software
11-
* distributed under the License is distributed on an "AS IS" BASIS,
12-
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13-
* See the License for the specific language governing permissions and
14-
* limitations under the License.
15-
*/
2+
* Copyright 2023-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
1617
package org.springframework.ai.model.function;
1718

1819
import java.lang.reflect.Type;
1920
import java.util.Arrays;
2021
import java.util.function.BiFunction;
22+
import java.util.function.Consumer;
2123
import java.util.function.Function;
24+
import java.util.function.Supplier;
2225

2326
import com.fasterxml.jackson.core.JsonProcessingException;
2427
import com.fasterxml.jackson.databind.DeserializationFeature;
@@ -43,7 +46,7 @@
4346

4447
/**
4548
* Default implementation of the {@link FunctionCallback.Builder}.
46-
*
49+
*
4750
* @author Christian Tzolov
4851
* @since 1.0.0
4952
*/
@@ -137,6 +140,20 @@ public <I, O> FunctionInvokingSpec<I, O> function(String name, BiFunction<I, Too
137140
return new DefaultFunctionInvokingSpec<>(name, biFunction);
138141
}
139142

143+
@Override
144+
public <O> FunctionInvokingSpec<Void, O> function(String name, Supplier<O> supplier) {
145+
Function<Void, O> function = (input) -> supplier.get();
146+
return new DefaultFunctionInvokingSpec<>(name, function).inputType(Void.class);
147+
}
148+
149+
public <I> FunctionInvokingSpec<I, Void> function(String name, Consumer<I> consumer) {
150+
Function<I, Void> function = (I input) -> {
151+
consumer.accept(input);
152+
return null;
153+
};
154+
return new DefaultFunctionInvokingSpec<>(name, function);
155+
}
156+
140157
@Override
141158
public MethodInvokingSpec method(String methodName, Class<?>... argumentTypes) {
142159
return new DefaultMethodInvokingSpec(methodName, argumentTypes);

spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallback.java

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
package org.springframework.ai.model.function;
1818

1919
import java.util.function.BiFunction;
20+
import java.util.function.Consumer;
2021
import java.util.function.Function;
22+
import java.util.function.Supplier;
2123

2224
import com.fasterxml.jackson.databind.ObjectMapper;
2325

@@ -141,6 +143,16 @@ interface Builder {
141143
*/
142144
<I, O> FunctionInvokingSpec<I, O> function(String name, BiFunction<I, ToolContext, O> biFunction);
143145

146+
/**
147+
* Builds a {@link Supplier} invoking {@link FunctionCallback} instance.
148+
*/
149+
<O> FunctionInvokingSpec<Void, O> function(String name, Supplier<O> supplier);
150+
151+
/**
152+
* Builds a {@link Consumer} invoking {@link FunctionCallback} instance.
153+
*/
154+
<I> FunctionInvokingSpec<I, Void> function(String name, Consumer<I> consumer);
155+
144156
/**
145157
* Builds a Method invoking {@link FunctionCallback} instance.
146158
*/
@@ -189,14 +201,14 @@ interface MethodInvokingSpec {
189201
MethodInvokingSpec name(String name);
190202

191203
/**
192-
* For non static objects the target object is used to invoke the method.
204+
* For non-static objects the target object is used to invoke the method.
193205
* @param methodObject target object where the method is defined.
194206
*/
195207
MethodInvokingSpec targetObject(Object methodObject);
196208

197209
/**
198-
* Target class where the method is defined. Used for static methods. For non
199-
* static methods the target object is used.
210+
* Target class where the method is defined. Used for static methods. For
211+
* non-static methods the target object is used.
200212
* @param targetClass method target class.
201213
*/
202214
MethodInvokingSpec targetClass(Class<?> targetClass);

spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackContext.java

Lines changed: 64 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,12 @@
1717
package org.springframework.ai.model.function;
1818

1919
import java.util.function.BiFunction;
20+
import java.util.function.Consumer;
2021
import java.util.function.Function;
22+
import java.util.function.Supplier;
2123

2224
import com.fasterxml.jackson.annotation.JsonClassDescription;
25+
import kotlin.jvm.functions.Function0;
2326
import kotlin.jvm.functions.Function1;
2427
import kotlin.jvm.functions.Function2;
2528

@@ -30,6 +33,7 @@
3033
import org.springframework.context.annotation.Description;
3134
import org.springframework.context.support.GenericApplicationContext;
3235
import org.springframework.core.KotlinDetector;
36+
import org.springframework.core.ParameterizedTypeReference;
3337
import org.springframework.core.ResolvableType;
3438
import org.springframework.lang.NonNull;
3539
import org.springframework.lang.Nullable;
@@ -38,9 +42,9 @@
3842
/**
3943
* A Spring {@link ApplicationContextAware} implementation that provides a way to retrieve
4044
* a {@link Function} from the Spring context and wrap it into a {@link FunctionCallback}.
41-
*
45+
* <p>
4246
* The name of the function is determined by the bean name.
43-
*
47+
* <p>
4448
* The description of the function is determined by the following rules:
4549
* <ul>
4650
* <li>Provided as a default description</li>
@@ -69,24 +73,28 @@ public void setApplicationContext(@NonNull ApplicationContext applicationContext
6973

7074
@SuppressWarnings({ "unchecked" })
7175
public FunctionCallback getFunctionCallback(@NonNull String beanName, @Nullable String defaultDescription) {
72-
7376
ResolvableType functionType = TypeResolverHelper.resolveBeanType(this.applicationContext, beanName);
74-
ResolvableType functionInputType = TypeResolverHelper.getFunctionArgumentType(functionType, 0);
77+
ResolvableType functionInputType = (ResolvableType.forType(Supplier.class).isAssignableFrom(functionType))
78+
? ResolvableType.forType(Void.class) : TypeResolverHelper.getFunctionArgumentType(functionType, 0);
79+
80+
String functionDescription = resolveFunctionDescription(beanName, defaultDescription,
81+
functionInputType.toClass());
82+
Object bean = this.applicationContext.getBean(beanName);
83+
84+
return buildFunctionCallback(beanName, functionType, functionInputType, functionDescription, bean);
85+
}
7586

76-
Class<?> functionInputClass = functionInputType.toClass();
87+
private String resolveFunctionDescription(String beanName, String defaultDescription, Class<?> functionInputClass) {
7788
String functionDescription = defaultDescription;
7889

7990
if (!StringUtils.hasText(functionDescription)) {
80-
// Look for a Description annotation on the bean
8191
Description descriptionAnnotation = this.applicationContext.findAnnotationOnBean(beanName,
8292
Description.class);
83-
8493
if (descriptionAnnotation != null) {
8594
functionDescription = descriptionAnnotation.value();
8695
}
8796

8897
if (!StringUtils.hasText(functionDescription)) {
89-
// Look for a JsonClassDescription annotation on the input class
9098
JsonClassDescription jsonClassDescriptionAnnotation = functionInputClass
9199
.getAnnotation(JsonClassDescription.class);
92100
if (jsonClassDescriptionAnnotation != null) {
@@ -95,51 +103,79 @@ public FunctionCallback getFunctionCallback(@NonNull String beanName, @Nullable
95103
}
96104

97105
if (!StringUtils.hasText(functionDescription)) {
98-
throw new IllegalStateException("Could not determine function description."
106+
throw new IllegalStateException("Could not determine function description. "
99107
+ "Please provide a description either as a default parameter, via @Description annotation on the bean "
100108
+ "or @JsonClassDescription annotation on the input class.");
101109
}
102110
}
103111

104-
Object bean = this.applicationContext.getBean(beanName);
112+
return functionDescription;
113+
}
114+
115+
private FunctionCallback buildFunctionCallback(String beanName, ResolvableType functionType,
116+
ResolvableType functionInputType, String functionDescription, Object bean) {
105117

106118
if (KotlinDetector.isKotlinPresent()) {
107119
if (KotlinDelegate.isKotlinFunction(functionType.toClass())) {
108120
return FunctionCallback.builder()
109121
.schemaType(this.schemaType)
110122
.description(functionDescription)
111123
.function(beanName, KotlinDelegate.wrapKotlinFunction(bean))
112-
.inputType(functionInputClass)
124+
.inputType(ParameterizedTypeReference.forType(functionInputType.getType()))
113125
.build();
114126
}
115-
else if (KotlinDelegate.isKotlinBiFunction(functionType.toClass())) {
127+
if (KotlinDelegate.isKotlinBiFunction(functionType.toClass())) {
116128
return FunctionCallback.builder()
117129
.description(functionDescription)
118130
.schemaType(this.schemaType)
119131
.function(beanName, KotlinDelegate.wrapKotlinBiFunction(bean))
120-
.inputType(functionInputClass)
132+
.inputType(ParameterizedTypeReference.forType(functionInputType.getType()))
133+
.build();
134+
}
135+
if (KotlinDelegate.isKotlinSupplier(functionType.toClass())) {
136+
return FunctionCallback.builder()
137+
.description(functionDescription)
138+
.schemaType(this.schemaType)
139+
.function(beanName, KotlinDelegate.wrapKotlinSupplier(bean))
140+
.inputType(ParameterizedTypeReference.forType(functionInputType.getType()))
121141
.build();
122142
}
123143
}
144+
124145
if (bean instanceof Function<?, ?> function) {
125146
return FunctionCallback.builder()
126147
.schemaType(this.schemaType)
127148
.description(functionDescription)
128149
.function(beanName, function)
129-
.inputType(functionInputClass)
150+
.inputType(ParameterizedTypeReference.forType(functionInputType.getType()))
130151
.build();
131152
}
132-
else if (bean instanceof BiFunction<?, ?, ?>) {
153+
if (bean instanceof BiFunction<?, ?, ?>) {
133154
return FunctionCallback.builder()
134155
.description(functionDescription)
135156
.schemaType(this.schemaType)
136157
.function(beanName, (BiFunction<?, ToolContext, ?>) bean)
137-
.inputType(functionInputClass)
158+
.inputType(ParameterizedTypeReference.forType(functionInputType.getType()))
159+
.build();
160+
}
161+
if (bean instanceof Supplier<?> supplier) {
162+
return FunctionCallback.builder()
163+
.description(functionDescription)
164+
.schemaType(this.schemaType)
165+
.function(beanName, supplier)
166+
.inputType(ParameterizedTypeReference.forType(functionInputType.getType()))
138167
.build();
139168
}
140-
else {
141-
throw new IllegalStateException();
169+
if (bean instanceof Consumer<?> consumer) {
170+
return FunctionCallback.builder()
171+
.description(functionDescription)
172+
.schemaType(this.schemaType)
173+
.function(beanName, consumer)
174+
.inputType(ParameterizedTypeReference.forType(functionInputType.getType()))
175+
.build();
142176
}
177+
178+
throw new IllegalStateException("Unsupported function type");
143179
}
144180

145181
public enum SchemaType {
@@ -148,7 +184,16 @@ public enum SchemaType {
148184

149185
}
150186

151-
private static class KotlinDelegate {
187+
private static final class KotlinDelegate {
188+
189+
public static boolean isKotlinSupplier(Class<?> clazz) {
190+
return Function0.class.isAssignableFrom(clazz);
191+
}
192+
193+
@SuppressWarnings("unchecked")
194+
public static Supplier<?> wrapKotlinSupplier(Object function) {
195+
return () -> ((Function0<Object>) function).invoke();
196+
}
152197

153198
public static boolean isKotlinFunction(Class<?> clazz) {
154199
return Function1.class.isAssignableFrom(clazz);

0 commit comments

Comments
 (0)