Skip to content

Commit c544d0c

Browse files
tzolovilayaperumalg
authored andcommitted
Add support for BiFunction class type resolution in TypeResolverHelper
- Adds special handling for BiFunction class types - Adds test cases to verify BiFunction class type resolution - Removes deprecated AbstractFunctionCallSupport class - Cleans up unused imports in FunctionCallback Resolves #1576
1 parent 3288c55 commit c544d0c

File tree

4 files changed

+59
-200
lines changed

4 files changed

+59
-200
lines changed

spring-ai-core/src/main/java/org/springframework/ai/model/function/AbstractFunctionCallSupport.java

Lines changed: 0 additions & 195 deletions
This file was deleted.

spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallback.java

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515
*/
1616
package org.springframework.ai.model.function;
1717

18-
import java.util.Map;
19-
2018
import org.springframework.ai.chat.model.ToolContext;
2119

2220
/**

spring-ai-core/src/main/java/org/springframework/ai/model/function/TypeResolverHelper.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,14 @@ public static Type getFunctionArgumentType(Type functionType, int argumentIndex)
130130

131131
// Resolves: https://github.com/spring-projects/spring-ai/issues/726
132132
if (!(functionType instanceof ParameterizedType)) {
133-
functionType = FunctionTypeUtils.discoverFunctionTypeFromClass(FunctionTypeUtils.getRawType(functionType));
133+
Class<?> functionalClass = FunctionTypeUtils.getRawType(functionType);
134+
// Resolves: https://github.com/spring-projects/spring-ai/issues/1576
135+
if (BiFunction.class.isAssignableFrom(functionalClass)) {
136+
functionType = TypeResolver.reify(BiFunction.class, (Class<BiFunction<?, ?, ?>>) functionalClass);
137+
}
138+
else {
139+
functionType = FunctionTypeUtils.discoverFunctionTypeFromClass(functionalClass);
140+
}
134141
}
135142

136143
var argumentType = functionType instanceof ParameterizedType

spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWithPlainFunctionBeanIT.java

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ class FunctionCallbackWithPlainFunctionBeanIT {
5959
.withUserConfiguration(Config.class);
6060

6161
@Test
62-
void functionCallTest2() {
62+
void functionCallWithDirectBiFunction() {
6363
contextRunner.withPropertyValues("spring.ai.openai.chat.options.model=" + ChatModel.GPT_4_O_MINI.getName())
6464
.run(context -> {
6565

@@ -72,7 +72,7 @@ void functionCallTest2() {
7272
.toolContext(Map.of("sessionId", "123"))
7373
.call()
7474
.content();
75-
System.out.println(content);
75+
logger.info(content);
7676

7777
// Test weatherFunction
7878
UserMessage userMessage = new UserMessage(
@@ -91,6 +91,39 @@ void functionCallTest2() {
9191
});
9292
}
9393

94+
@Test
95+
void functionCallWithBiFunctionClass() {
96+
contextRunner.withPropertyValues("spring.ai.openai.chat.options.model=" + ChatModel.GPT_4_O_MINI.getName())
97+
.run(context -> {
98+
99+
OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class);
100+
101+
ChatClient chatClient = ChatClient.builder(chatModel).build();
102+
103+
String content = chatClient.prompt("What's the weather like in San Francisco, Tokyo, and Paris?")
104+
.functions("weatherFunctionWithClassBiFunction")
105+
.toolContext(Map.of("sessionId", "123"))
106+
.call()
107+
.content();
108+
logger.info(content);
109+
110+
// Test weatherFunction
111+
UserMessage userMessage = new UserMessage(
112+
"What's the weather like in San Francisco, Tokyo, and Paris? You can call the following functions 'weatherFunction'");
113+
114+
ChatResponse response = chatModel.call(new Prompt(List.of(userMessage),
115+
OpenAiChatOptions.builder()
116+
.withFunction("weatherFunctionWithClassBiFunction")
117+
.withToolContext(Map.of("sessionId", "123"))
118+
.build()));
119+
120+
logger.info("Response: {}", response);
121+
122+
assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15");
123+
124+
});
125+
}
126+
94127
@Test
95128
void functionCallTest() {
96129
contextRunner.withPropertyValues("spring.ai.openai.chat.options.model=" + ChatModel.GPT_4_O_MINI.getName())
@@ -196,6 +229,12 @@ void streamFunctionCallTest() {
196229
@Configuration
197230
static class Config {
198231

232+
@Bean
233+
@Description("Get the weather in location")
234+
public MyBiFunction weatherFunctionWithClassBiFunction() {
235+
return new MyBiFunction();
236+
}
237+
199238
@Bean
200239
@Description("Get the weather in location")
201240
public BiFunction<MockWeatherService.Request, ToolContext, MockWeatherService.Response> weatherFunctionWithContext() {
@@ -220,4 +259,14 @@ public Function<MockWeatherService.Request, MockWeatherService.Response> weather
220259

221260
}
222261

262+
public static class MyBiFunction
263+
implements BiFunction<MockWeatherService.Request, ToolContext, MockWeatherService.Response> {
264+
265+
@Override
266+
public MockWeatherService.Response apply(MockWeatherService.Request request, ToolContext context) {
267+
return new MockWeatherService().apply(request);
268+
}
269+
270+
}
271+
223272
}

0 commit comments

Comments
 (0)