Skip to content

Commit 2e6d6ea

Browse files
alexandreromantzolov
authored andcommitted
feat(multimudality): Add support for base64-encoded images in tool call results (#2368)
- Enhance McpToolUtils to handle base64-encoded images in JSON responses - Add Base64Wrapper record to parse JSON structures containing base64 image data - Implement image conversion in DefaultToolCallResultConverter to encode RenderedImage as base64 PNG - Add tests for DefaultToolCallResultConverter including image conversion - Gracefully handle unsupported JSON structure for base64 wrappers Signed-off-by: Alexandre Roman <alexandre.roman@broadcom.com>
1 parent 14e7033 commit 2e6d6ea

File tree

3 files changed

+177
-7
lines changed

3 files changed

+177
-7
lines changed

mcp/common/src/main/java/org/springframework/ai/mcp/McpToolUtils.java

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,15 @@
1818
import java.util.List;
1919
import java.util.Map;
2020

21+
import com.fasterxml.jackson.annotation.JsonAlias;
22+
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
2123
import io.micrometer.common.util.StringUtils;
2224
import io.modelcontextprotocol.client.McpAsyncClient;
2325
import io.modelcontextprotocol.client.McpSyncClient;
2426
import io.modelcontextprotocol.server.McpServerFeatures;
25-
import io.modelcontextprotocol.server.McpSyncServerExchange;
2627
import io.modelcontextprotocol.server.McpServerFeatures.AsyncToolRegistration;
2728
import io.modelcontextprotocol.server.McpServerFeatures.AsyncToolSpecification;
29+
import io.modelcontextprotocol.server.McpSyncServerExchange;
2830
import io.modelcontextprotocol.spec.McpSchema;
2931
import io.modelcontextprotocol.spec.McpSchema.Role;
3032
import reactor.core.publisher.Mono;
@@ -33,6 +35,8 @@
3335
import org.springframework.ai.chat.model.ToolContext;
3436
import org.springframework.ai.model.ModelOptionsUtils;
3537
import org.springframework.ai.tool.ToolCallback;
38+
import org.springframework.ai.util.json.JsonParser;
39+
import org.springframework.lang.Nullable;
3640
import org.springframework.util.CollectionUtils;
3741
import org.springframework.util.MimeType;
3842

@@ -234,9 +238,22 @@ public static McpServerFeatures.SyncToolRegistration toSyncToolRegistration(Tool
234238
return new McpServerFeatures.SyncToolRegistration(tool, request -> {
235239
try {
236240
String callResult = toolCallback.call(ModelOptionsUtils.toJsonString(request));
237-
if (mimeType != null && mimeType.toString().startsWith("image")) {
238-
return new McpSchema.CallToolResult(List
239-
.of(new McpSchema.ImageContent(List.of(Role.ASSISTANT), null, callResult, mimeType.toString())),
241+
String imgData = callResult;
242+
if (mimeType != null && "image".equals(mimeType.getType())) {
243+
String imgType = mimeType.toString();
244+
if (callResult.startsWith("{") && callResult.endsWith("}")) {
245+
// This is most likely a JSON structure:
246+
// let's try to parse it as a base64 wrapper.
247+
var b64Struct = JsonParser.fromJson(callResult, Base64Wrapper.class);
248+
if (b64Struct.mimeType() != null && b64Struct.data() != null
249+
&& b64Struct.mimeType.getType().equals("image")) {
250+
// Get the base64 encoded image as is.
251+
imgType = b64Struct.mimeType().toString();
252+
imgData = b64Struct.data();
253+
}
254+
}
255+
return new McpSchema.CallToolResult(
256+
List.of(new McpSchema.ImageContent(List.of(Role.ASSISTANT), null, imgData, imgType)),
240257
false);
241258
}
242259
return new McpSchema.CallToolResult(List.of(new McpSchema.TextContent(callResult)), false);
@@ -547,4 +564,9 @@ public static List<ToolCallback> getToolCallbacksFromAsyncClients(List<McpAsyncC
547564
return List.of((new AsyncMcpToolCallbackProvider(asyncMcpClients).getToolCallbacks()));
548565
}
549566

567+
@JsonIgnoreProperties(ignoreUnknown = true)
568+
private record Base64Wrapper(@JsonAlias("mimetype") @Nullable MimeType mimeType, @JsonAlias( {
569+
"base64", "b64", "imageData" }) @Nullable String data){
570+
}
571+
550572
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
package org.springframework.ai.tool.execution;
2+
3+
import org.junit.jupiter.api.Test;
4+
import org.springframework.ai.util.json.JsonParser;
5+
import org.springframework.util.MimeType;
6+
import org.springframework.util.MimeTypeUtils;
7+
8+
import javax.imageio.ImageIO;
9+
import java.awt.Color;
10+
import java.awt.image.BufferedImage;
11+
import java.io.ByteArrayInputStream;
12+
import java.io.IOException;
13+
import java.util.Base64;
14+
import java.util.List;
15+
import java.util.Map;
16+
17+
import static org.assertj.core.api.Assertions.assertThat;
18+
19+
/**
20+
* Unit tests for {@link DefaultToolCallResultConverter}.
21+
*
22+
* @author Thomas Vitale
23+
*/
24+
class DefaultToolCallResultConverterTests {
25+
26+
private final DefaultToolCallResultConverter converter = new DefaultToolCallResultConverter();
27+
28+
@Test
29+
void convertWithNullReturnTypeShouldReturn() {
30+
String result = converter.convert(null, null);
31+
assertThat(result).isEqualTo("null");
32+
}
33+
34+
@Test
35+
void convertVoidReturnTypeShouldReturnDone() {
36+
String result = converter.convert(null, void.class);
37+
assertThat(result).isEqualTo("Done");
38+
}
39+
40+
@Test
41+
void convertStringReturnTypeShouldReturnJson() {
42+
String result = converter.convert("test", String.class);
43+
assertThat(result).isEqualTo("\"test\"");
44+
}
45+
46+
@Test
47+
void convertNullReturnValueShouldReturnNullJson() {
48+
String result = converter.convert(null, String.class);
49+
assertThat(result).isEqualTo("null");
50+
}
51+
52+
@Test
53+
void convertObjectReturnTypeShouldReturnJson() {
54+
TestObject testObject = new TestObject("test", 42);
55+
String result = converter.convert(testObject, TestObject.class);
56+
assertThat(result).containsIgnoringWhitespaces("""
57+
"name": "test"
58+
""").containsIgnoringWhitespaces("""
59+
"value": 42
60+
""");
61+
}
62+
63+
@Test
64+
void convertCollectionReturnTypeShouldReturnJson() {
65+
List<String> testList = List.of("one", "two", "three");
66+
String result = converter.convert(testList, List.class);
67+
assertThat(result).isEqualTo("""
68+
["one","two","three"]
69+
""".trim());
70+
}
71+
72+
@Test
73+
void convertMapReturnTypeShouldReturnJson() {
74+
Map<String, Integer> testMap = Map.of("one", 1, "two", 2);
75+
String result = converter.convert(testMap, Map.class);
76+
assertThat(result).containsIgnoringWhitespaces("""
77+
"one": 1
78+
""").containsIgnoringWhitespaces("""
79+
"two": 2
80+
""");
81+
}
82+
83+
@Test
84+
void convertImageShouldReturnBase64Image() throws IOException {
85+
// We don't want any AWT windows.
86+
System.setProperty("java.awt.headless", "true");
87+
88+
var img = new BufferedImage(64, 64, BufferedImage.TYPE_4BYTE_ABGR);
89+
var g = img.createGraphics();
90+
g.setColor(Color.WHITE);
91+
g.fillRect(0, 0, 64, 64);
92+
g.dispose();
93+
String result = converter.convert(img, BufferedImage.class);
94+
95+
var b64Struct = JsonParser.fromJson(result, Base64Wrapper.class);
96+
assertThat(b64Struct.mimeType).isEqualTo(MimeTypeUtils.IMAGE_PNG);
97+
assertThat(b64Struct.data).isNotNull();
98+
99+
var imgData = Base64.getDecoder().decode(b64Struct.data);
100+
assertThat(imgData.length).isNotZero();
101+
102+
var imgRes = ImageIO.read(new ByteArrayInputStream(imgData));
103+
assertThat(imgRes.getWidth()).isEqualTo(64);
104+
assertThat(imgRes.getHeight()).isEqualTo(64);
105+
assertThat(imgRes.getRGB(0, 0)).isEqualTo(img.getRGB(0, 0));
106+
}
107+
108+
record Base64Wrapper(MimeType mimeType, String data) {
109+
}
110+
111+
static class TestObject {
112+
113+
private final String name;
114+
115+
private final int value;
116+
117+
TestObject(String name, int value) {
118+
this.name = name;
119+
this.value = value;
120+
}
121+
122+
public String getName() {
123+
return name;
124+
}
125+
126+
public int getValue() {
127+
return value;
128+
}
129+
130+
}
131+
132+
}

spring-ai-model/src/main/java/org/springframework/ai/tool/execution/DefaultToolCallResultConverter.java

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,19 @@
1616

1717
package org.springframework.ai.tool.execution;
1818

19-
import java.lang.reflect.Type;
20-
2119
import org.slf4j.Logger;
2220
import org.slf4j.LoggerFactory;
23-
2421
import org.springframework.ai.util.json.JsonParser;
2522
import org.springframework.lang.Nullable;
2623

24+
import javax.imageio.ImageIO;
25+
import java.awt.image.RenderedImage;
26+
import java.io.ByteArrayOutputStream;
27+
import java.io.IOException;
28+
import java.lang.reflect.Type;
29+
import java.util.Base64;
30+
import java.util.Map;
31+
2732
/**
2833
* A default implementation of {@link ToolCallResultConverter}.
2934
*
@@ -40,6 +45,17 @@ public String convert(@Nullable Object result, @Nullable Type returnType) {
4045
logger.debug("The tool has no return type. Converting to conventional response.");
4146
return "Done";
4247
}
48+
if (result instanceof RenderedImage) {
49+
final var buf = new ByteArrayOutputStream(1024 * 4);
50+
try {
51+
ImageIO.write((RenderedImage) result, "PNG", buf);
52+
}
53+
catch (IOException e) {
54+
return "Failed to convert tool result to a base64 image: " + e.getMessage();
55+
}
56+
final var imgB64 = Base64.getEncoder().encodeToString(buf.toByteArray());
57+
return JsonParser.toJson(Map.of("mimeType", "image/png", "data", imgB64));
58+
}
4359
else {
4460
logger.debug("Converting tool result to JSON.");
4561
return JsonParser.toJson(result);

0 commit comments

Comments
 (0)