Skip to content

Commit b9a6834

Browse files
sunyuhan1998ilayaperumalg
authored andcommitted
test: Add unit test to verify multiple method toolcallbacks with toolcontext
Auto-cherry-pick to 1.0.x Signed-off-by: Sun Yuhan <sunyuhan1998@users.noreply.github.com>
1 parent 16a7084 commit b9a6834

File tree

1 file changed

+74
-0
lines changed

1 file changed

+74
-0
lines changed

spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingManagerTests.java

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
package org.springframework.ai.model.tool;
1818

19+
import java.lang.reflect.Method;
1920
import java.util.List;
2021
import java.util.Map;
2122

@@ -27,13 +28,15 @@
2728
import org.springframework.ai.chat.messages.UserMessage;
2829
import org.springframework.ai.chat.model.ChatResponse;
2930
import org.springframework.ai.chat.model.Generation;
31+
import org.springframework.ai.chat.model.ToolContext;
3032
import org.springframework.ai.chat.prompt.Prompt;
3133
import org.springframework.ai.tool.ToolCallback;
3234
import org.springframework.ai.tool.definition.DefaultToolDefinition;
3335
import org.springframework.ai.tool.definition.ToolDefinition;
3436
import org.springframework.ai.tool.execution.ToolExecutionException;
3537
import org.springframework.ai.tool.execution.ToolExecutionExceptionProcessor;
3638
import org.springframework.ai.tool.metadata.ToolMetadata;
39+
import org.springframework.ai.tool.method.MethodToolCallback;
3740
import org.springframework.ai.tool.resolution.StaticToolCallbackResolver;
3841
import org.springframework.ai.tool.resolution.ToolCallbackResolver;
3942

@@ -45,6 +48,7 @@
4548
* Unit tests for {@link DefaultToolCallingManager}.
4649
*
4750
* @author Thomas Vitale
51+
* @author Sun Yuhan
4852
*/
4953
class DefaultToolCallingManagerTests {
5054

@@ -317,6 +321,49 @@ void whenToolCallWithExceptionThenReturnError() {
317321
assertThat(toolExecutionResult.conversationHistory()).contains(expectedToolResponse);
318322
}
319323

324+
@Test
325+
void whenMixedMethodToolCallsInChatResponseThenExecute() throws NoSuchMethodException {
326+
ToolCallingManager toolCallingManager = DefaultToolCallingManager.builder().build();
327+
328+
ToolDefinition toolDefinitionA = ToolDefinition.builder().name("toolA").inputSchema("{}").build();
329+
Method methodA = TestGenericClass.class.getMethod("call", String.class);
330+
MethodToolCallback methodToolCallback = MethodToolCallback.builder()
331+
.toolDefinition(toolDefinitionA)
332+
.toolMethod(methodA)
333+
.toolObject(new TestGenericClass())
334+
.build();
335+
336+
ToolDefinition toolDefinitionB = ToolDefinition.builder().name("toolB").inputSchema("{}").build();
337+
Method methodB = TestGenericClass.class.getMethod("callWithToolContext", ToolContext.class);
338+
MethodToolCallback methodToolCallbackNeedToolContext = MethodToolCallback.builder()
339+
.toolDefinition(toolDefinitionB)
340+
.toolMethod(methodB)
341+
.toolObject(new TestGenericClass())
342+
.build();
343+
344+
Prompt prompt = new Prompt(new UserMessage("Hello"),
345+
ToolCallingChatOptions.builder()
346+
.toolCallbacks(methodToolCallback, methodToolCallbackNeedToolContext)
347+
.toolNames("toolA", "toolB")
348+
.toolContext("key", "value")
349+
.build());
350+
351+
ChatResponse chatResponse = ChatResponse.builder()
352+
.generations(List.of(new Generation(new AssistantMessage("", Map.of(),
353+
List.of(new AssistantMessage.ToolCall("toolA", "function", "toolA", "{}"),
354+
new AssistantMessage.ToolCall("toolB", "function", "toolB", "{}"))))))
355+
.build();
356+
357+
ToolResponseMessage expectedToolResponse = new ToolResponseMessage(
358+
List.of(new ToolResponseMessage.ToolResponse("toolA", "toolA", TestGenericClass.CALL_RESULT_JSON),
359+
new ToolResponseMessage.ToolResponse("toolB", "toolB",
360+
TestGenericClass.CALL_WITH_TOOL_CONTEXT_RESULT_JSON)));
361+
362+
ToolExecutionResult toolExecutionResult = toolCallingManager.executeToolCalls(prompt, chatResponse);
363+
364+
assertThat(toolExecutionResult.conversationHistory()).contains(expectedToolResponse);
365+
}
366+
320367
static class TestToolCallback implements ToolCallback {
321368

322369
private final ToolDefinition toolDefinition;
@@ -370,4 +417,31 @@ public String call(String toolInput) {
370417

371418
}
372419

420+
/**
421+
* Test class with methods that use generic types.
422+
*/
423+
static class TestGenericClass {
424+
425+
public final static String CALL_RESULT_JSON = """
426+
{
427+
"result": "Mission accomplished!"
428+
}
429+
""";
430+
431+
public final static String CALL_WITH_TOOL_CONTEXT_RESULT_JSON = """
432+
{
433+
"result": "ToolContext mission accomplished!"
434+
}
435+
""";
436+
437+
public String call(String toolInput) {
438+
return CALL_RESULT_JSON;
439+
}
440+
441+
public String callWithToolContext(ToolContext toolContext) {
442+
return CALL_WITH_TOOL_CONTEXT_RESULT_JSON;
443+
}
444+
445+
}
446+
373447
}

0 commit comments

Comments
 (0)