|
16 | 16 |
|
17 | 17 | package org.springframework.ai.model.tool;
|
18 | 18 |
|
| 19 | +import java.lang.reflect.Method; |
19 | 20 | import java.util.List;
|
20 | 21 | import java.util.Map;
|
21 | 22 |
|
|
27 | 28 | import org.springframework.ai.chat.messages.UserMessage;
|
28 | 29 | import org.springframework.ai.chat.model.ChatResponse;
|
29 | 30 | import org.springframework.ai.chat.model.Generation;
|
| 31 | +import org.springframework.ai.chat.model.ToolContext; |
30 | 32 | import org.springframework.ai.chat.prompt.Prompt;
|
31 | 33 | import org.springframework.ai.tool.ToolCallback;
|
32 | 34 | import org.springframework.ai.tool.definition.DefaultToolDefinition;
|
33 | 35 | import org.springframework.ai.tool.definition.ToolDefinition;
|
34 | 36 | import org.springframework.ai.tool.execution.ToolExecutionException;
|
35 | 37 | import org.springframework.ai.tool.execution.ToolExecutionExceptionProcessor;
|
36 | 38 | import org.springframework.ai.tool.metadata.ToolMetadata;
|
| 39 | +import org.springframework.ai.tool.method.MethodToolCallback; |
37 | 40 | import org.springframework.ai.tool.resolution.StaticToolCallbackResolver;
|
38 | 41 | import org.springframework.ai.tool.resolution.ToolCallbackResolver;
|
39 | 42 |
|
|
45 | 48 | * Unit tests for {@link DefaultToolCallingManager}.
|
46 | 49 | *
|
47 | 50 | * @author Thomas Vitale
|
| 51 | + * @author Sun Yuhan |
48 | 52 | */
|
49 | 53 | class DefaultToolCallingManagerTests {
|
50 | 54 |
|
@@ -317,6 +321,49 @@ void whenToolCallWithExceptionThenReturnError() {
|
317 | 321 | assertThat(toolExecutionResult.conversationHistory()).contains(expectedToolResponse);
|
318 | 322 | }
|
319 | 323 |
|
| 324 | + @Test |
| 325 | + void whenMixedMethodToolCallsInChatResponseThenExecute() throws NoSuchMethodException { |
| 326 | + ToolCallingManager toolCallingManager = DefaultToolCallingManager.builder().build(); |
| 327 | + |
| 328 | + ToolDefinition toolDefinitionA = ToolDefinition.builder().name("toolA").inputSchema("{}").build(); |
| 329 | + Method methodA = TestGenericClass.class.getMethod("call", String.class); |
| 330 | + MethodToolCallback methodToolCallback = MethodToolCallback.builder() |
| 331 | + .toolDefinition(toolDefinitionA) |
| 332 | + .toolMethod(methodA) |
| 333 | + .toolObject(new TestGenericClass()) |
| 334 | + .build(); |
| 335 | + |
| 336 | + ToolDefinition toolDefinitionB = ToolDefinition.builder().name("toolB").inputSchema("{}").build(); |
| 337 | + Method methodB = TestGenericClass.class.getMethod("callWithToolContext", ToolContext.class); |
| 338 | + MethodToolCallback methodToolCallbackNeedToolContext = MethodToolCallback.builder() |
| 339 | + .toolDefinition(toolDefinitionB) |
| 340 | + .toolMethod(methodB) |
| 341 | + .toolObject(new TestGenericClass()) |
| 342 | + .build(); |
| 343 | + |
| 344 | + Prompt prompt = new Prompt(new UserMessage("Hello"), |
| 345 | + ToolCallingChatOptions.builder() |
| 346 | + .toolCallbacks(methodToolCallback, methodToolCallbackNeedToolContext) |
| 347 | + .toolNames("toolA", "toolB") |
| 348 | + .toolContext("key", "value") |
| 349 | + .build()); |
| 350 | + |
| 351 | + ChatResponse chatResponse = ChatResponse.builder() |
| 352 | + .generations(List.of(new Generation(new AssistantMessage("", Map.of(), |
| 353 | + List.of(new AssistantMessage.ToolCall("toolA", "function", "toolA", "{}"), |
| 354 | + new AssistantMessage.ToolCall("toolB", "function", "toolB", "{}")))))) |
| 355 | + .build(); |
| 356 | + |
| 357 | + ToolResponseMessage expectedToolResponse = new ToolResponseMessage( |
| 358 | + List.of(new ToolResponseMessage.ToolResponse("toolA", "toolA", TestGenericClass.CALL_RESULT_JSON), |
| 359 | + new ToolResponseMessage.ToolResponse("toolB", "toolB", |
| 360 | + TestGenericClass.CALL_WITH_TOOL_CONTEXT_RESULT_JSON))); |
| 361 | + |
| 362 | + ToolExecutionResult toolExecutionResult = toolCallingManager.executeToolCalls(prompt, chatResponse); |
| 363 | + |
| 364 | + assertThat(toolExecutionResult.conversationHistory()).contains(expectedToolResponse); |
| 365 | + } |
| 366 | + |
320 | 367 | static class TestToolCallback implements ToolCallback {
|
321 | 368 |
|
322 | 369 | private final ToolDefinition toolDefinition;
|
@@ -370,4 +417,31 @@ public String call(String toolInput) {
|
370 | 417 |
|
371 | 418 | }
|
372 | 419 |
|
| 420 | + /** |
| 421 | + * Test class with methods that use generic types. |
| 422 | + */ |
| 423 | + static class TestGenericClass { |
| 424 | + |
| 425 | + public final static String CALL_RESULT_JSON = """ |
| 426 | + { |
| 427 | + "result": "Mission accomplished!" |
| 428 | + } |
| 429 | + """; |
| 430 | + |
| 431 | + public final static String CALL_WITH_TOOL_CONTEXT_RESULT_JSON = """ |
| 432 | + { |
| 433 | + "result": "ToolContext mission accomplished!" |
| 434 | + } |
| 435 | + """; |
| 436 | + |
| 437 | + public String call(String toolInput) { |
| 438 | + return CALL_RESULT_JSON; |
| 439 | + } |
| 440 | + |
| 441 | + public String callWithToolContext(ToolContext toolContext) { |
| 442 | + return CALL_WITH_TOOL_CONTEXT_RESULT_JSON; |
| 443 | + } |
| 444 | + |
| 445 | + } |
| 446 | + |
373 | 447 | }
|
0 commit comments