Skip to content

Commit 6c52c99

Browse files
tzolovmarkpollack
authored andcommitted
feat: Improve validation in MethodToolCallbackProvider
This ensures that validation errors are caught early during object construction rather than later when methods are called, providing better error feedback. - Add validation for tool-annotated methods during construction - Validate duplicate tool names in constructor instead of only at getToolCallbacks() time - Add comprehensive test suite for MethodToolCallbackProvider Signed-off-by: Christian Tzolov <christian.tzolov@broadcom.com>
1 parent 90cab21 commit 6c52c99

File tree

2 files changed

+161
-0
lines changed

2 files changed

+161
-0
lines changed

spring-ai-model/src/main/java/org/springframework/ai/tool/method/MethodToolCallbackProvider.java

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import java.lang.reflect.Method;
2020
import java.util.Arrays;
2121
import java.util.List;
22+
import java.util.Optional;
2223
import java.util.function.Consumer;
2324
import java.util.function.Function;
2425
import java.util.function.Supplier;
@@ -44,6 +45,7 @@
4445
* {@link Tool}-annotated methods.
4546
*
4647
* @author Thomas Vitale
48+
* @author Christian Tzolov
4749
* @since 1.0.0
4850
*/
4951
public final class MethodToolCallbackProvider implements ToolCallbackProvider {
@@ -55,7 +57,26 @@ public final class MethodToolCallbackProvider implements ToolCallbackProvider {
5557
private MethodToolCallbackProvider(List<Object> toolObjects) {
5658
Assert.notNull(toolObjects, "toolObjects cannot be null");
5759
Assert.noNullElements(toolObjects, "toolObjects cannot contain null elements");
60+
assertToolAnnotatedMethodsPresent(toolObjects);
5861
this.toolObjects = toolObjects;
62+
validateToolCallbacks(getToolCallbacks());
63+
}
64+
65+
private void assertToolAnnotatedMethodsPresent(List<Object> toolObjects) {
66+
67+
for (Object toolObject : toolObjects) {
68+
List<Method> toolMethods = Stream
69+
.of(ReflectionUtils.getDeclaredMethods(
70+
AopUtils.isAopProxy(toolObject) ? AopUtils.getTargetClass(toolObject) : toolObject.getClass()))
71+
.filter(toolMethod -> toolMethod.isAnnotationPresent(Tool.class))
72+
.filter(toolMethod -> !isFunctionalType(toolMethod))
73+
.toList();
74+
75+
if (toolMethods.isEmpty()) {
76+
throw new IllegalStateException("No @Tool annotated methods found in " + toolObject + "."
77+
+ "Did you mean to pass a ToolCallback or ToolCallbackProvider? If so, you have to use .toolCallbacks() instead of .tool()");
78+
}
79+
}
5980
}
6081

6182
@Override
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
/*
2+
* Copyright 2025-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.tool.method;
18+
19+
import java.util.function.Consumer;
20+
import java.util.function.Function;
21+
import java.util.function.Supplier;
22+
23+
import org.junit.jupiter.api.Test;
24+
25+
import org.springframework.ai.tool.annotation.Tool;
26+
27+
import static org.assertj.core.api.Assertions.assertThat;
28+
import static org.assertj.core.api.Assertions.assertThatThrownBy;
29+
30+
/**
31+
* Unit tests for {@link MethodToolCallbackProvider}.
32+
*
33+
* @author Christian Tzolov
34+
*/
35+
class MethodToolCallbackProviderTests {
36+
37+
@Test
38+
void whenToolObjectHasToolAnnotatedMethodThenSucceed() {
39+
MethodToolCallbackProvider provider = MethodToolCallbackProvider.builder()
40+
.toolObjects(new ValidToolObject())
41+
.build();
42+
43+
assertThat(provider.getToolCallbacks()).hasSize(1);
44+
assertThat(provider.getToolCallbacks()[0].getToolDefinition().name()).isEqualTo("validTool");
45+
}
46+
47+
@Test
48+
void whenToolObjectHasNoToolAnnotatedMethodThenThrow() {
49+
assertThatThrownBy(
50+
() -> MethodToolCallbackProvider.builder().toolObjects(new NoToolAnnotatedMethodObject()).build())
51+
.isInstanceOf(IllegalStateException.class)
52+
.hasMessageContaining("No @Tool annotated methods found in");
53+
}
54+
55+
@Test
56+
void whenToolObjectHasOnlyFunctionalTypeToolMethodsThenThrow() {
57+
assertThatThrownBy(() -> MethodToolCallbackProvider.builder()
58+
.toolObjects(new OnlyFunctionalTypeToolMethodsObject())
59+
.build()).isInstanceOf(IllegalStateException.class)
60+
.hasMessageContaining("No @Tool annotated methods found in");
61+
}
62+
63+
@Test
64+
void whenToolObjectHasMixOfValidAndFunctionalTypeToolMethodsThenSucceed() {
65+
MethodToolCallbackProvider provider = MethodToolCallbackProvider.builder()
66+
.toolObjects(new MixedToolMethodsObject())
67+
.build();
68+
69+
assertThat(provider.getToolCallbacks()).hasSize(1);
70+
assertThat(provider.getToolCallbacks()[0].getToolDefinition().name()).isEqualTo("validTool");
71+
}
72+
73+
@Test
74+
void whenMultipleToolObjectsWithSameToolNameThenThrow() {
75+
assertThatThrownBy(() -> MethodToolCallbackProvider.builder()
76+
.toolObjects(new ValidToolObject(), new DuplicateToolNameObject())
77+
.build()).isInstanceOf(IllegalStateException.class)
78+
.hasMessageContaining("Multiple tools with the same name (validTool) found in sources");
79+
}
80+
81+
static class ValidToolObject {
82+
83+
@Tool
84+
public String validTool() {
85+
return "Valid tool result";
86+
}
87+
88+
}
89+
90+
static class NoToolAnnotatedMethodObject {
91+
92+
public String notATool() {
93+
return "Not a tool";
94+
}
95+
96+
}
97+
98+
static class OnlyFunctionalTypeToolMethodsObject {
99+
100+
@Tool
101+
public Function<String, String> functionTool() {
102+
return input -> "Function result: " + input;
103+
}
104+
105+
@Tool
106+
public Supplier<String> supplierTool() {
107+
return () -> "Supplier result";
108+
}
109+
110+
@Tool
111+
public Consumer<String> consumerTool() {
112+
return input -> System.out.println("Consumer received: " + input);
113+
}
114+
115+
}
116+
117+
static class MixedToolMethodsObject {
118+
119+
@Tool
120+
public String validTool() {
121+
return "Valid tool result";
122+
}
123+
124+
@Tool
125+
public Function<String, String> functionTool() {
126+
return input -> "Function result: " + input;
127+
}
128+
129+
}
130+
131+
static class DuplicateToolNameObject {
132+
133+
@Tool
134+
public String validTool() {
135+
return "Duplicate tool result";
136+
}
137+
138+
}
139+
140+
}

0 commit comments

Comments
 (0)