Skip to content

Commit fb1746b

Browse files
Kehrlanntzolov
authored andcommitted
fix: handle root exceptions in McpToolCallback (#3553)
- Sync and Async McpToolCallback can now handle exceptions where cause == null instead of throwing a NPE. Signed-off-by: Daniel Garnier-Moiroux <git@garnier.wf>
1 parent e0d3703 commit fb1746b

File tree

4 files changed

+117
-29
lines changed

4 files changed

+117
-29
lines changed

mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallback.java

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,12 @@
1616

1717
package org.springframework.ai.mcp;
1818

19-
import java.util.Map;
20-
2119
import io.modelcontextprotocol.client.McpAsyncClient;
20+
import io.modelcontextprotocol.spec.McpSchema;
2221
import io.modelcontextprotocol.spec.McpSchema.CallToolRequest;
2322
import io.modelcontextprotocol.spec.McpSchema.Tool;
23+
import java.util.Map;
24+
import reactor.core.publisher.Mono;
2425

2526
import org.springframework.ai.chat.model.ToolContext;
2627
import org.springframework.ai.model.ModelOptionsUtils;
@@ -112,19 +113,16 @@ public String call(String functionInput) {
112113
Map<String, Object> arguments = ModelOptionsUtils.jsonToMap(functionInput);
113114
// Note that we use the original tool name here, not the adapted one from
114115
// getToolDefinition
115-
try {
116-
return this.asyncMcpClient.callTool(new CallToolRequest(this.tool.name(), arguments)).map(response -> {
117-
if (response.isError() != null && response.isError()) {
118-
throw new ToolExecutionException(this.getToolDefinition(),
119-
new IllegalStateException("Error calling tool: " + response.content()));
120-
}
121-
return ModelOptionsUtils.toJsonString(response.content());
122-
}).block();
123-
}
124-
catch (Exception ex) {
125-
throw new ToolExecutionException(this.getToolDefinition(), ex.getCause());
126-
}
127-
116+
return this.asyncMcpClient.callTool(new CallToolRequest(this.tool.name(), arguments)).onErrorMap(exception -> {
117+
// If the tool throws an error during execution
118+
throw new ToolExecutionException(this.getToolDefinition(), exception);
119+
}).map(response -> {
120+
if (response.isError() != null && response.isError()) {
121+
throw new ToolExecutionException(this.getToolDefinition(),
122+
new IllegalStateException("Error calling tool: " + response.content()));
123+
}
124+
return ModelOptionsUtils.toJsonString(response.content());
125+
}).block();
128126
}
129127

130128
@Override

mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallback.java

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,11 @@
1616

1717
package org.springframework.ai.mcp;
1818

19-
import java.lang.reflect.InvocationTargetException;
20-
import java.util.Map;
21-
2219
import io.modelcontextprotocol.client.McpSyncClient;
2320
import io.modelcontextprotocol.spec.McpSchema.CallToolRequest;
2421
import io.modelcontextprotocol.spec.McpSchema.CallToolResult;
2522
import io.modelcontextprotocol.spec.McpSchema.Tool;
23+
import java.util.Map;
2624
import org.slf4j.Logger;
2725
import org.slf4j.LoggerFactory;
2826

@@ -32,7 +30,6 @@
3230
import org.springframework.ai.tool.definition.DefaultToolDefinition;
3331
import org.springframework.ai.tool.definition.ToolDefinition;
3432
import org.springframework.ai.tool.execution.ToolExecutionException;
35-
import org.springframework.core.log.LogAccessor;
3633

3734
/**
3835
* Implementation of {@link ToolCallback} that adapts MCP tools to Spring AI's tool
@@ -118,22 +115,24 @@ public ToolDefinition getToolDefinition() {
118115
@Override
119116
public String call(String functionInput) {
120117
Map<String, Object> arguments = ModelOptionsUtils.jsonToMap(functionInput);
121-
// Note that we use the original tool name here, not the adapted one from
122-
// getToolDefinition
118+
119+
CallToolResult response;
123120
try {
124-
CallToolResult response = this.mcpClient.callTool(new CallToolRequest(this.tool.name(), arguments));
125-
if (response.isError() != null && response.isError()) {
126-
logger.error("Error calling tool: {}", response.content());
127-
throw new ToolExecutionException(this.getToolDefinition(),
128-
new IllegalStateException("Error calling tool: " + response.content()));
129-
}
130-
return ModelOptionsUtils.toJsonString(response.content());
121+
// Note that we use the original tool name here, not the adapted one from
122+
// getToolDefinition
123+
response = this.mcpClient.callTool(new CallToolRequest(this.tool.name(), arguments));
131124
}
132125
catch (Exception ex) {
133126
logger.error("Exception while tool calling: ", ex);
134-
throw new ToolExecutionException(this.getToolDefinition(), ex.getCause());
127+
throw new ToolExecutionException(this.getToolDefinition(), ex);
135128
}
136129

130+
if (response.isError() != null && response.isError()) {
131+
logger.error("Error calling tool: {}", response.content());
132+
throw new ToolExecutionException(this.getToolDefinition(),
133+
new IllegalStateException("Error calling tool: " + response.content()));
134+
}
135+
return ModelOptionsUtils.toJsonString(response.content());
137136
}
138137

139138
@Override
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
package org.springframework.ai.mcp;
2+
3+
import io.modelcontextprotocol.client.McpAsyncClient;
4+
import io.modelcontextprotocol.spec.McpSchema;
5+
import org.junit.jupiter.api.Test;
6+
import org.junit.jupiter.api.extension.ExtendWith;
7+
import org.mockito.Mock;
8+
import org.mockito.junit.jupiter.MockitoExtension;
9+
import reactor.core.publisher.Mono;
10+
11+
import org.springframework.ai.tool.execution.ToolExecutionException;
12+
import static org.assertj.core.api.Assertions.assertThatThrownBy;
13+
import static org.mockito.ArgumentMatchers.any;
14+
import static org.mockito.Mockito.when;
15+
16+
@ExtendWith(MockitoExtension.class)
17+
class AsyncMcpToolCallbackTest {
18+
19+
@Mock
20+
private McpAsyncClient mcpClient;
21+
22+
@Mock
23+
private McpSchema.Tool tool;
24+
25+
@Test
26+
void callShouldThrowOnError() {
27+
when(this.tool.name()).thenReturn("testTool");
28+
var clientInfo = new McpSchema.Implementation("testClient", "1.0.0");
29+
when(this.mcpClient.getClientInfo()).thenReturn(clientInfo);
30+
var callToolResult = McpSchema.CallToolResult.builder().addTextContent("Some error data").isError(true).build();
31+
when(this.mcpClient.callTool(any(McpSchema.CallToolRequest.class))).thenReturn(Mono.just(callToolResult));
32+
33+
var callback = new AsyncMcpToolCallback(this.mcpClient, this.tool);
34+
assertThatThrownBy(() -> callback.call("{\"param\":\"value\"}")).isInstanceOf(ToolExecutionException.class)
35+
.cause()
36+
.isInstanceOf(IllegalStateException.class)
37+
.hasMessage("Error calling tool: [TextContent[audience=null, priority=null, text=Some error data]]");
38+
}
39+
40+
@Test
41+
void callShouldWrapReactiveErrors() {
42+
when(this.tool.name()).thenReturn("testTool");
43+
var clientInfo = new McpSchema.Implementation("testClient", "1.0.0");
44+
when(this.mcpClient.getClientInfo()).thenReturn(clientInfo);
45+
when(this.mcpClient.callTool(any(McpSchema.CallToolRequest.class)))
46+
.thenReturn(Mono.error(new Exception("Testing tool error")));
47+
48+
var callback = new AsyncMcpToolCallback(this.mcpClient, this.tool);
49+
assertThatThrownBy(() -> callback.call("{\"param\":\"value\"}")).isInstanceOf(ToolExecutionException.class)
50+
.rootCause()
51+
.hasMessage("Testing tool error");
52+
}
53+
54+
}

mcp/common/src/test/java/org/springframework/ai/mcp/SyncMcpToolCallbackTests.java

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
package org.springframework.ai.mcp;
1818

19+
import io.modelcontextprotocol.spec.McpSchema;
20+
import java.util.List;
1921
import java.util.Map;
2022

2123
import io.modelcontextprotocol.client.McpSyncClient;
@@ -29,8 +31,11 @@
2931
import org.mockito.junit.jupiter.MockitoExtension;
3032

3133
import org.springframework.ai.chat.model.ToolContext;
34+
import org.springframework.ai.content.Content;
35+
import org.springframework.ai.tool.execution.ToolExecutionException;
3236

3337
import static org.assertj.core.api.Assertions.assertThat;
38+
import static org.assertj.core.api.Assertions.assertThatThrownBy;
3439
import static org.mockito.ArgumentMatchers.any;
3540
import static org.mockito.Mockito.mock;
3641
import static org.mockito.Mockito.when;
@@ -94,4 +99,36 @@ void callShouldIgnoreToolContext() {
9499
assertThat(response).isNotNull();
95100
}
96101

102+
@Test
103+
void callShouldThrowOnError() {
104+
when(this.tool.name()).thenReturn("testTool");
105+
var clientInfo = new Implementation("testClient", "1.0.0");
106+
when(this.mcpClient.getClientInfo()).thenReturn(clientInfo);
107+
CallToolResult callResult = mock(CallToolResult.class);
108+
when(callResult.isError()).thenReturn(true);
109+
when(callResult.content()).thenReturn(List.of(new McpSchema.TextContent("Some error data")));
110+
when(this.mcpClient.callTool(any(CallToolRequest.class))).thenReturn(callResult);
111+
112+
SyncMcpToolCallback callback = new SyncMcpToolCallback(this.mcpClient, this.tool);
113+
114+
assertThatThrownBy(() -> callback.call("{\"param\":\"value\"}")).isInstanceOf(ToolExecutionException.class)
115+
.cause()
116+
.isInstanceOf(IllegalStateException.class)
117+
.hasMessage("Error calling tool: [TextContent[audience=null, priority=null, text=Some error data]]");
118+
}
119+
120+
@Test
121+
void callShouldWrapExceptions() {
122+
when(this.tool.name()).thenReturn("testTool");
123+
var clientInfo = new Implementation("testClient", "1.0.0");
124+
when(this.mcpClient.getClientInfo()).thenReturn(clientInfo);
125+
when(this.mcpClient.callTool(any(CallToolRequest.class))).thenThrow(new RuntimeException("Testing tool error"));
126+
127+
SyncMcpToolCallback callback = new SyncMcpToolCallback(this.mcpClient, this.tool);
128+
129+
assertThatThrownBy(() -> callback.call("{\"param\":\"value\"}")).isInstanceOf(ToolExecutionException.class)
130+
.rootCause()
131+
.hasMessage("Testing tool error");
132+
}
133+
97134
}

0 commit comments

Comments
 (0)