Skip to content

Commit 2c3905a

Browse files
committed
Merge branch 'refs/heads/main' into feat/header
# Conflicts: # mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java
2 parents ed4f8e7 + f3b0774 commit 2c3905a

File tree

10 files changed

+280
-22
lines changed

10 files changed

+280
-22
lines changed

README.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,13 @@ This SDK enables Java applications to interact with AI models and tools through
77
## 📚 Reference Documentation
88

99
#### MCP Java SDK documentation
10-
For comprehensive guides and SDK API documentation, visit the [MCP Java SDK Reference Documentation](https://modelcontextprotocol.io/sdk/java/mcp-overview).
10+
For comprehensive guides and SDK API documentation
11+
12+
- [Features](https://modelcontextprotocol.io/sdk/java/mcp-overview#features) - Overview the features provided by the Java MCP SDK
13+
- [Acrchitecture](https://modelcontextprotocol.io/sdk/java/mcp-overview#architecture) - Java MCP SDK architecture overview.
14+
- [Java Dependencies / BOM](https://modelcontextprotocol.io/sdk/java/mcp-overview#dependencies) - Java dependencies and BOM.
15+
- [Java MCP Client](https://modelcontextprotocol.io/sdk/java/mcp-client) - Learn how to use the MCP client to interact with MCP servers.
16+
- [Java MCP Server](https://modelcontextprotocol.io/sdk/java/mcp-server) - Learn how to implement and configure a MCP servers.
1117

1218
#### Spring AI MCP documentation
1319
[Spring AI MCP](https://docs.spring.io/spring-ai/reference/api/mcp/mcp-overview.html) extends the MCP Java SDK with Spring Boot integration, providing both [client](https://docs.spring.io/spring-ai/reference/api/mcp/mcp-client-boot-starter-docs.html) and [server](https://docs.spring.io/spring-ai/reference/api/mcp/mcp-server-boot-starter-docs.html) starters. Bootstrap your AI applications with MCP support using [Spring Initializer](https://start.spring.io).

mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -127,13 +127,14 @@ public Mono<Void> connect(Function<Mono<McpSchema.JSONRPCMessage>, Mono<McpSchem
127127
}
128128

129129
private DefaultMcpTransportSession createTransportSession() {
130-
Supplier<Publisher<Void>> onClose = () -> {
131-
DefaultMcpTransportSession transportSession = this.activeSession.get();
132-
return transportSession.sessionId().isEmpty() ? Mono.empty()
133-
: webClient.delete().uri(this.endpoint).headers(httpHeaders -> {
134-
httpHeaders.add(MCP_SESSION_ID, transportSession.sessionId().get());
135-
}).retrieve().toBodilessEntity().doOnError(e -> logger.info("Got response {}", e)).then();
136-
};
130+
Function<String, Publisher<Void>> onClose = sessionId -> sessionId == null ? Mono.empty()
131+
: webClient.delete().uri(this.endpoint).headers(httpHeaders -> {
132+
httpHeaders.add(MCP_SESSION_ID, sessionId);
133+
})
134+
.retrieve()
135+
.toBodilessEntity()
136+
.doOnError(e -> logger.warn("Got error when closing transport", e))
137+
.then();
137138
return new DefaultMcpTransportSession(onClose);
138139
}
139140

@@ -194,6 +195,7 @@ private Mono<Disposable> reconnect(McpTransportStream<Disposable> stream) {
194195
})
195196
.exchangeToFlux(response -> {
196197
if (isEventStream(response)) {
198+
logger.debug("Established SSE stream via GET");
197199
return eventStream(stream, response);
198200
}
199201
else if (isNotAllowed(response)) {
@@ -210,6 +212,7 @@ else if (isNotFound(response)) {
210212
}).flux();
211213
}
212214
})
215+
.flatMap(jsonrpcMessage -> this.handler.get().apply(Mono.just(jsonrpcMessage)))
213216
.onErrorComplete(t -> {
214217
this.handleException(t);
215218
return true;
@@ -275,6 +278,7 @@ public Mono<Void> sendMessage(McpSchema.JSONRPCMessage message) {
275278
else {
276279
MediaType mediaType = contentType.get();
277280
if (mediaType.isCompatibleWith(MediaType.TEXT_EVENT_STREAM)) {
281+
logger.debug("Established SSE stream via POST");
278282
// communicate to caller that the message was delivered
279283
sink.success();
280284
// starting a stream

mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpSyncClientTests.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
package io.modelcontextprotocol.client;
22

3-
import com.fasterxml.jackson.databind.ObjectMapper;
43
import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport;
54
import io.modelcontextprotocol.spec.McpClientTransport;
65
import org.junit.jupiter.api.Timeout;
7-
import org.springframework.web.reactive.function.client.WebClient;
86
import org.testcontainers.containers.GenericContainer;
97
import org.testcontainers.containers.wait.strategy.Wait;
108

9+
import org.springframework.web.reactive.function.client.WebClient;
10+
1111
@Timeout(15)
1212
public class WebClientStreamableHttpSyncClientTests extends AbstractMcpSyncClientTests {
1313

mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java

Lines changed: 74 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import java.util.Map;
99
import java.util.Objects;
1010
import java.util.concurrent.atomic.AtomicBoolean;
11+
import java.util.concurrent.atomic.AtomicInteger;
1112
import java.util.concurrent.atomic.AtomicReference;
1213
import java.util.function.Consumer;
1314
import java.util.function.Function;
@@ -19,6 +20,8 @@
1920
import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities;
2021
import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest;
2122
import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult;
23+
import io.modelcontextprotocol.spec.McpSchema.ElicitRequest;
24+
import io.modelcontextprotocol.spec.McpSchema.ElicitResult;
2225
import io.modelcontextprotocol.spec.McpSchema.GetPromptRequest;
2326
import io.modelcontextprotocol.spec.McpSchema.Prompt;
2427
import io.modelcontextprotocol.spec.McpSchema.Resource;
@@ -38,6 +41,7 @@
3841
import static org.assertj.core.api.Assertions.assertThat;
3942
import static org.assertj.core.api.Assertions.assertThatCode;
4043
import static org.assertj.core.api.Assertions.assertThatThrownBy;
44+
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
4145

4246
/**
4347
* Test suite for the {@link McpAsyncClient} that can be used with different
@@ -77,7 +81,9 @@ McpAsyncClient client(McpClientTransport transport, Function<McpClient.AsyncSpec
7781
McpClient.AsyncSpec builder = McpClient.async(transport)
7882
.requestTimeout(getRequestTimeout())
7983
.initializationTimeout(getInitializationTimeout())
80-
.capabilities(ClientCapabilities.builder().roots(true).build());
84+
.sampling(req -> Mono.just(new CreateMessageResult(McpSchema.Role.USER,
85+
new McpSchema.TextContent("Oh, hi!"), "modelId", CreateMessageResult.StopReason.END_TURN)))
86+
.capabilities(ClientCapabilities.builder().roots(true).sampling().build());
8187
builder = customizer.apply(builder);
8288
client.set(builder.build());
8389
}).doesNotThrowAnyException();
@@ -424,6 +430,20 @@ void testInitializeWithSamplingCapability() {
424430
});
425431
}
426432

433+
@Test
434+
void testInitializeWithElicitationCapability() {
435+
ClientCapabilities capabilities = ClientCapabilities.builder().elicitation().build();
436+
ElicitResult elicitResult = ElicitResult.builder()
437+
.message(ElicitResult.Action.ACCEPT)
438+
.content(Map.of("foo", "bar"))
439+
.build();
440+
withClient(createMcpTransport(),
441+
builder -> builder.capabilities(capabilities).elicitation(request -> Mono.just(elicitResult)),
442+
client -> {
443+
StepVerifier.create(client.initialize()).expectNextMatches(Objects::nonNull).verifyComplete();
444+
});
445+
}
446+
427447
@Test
428448
void testInitializeWithAllCapabilities() {
429449
var capabilities = ClientCapabilities.builder()
@@ -435,7 +455,11 @@ void testInitializeWithAllCapabilities() {
435455
Function<CreateMessageRequest, Mono<CreateMessageResult>> samplingHandler = request -> Mono
436456
.just(CreateMessageResult.builder().message("test").model("test-model").build());
437457

438-
withClient(createMcpTransport(), builder -> builder.capabilities(capabilities).sampling(samplingHandler),
458+
Function<ElicitRequest, Mono<ElicitResult>> elicitationHandler = request -> Mono
459+
.just(ElicitResult.builder().message(ElicitResult.Action.ACCEPT).content(Map.of("foo", "bar")).build());
460+
461+
withClient(createMcpTransport(),
462+
builder -> builder.capabilities(capabilities).sampling(samplingHandler).elicitation(elicitationHandler),
439463
client ->
440464

441465
StepVerifier.create(client.initialize()).assertNext(result -> {
@@ -487,4 +511,52 @@ void testLoggingWithNullNotification() {
487511
});
488512
}
489513

514+
@Test
515+
void testSampling() {
516+
McpClientTransport transport = createMcpTransport();
517+
518+
final String message = "Hello, world!";
519+
final String response = "Goodbye, world!";
520+
final int maxTokens = 100;
521+
522+
AtomicReference<String> receivedPrompt = new AtomicReference<>();
523+
AtomicReference<String> receivedMessage = new AtomicReference<>();
524+
AtomicInteger receivedMaxTokens = new AtomicInteger();
525+
526+
withClient(transport, spec -> spec.capabilities(McpSchema.ClientCapabilities.builder().sampling().build())
527+
.sampling(request -> {
528+
McpSchema.TextContent messageText = assertInstanceOf(McpSchema.TextContent.class,
529+
request.messages().get(0).content());
530+
receivedPrompt.set(request.systemPrompt());
531+
receivedMessage.set(messageText.text());
532+
receivedMaxTokens.set(request.maxTokens());
533+
534+
return Mono
535+
.just(new McpSchema.CreateMessageResult(McpSchema.Role.USER, new McpSchema.TextContent(response),
536+
"modelId", McpSchema.CreateMessageResult.StopReason.END_TURN));
537+
}), client -> {
538+
StepVerifier.create(client.initialize()).expectNextMatches(Objects::nonNull).verifyComplete();
539+
540+
StepVerifier.create(client.callTool(
541+
new McpSchema.CallToolRequest("sampleLLM", Map.of("prompt", message, "maxTokens", maxTokens))))
542+
.consumeNextWith(result -> {
543+
// Verify tool response to ensure our sampling response was passed
544+
// through
545+
assertThat(result.content()).hasAtLeastOneElementOfType(McpSchema.TextContent.class);
546+
assertThat(result.content()).allSatisfy(content -> {
547+
if (!(content instanceof McpSchema.TextContent text))
548+
return;
549+
550+
assertThat(text.text()).endsWith(response); // Prefixed
551+
});
552+
553+
// Verify sampling request parameters received in our callback
554+
assertThat(receivedPrompt.get()).isNotEmpty();
555+
assertThat(receivedMessage.get()).endsWith(message); // Prefixed
556+
assertThat(receivedMaxTokens.get()).isEqualTo(maxTokens);
557+
})
558+
.verifyComplete();
559+
});
560+
}
561+
490562
}

mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import java.util.List;
99
import java.util.Map;
1010
import java.util.concurrent.atomic.AtomicBoolean;
11+
import java.util.concurrent.atomic.AtomicInteger;
1112
import java.util.concurrent.atomic.AtomicReference;
1213
import java.util.function.Consumer;
1314
import java.util.function.Function;
@@ -31,13 +32,12 @@
3132
import org.junit.jupiter.api.BeforeEach;
3233
import org.junit.jupiter.api.Test;
3334
import reactor.core.publisher.Mono;
34-
import reactor.core.scheduler.Scheduler;
35-
import reactor.core.scheduler.Schedulers;
3635
import reactor.test.StepVerifier;
3736

3837
import static org.assertj.core.api.Assertions.assertThat;
3938
import static org.assertj.core.api.Assertions.assertThatCode;
4039
import static org.assertj.core.api.Assertions.assertThatThrownBy;
40+
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
4141

4242
/**
4343
* Unit tests for MCP Client Session functionality.
@@ -438,4 +438,48 @@ void testLoggingWithNullNotification() {
438438
.hasMessageContaining("Logging level must not be null"));
439439
}
440440

441+
@Test
442+
void testSampling() {
443+
McpClientTransport transport = createMcpTransport();
444+
445+
final String message = "Hello, world!";
446+
final String response = "Goodbye, world!";
447+
final int maxTokens = 100;
448+
449+
AtomicReference<String> receivedPrompt = new AtomicReference<>();
450+
AtomicReference<String> receivedMessage = new AtomicReference<>();
451+
AtomicInteger receivedMaxTokens = new AtomicInteger();
452+
453+
withClient(transport, spec -> spec.capabilities(McpSchema.ClientCapabilities.builder().sampling().build())
454+
.sampling(request -> {
455+
McpSchema.TextContent messageText = assertInstanceOf(McpSchema.TextContent.class,
456+
request.messages().get(0).content());
457+
receivedPrompt.set(request.systemPrompt());
458+
receivedMessage.set(messageText.text());
459+
receivedMaxTokens.set(request.maxTokens());
460+
461+
return new McpSchema.CreateMessageResult(McpSchema.Role.USER, new McpSchema.TextContent(response),
462+
"modelId", McpSchema.CreateMessageResult.StopReason.END_TURN);
463+
}), client -> {
464+
client.initialize();
465+
466+
McpSchema.CallToolResult result = client.callTool(
467+
new McpSchema.CallToolRequest("sampleLLM", Map.of("prompt", message, "maxTokens", maxTokens)));
468+
469+
// Verify tool response to ensure our sampling response was passed through
470+
assertThat(result.content()).hasAtLeastOneElementOfType(McpSchema.TextContent.class);
471+
assertThat(result.content()).allSatisfy(content -> {
472+
if (!(content instanceof McpSchema.TextContent text))
473+
return;
474+
475+
assertThat(text.text()).endsWith(response); // Prefixed
476+
});
477+
478+
// Verify sampling request parameters received in our callback
479+
assertThat(receivedPrompt.get()).isNotEmpty();
480+
assertThat(receivedMessage.get()).endsWith(message); // Prefixed
481+
assertThat(receivedMaxTokens.get()).isEqualTo(maxTokens);
482+
});
483+
}
484+
441485
}

mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportSession.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import java.util.Optional;
1111
import java.util.concurrent.atomic.AtomicBoolean;
1212
import java.util.concurrent.atomic.AtomicReference;
13-
import java.util.function.Supplier;
13+
import java.util.function.Function;
1414

1515
/**
1616
* Default implementation of {@link McpTransportSession} which manages the open
@@ -29,9 +29,9 @@ public class DefaultMcpTransportSession implements McpTransportSession<Disposabl
2929

3030
private final AtomicReference<String> sessionId = new AtomicReference<>();
3131

32-
private final Supplier<Publisher<Void>> onClose;
32+
private final Function<String, Publisher<Void>> onClose;
3333

34-
public DefaultMcpTransportSession(Supplier<Publisher<Void>> onClose) {
34+
public DefaultMcpTransportSession(Function<String, Publisher<Void>> onClose) {
3535
this.onClose = onClose;
3636
}
3737

@@ -73,7 +73,8 @@ public void close() {
7373

7474
@Override
7575
public Mono<Void> closeGracefully() {
76-
return Mono.from(this.onClose.get()).then(Mono.fromRunnable(this.openConnections::dispose));
76+
return Mono.from(this.onClose.apply(this.sessionId.get()))
77+
.then(Mono.fromRunnable(this.openConnections::dispose));
7778
}
7879

7980
}

mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1371,8 +1371,9 @@ public record CompleteCompletion(
13711371
@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.PROPERTY, property = "type")
13721372
@JsonSubTypes({ @JsonSubTypes.Type(value = TextContent.class, name = "text"),
13731373
@JsonSubTypes.Type(value = ImageContent.class, name = "image"),
1374+
@JsonSubTypes.Type(value = AudioContent.class, name = "audio"),
13741375
@JsonSubTypes.Type(value = EmbeddedResource.class, name = "resource") })
1375-
public sealed interface Content permits TextContent, ImageContent, EmbeddedResource {
1376+
public sealed interface Content permits TextContent, ImageContent, AudioContent, EmbeddedResource {
13761377

13771378
default String type() {
13781379
if (this instanceof TextContent) {
@@ -1381,6 +1382,9 @@ default String type() {
13811382
else if (this instanceof ImageContent) {
13821383
return "image";
13831384
}
1385+
else if (this instanceof AudioContent) {
1386+
return "audio";
1387+
}
13841388
else if (this instanceof EmbeddedResource) {
13851389
return "resource";
13861390
}
@@ -1410,6 +1414,14 @@ public record ImageContent( // @formatter:off
14101414
@JsonProperty("mimeType") String mimeType) implements Content { // @formatter:on
14111415
}
14121416

1417+
@JsonInclude(JsonInclude.Include.NON_ABSENT)
1418+
@JsonIgnoreProperties(ignoreUnknown = true)
1419+
public record AudioContent( // @formatter:off
1420+
@JsonProperty("annotations") Annotations annotations,
1421+
@JsonProperty("data") String data,
1422+
@JsonProperty("mimeType") String mimeType) implements Annotated, Content { // @formatter:on
1423+
}
1424+
14131425
@JsonInclude(JsonInclude.Include.NON_ABSENT)
14141426
@JsonIgnoreProperties(ignoreUnknown = true)
14151427
public record EmbeddedResource( // @formatter:off

0 commit comments

Comments
 (0)