Skip to content

Commit 2ca1be2

Browse files
sunyuhan1998ilayaperumalg
authored andcommitted
fix: corrected a logic error in the validateToolContextSupport method caused by incorrect parameter order.
Fixes #GH-3466 Signed-off-by: Sun Yuhan <sunyuhan1998@users.noreply.github.com>
1 parent fcbdac9 commit 2ca1be2

File tree

2 files changed

+79
-2
lines changed

2 files changed

+79
-2
lines changed

spring-ai-model/src/main/java/org/springframework/ai/tool/method/MethodToolCallback.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,8 @@ public String call(String toolInput, @Nullable ToolContext toolContext) {
118118
private void validateToolContextSupport(@Nullable ToolContext toolContext) {
119119
var isNonEmptyToolContextProvided = toolContext != null && !CollectionUtils.isEmpty(toolContext.getContext());
120120
var isToolContextAcceptedByMethod = Stream.of(this.toolMethod.getParameterTypes())
121-
.anyMatch(type -> ClassUtils.isAssignable(type, ToolContext.class));
122-
if (isToolContextAcceptedByMethod && !isNonEmptyToolContextProvided) {
121+
.anyMatch(type -> ClassUtils.isAssignable(ToolContext.class, type));
122+
if (isNonEmptyToolContextProvided && !isToolContextAcceptedByMethod) {
123123
throw new IllegalArgumentException("ToolContext is required by the method as an argument");
124124
}
125125
}

spring-ai-model/src/test/java/org/springframework/ai/tool/method/MethodToolCallbackGenericTypesTest.java

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,12 @@
2222

2323
import org.junit.jupiter.api.Test;
2424

25+
import org.springframework.ai.chat.model.ToolContext;
2526
import org.springframework.ai.tool.definition.DefaultToolDefinition;
2627
import org.springframework.ai.tool.definition.ToolDefinition;
2728

2829
import static org.assertj.core.api.Assertions.assertThat;
30+
import static org.assertj.core.api.Assertions.assertThatThrownBy;
2931

3032
/**
3133
* Tests for {@link MethodToolCallback} with generic types.
@@ -137,6 +139,76 @@ void testNestedGenericType() throws Exception {
137139
assertThat(result).isEqualTo("2 maps processed: [{a=1, b=2}, {c=3, d=4}]");
138140
}
139141

142+
@Test
143+
void testToolContextType() throws Exception {
144+
// Create a test object with a method that takes a List<Map<String, Integer>>
145+
TestGenericClass testObject = new TestGenericClass();
146+
Method method = TestGenericClass.class.getMethod("processStringListInToolContext", ToolContext.class);
147+
148+
// Create a tool definition
149+
ToolDefinition toolDefinition = DefaultToolDefinition.builder()
150+
.name("processToolContext")
151+
.description("Process tool context")
152+
.inputSchema("{}")
153+
.build();
154+
155+
// Create a MethodToolCallback
156+
MethodToolCallback callback = MethodToolCallback.builder()
157+
.toolDefinition(toolDefinition)
158+
.toolMethod(method)
159+
.toolObject(testObject)
160+
.build();
161+
162+
// Create an empty JSON input
163+
String toolInput = """
164+
{}
165+
""";
166+
167+
// Create a toolContext
168+
ToolContext toolContext = new ToolContext(Map.of("foo", "bar"));
169+
170+
// Call the tool
171+
String result = callback.call(toolInput, toolContext);
172+
173+
// Verify the result
174+
assertThat(result).isEqualTo("1 entries processed {foo=bar}");
175+
}
176+
177+
@Test
178+
void testToolContextTypeWithNonToolContextArgs() throws Exception {
179+
// Create a test object with a method that takes a List<String>
180+
TestGenericClass testObject = new TestGenericClass();
181+
Method method = TestGenericClass.class.getMethod("processStringList", List.class);
182+
183+
// Create a tool definition
184+
ToolDefinition toolDefinition = DefaultToolDefinition.builder()
185+
.name("processStringList")
186+
.description("Process a list of strings")
187+
.inputSchema("{}")
188+
.build();
189+
190+
// Create a MethodToolCallback
191+
MethodToolCallback callback = MethodToolCallback.builder()
192+
.toolDefinition(toolDefinition)
193+
.toolMethod(method)
194+
.toolObject(testObject)
195+
.build();
196+
197+
// Create a JSON input with a list of strings
198+
String toolInput = """
199+
{
200+
"strings": ["one", "two", "three"]
201+
}
202+
""";
203+
204+
// Create a toolContext
205+
ToolContext toolContext = new ToolContext(Map.of("foo", "bar"));
206+
207+
// Call the tool and verify
208+
assertThatThrownBy(() -> callback.call(toolInput, toolContext)).isInstanceOf(IllegalArgumentException.class)
209+
.hasMessageContaining("ToolContext is required by the method as an argument");
210+
}
211+
140212
/**
141213
* Test class with methods that use generic types.
142214
*/
@@ -154,6 +226,11 @@ public String processListOfMaps(List<Map<String, Integer>> listOfMaps) {
154226
return listOfMaps.size() + " maps processed: " + listOfMaps;
155227
}
156228

229+
public String processStringListInToolContext(ToolContext toolContext) {
230+
Map<String, Object> context = toolContext.getContext();
231+
return context.size() + " entries processed " + context;
232+
}
233+
157234
}
158235

159236
}

0 commit comments

Comments
 (0)