Skip to content

test: enhance resource reading tests #314

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
package io.modelcontextprotocol.client;

import com.fasterxml.jackson.databind.ObjectMapper;
import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport;
import io.modelcontextprotocol.spec.McpClientTransport;
import org.junit.jupiter.api.Timeout;
import org.springframework.web.reactive.function.client.WebClient;
import org.testcontainers.containers.GenericContainer;
import org.testcontainers.containers.wait.strategy.Wait;
import org.testcontainers.images.builder.ImageFromDockerfile;

import org.springframework.web.reactive.function.client.WebClient;

@Timeout(15)
public class WebClientStreamableHttpAsyncClientTests extends AbstractMcpAsyncClientTests {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@

import java.time.Duration;

import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport;
import io.modelcontextprotocol.spec.McpClientTransport;
import org.junit.jupiter.api.Timeout;
import org.springframework.web.reactive.function.client.WebClient;
import org.testcontainers.containers.GenericContainer;
import org.testcontainers.containers.wait.strategy.Wait;

import org.springframework.web.reactive.function.client.WebClient;
import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport;
import io.modelcontextprotocol.spec.McpClientTransport;

/**
* Tests for the {@link McpAsyncClient} with {@link WebFluxSseClientTransport}.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,41 +4,49 @@

package io.modelcontextprotocol.client;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatCode;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.junit.jupiter.api.Assertions.assertInstanceOf;

import java.time.Duration;
import java.util.ArrayList;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
import java.util.function.Function;

import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

import io.modelcontextprotocol.spec.McpClientTransport;
import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpSchema.BlobResourceContents;
import io.modelcontextprotocol.spec.McpSchema.CallToolRequest;
import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities;
import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest;
import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult;
import io.modelcontextprotocol.spec.McpSchema.ElicitRequest;
import io.modelcontextprotocol.spec.McpSchema.ElicitResult;
import io.modelcontextprotocol.spec.McpSchema.GetPromptRequest;
import io.modelcontextprotocol.spec.McpSchema.Prompt;
import io.modelcontextprotocol.spec.McpSchema.ReadResourceResult;
import io.modelcontextprotocol.spec.McpSchema.Resource;
import io.modelcontextprotocol.spec.McpSchema.ResourceContents;
import io.modelcontextprotocol.spec.McpSchema.Root;
import io.modelcontextprotocol.spec.McpSchema.SubscribeRequest;
import io.modelcontextprotocol.spec.McpSchema.TextResourceContents;
import io.modelcontextprotocol.spec.McpSchema.Tool;
import io.modelcontextprotocol.spec.McpSchema.UnsubscribeRequest;
import io.modelcontextprotocol.spec.McpTransport;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatCode;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

/**
* Test suite for the {@link McpAsyncClient} that can be used with different
* {@link McpTransport} implementations.
Expand Down Expand Up @@ -339,18 +347,59 @@ void testRemoveNonExistentRoot() {
}

@Test
@Disabled
void testReadResource() {
withClient(createMcpTransport(), mcpAsyncClient -> {
StepVerifier.create(mcpAsyncClient.listResources()).consumeNextWith(resources -> {
if (!resources.resources().isEmpty()) {
Resource firstResource = resources.resources().get(0);
StepVerifier.create(mcpAsyncClient.readResource(firstResource)).consumeNextWith(result -> {
assertThat(result).isNotNull();
assertThat(result.contents()).isNotNull();
}).verifyComplete();
withClient(createMcpTransport(), client -> {
Flux<McpSchema.ReadResourceResult> resources = client.initialize()
.then(client.listResources(null))
.flatMapMany(r -> Flux.fromIterable(r.resources()))
.flatMap(r -> client.readResource(r));

StepVerifier.create(resources).recordWith(ArrayList::new).consumeRecordedWith(readResourceResults -> {

for (ReadResourceResult result : readResourceResults) {

assertThat(result).isNotNull();
assertThat(result.contents()).isNotNull().isNotEmpty();

// Validate each content item
for (ResourceContents content : result.contents()) {
assertThat(content).isNotNull();
assertThat(content.uri()).isNotNull().isNotEmpty();
assertThat(content.mimeType()).isNotNull().isNotEmpty();

// Validate content based on its type with more comprehensive
// checks
switch (content.mimeType()) {
case "text/plain" -> {
TextResourceContents textContent = assertInstanceOf(TextResourceContents.class,
content);
assertThat(textContent.text()).isNotNull().isNotEmpty();
assertThat(textContent.uri()).isNotEmpty();
}
case "application/octet-stream" -> {
BlobResourceContents blobContent = assertInstanceOf(BlobResourceContents.class,
content);
assertThat(blobContent.blob()).isNotNull().isNotEmpty();
assertThat(blobContent.uri()).isNotNull().isNotEmpty();
// Validate base64 encoding format
assertThat(blobContent.blob()).matches("^[A-Za-z0-9+/]*={0,2}$");
}
default -> {

// Still validate basic properties
if (content instanceof TextResourceContents textContent) {
assertThat(textContent.text()).isNotNull();
}
else if (content instanceof BlobResourceContents blobContent) {
assertThat(blobContent.blob()).isNotNull();
}
}
}
}
}
}).verifyComplete();
})
.expectNextCount(10) // Expect 10 elements
.verifyComplete();
});
}

Expand Down Expand Up @@ -424,6 +473,20 @@ void testInitializeWithSamplingCapability() {
});
}

@Test
void testInitializeWithElicitationCapability() {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not related to the ReadResources, but discovered a missing elicitation aync testing

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you - I'll be sure to double-check this in other PRs 😓

ClientCapabilities capabilities = ClientCapabilities.builder().elicitation().build();
ElicitResult elicitResult = ElicitResult.builder()
.message(ElicitResult.Action.ACCEPT)
.content(Map.of("foo", "bar"))
.build();
withClient(createMcpTransport(),
builder -> builder.capabilities(capabilities).elicitation(request -> Mono.just(elicitResult)),
client -> {
StepVerifier.create(client.initialize()).expectNextMatches(Objects::nonNull).verifyComplete();
});
}

@Test
void testInitializeWithAllCapabilities() {
var capabilities = ClientCapabilities.builder()
Expand All @@ -435,7 +498,11 @@ void testInitializeWithAllCapabilities() {
Function<CreateMessageRequest, Mono<CreateMessageResult>> samplingHandler = request -> Mono
.just(CreateMessageResult.builder().message("test").model("test-model").build());

withClient(createMcpTransport(), builder -> builder.capabilities(capabilities).sampling(samplingHandler),
Function<ElicitRequest, Mono<ElicitResult>> elicitationHandler = request -> Mono
.just(ElicitResult.builder().message(ElicitResult.Action.ACCEPT).content(Map.of("foo", "bar")).build());

withClient(createMcpTransport(),
builder -> builder.capabilities(capabilities).sampling(samplingHandler).elicitation(elicitationHandler),
client ->

StepVerifier.create(client.initialize()).assertNext(result -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@

package io.modelcontextprotocol.client;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatCode;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.junit.jupiter.api.Assertions.assertInstanceOf;

import java.time.Duration;
import java.util.List;
import java.util.Map;
Expand All @@ -12,8 +17,15 @@
import java.util.function.Consumer;
import java.util.function.Function;

import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import io.modelcontextprotocol.spec.McpClientTransport;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpSchema.BlobResourceContents;
import io.modelcontextprotocol.spec.McpSchema.CallToolRequest;
import io.modelcontextprotocol.spec.McpSchema.CallToolResult;
import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities;
Expand All @@ -22,23 +34,17 @@
import io.modelcontextprotocol.spec.McpSchema.ListToolsResult;
import io.modelcontextprotocol.spec.McpSchema.ReadResourceResult;
import io.modelcontextprotocol.spec.McpSchema.Resource;
import io.modelcontextprotocol.spec.McpSchema.ResourceContents;
import io.modelcontextprotocol.spec.McpSchema.Root;
import io.modelcontextprotocol.spec.McpSchema.SubscribeRequest;
import io.modelcontextprotocol.spec.McpSchema.TextContent;
import io.modelcontextprotocol.spec.McpSchema.TextResourceContents;
import io.modelcontextprotocol.spec.McpSchema.Tool;
import io.modelcontextprotocol.spec.McpSchema.UnsubscribeRequest;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Scheduler;
import reactor.core.scheduler.Schedulers;
import reactor.test.StepVerifier;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatCode;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

/**
* Unit tests for MCP Client Session functionality.
*
Expand All @@ -47,6 +53,8 @@
*/
public abstract class AbstractMcpSyncClientTests {

private static final Logger logger = LoggerFactory.getLogger(AbstractMcpSyncClientTests.class);

private static final String TEST_MESSAGE = "Hello MCP Spring AI!";

abstract protected McpClientTransport createMcpTransport();
Expand Down Expand Up @@ -121,9 +129,9 @@ <T> void verifyNotificationSucceedsWithImplicitInitialization(Consumer<McpSyncCl

<T> void verifyCallSucceedsWithImplicitInitialization(Function<McpSyncClient, T> blockingOperation, String action) {
withClient(createMcpTransport(), mcpSyncClient -> {
StepVerifier.create(Mono.fromSupplier(() -> blockingOperation.apply(mcpSyncClient)))
.expectNextCount(1)
.verifyComplete();
StepVerifier.create(Mono.fromSupplier(() -> blockingOperation.apply(mcpSyncClient))
// Offload the blocking call to the real scheduler
.subscribeOn(Schedulers.boundedElastic())).expectNextCount(1).verifyComplete();
});
}

Expand Down Expand Up @@ -331,16 +339,70 @@ void testReadResourceWithoutInitialization() {
@Test
void testReadResource() {
withClient(createMcpTransport(), mcpSyncClient -> {

int readResourceCount = 0;

mcpSyncClient.initialize();
ListResourcesResult resources = mcpSyncClient.listResources(null);

if (!resources.resources().isEmpty()) {
Resource firstResource = resources.resources().get(0);
ReadResourceResult result = mcpSyncClient.readResource(firstResource);
assertThat(resources).isNotNull();
assertThat(resources.resources()).isNotNull();

assertThat(resources.resources()).isNotNull().isNotEmpty();

// Test reading each resource individually for better error isolation
for (Resource resource : resources.resources()) {
ReadResourceResult result = mcpSyncClient.readResource(resource);

assertThat(result).isNotNull();
assertThat(result.contents()).isNotNull();
assertThat(result.contents()).isNotNull().isNotEmpty();

readResourceCount++;

// Validate each content item
for (ResourceContents content : result.contents()) {
assertThat(content).isNotNull();
assertThat(content.uri()).isNotNull().isNotEmpty();
assertThat(content.mimeType()).isNotNull().isNotEmpty();

// Validate content based on its type with more comprehensive
// checks
switch (content.mimeType()) {
case "text/plain" -> {
TextResourceContents textContent = assertInstanceOf(TextResourceContents.class, content);
assertThat(textContent.text()).isNotNull().isNotEmpty();
// Verify URI consistency
assertThat(textContent.uri()).isEqualTo(resource.uri());
}
case "application/octet-stream" -> {
BlobResourceContents blobContent = assertInstanceOf(BlobResourceContents.class, content);
assertThat(blobContent.blob()).isNotNull().isNotEmpty();
// Verify URI consistency
assertThat(blobContent.uri()).isEqualTo(resource.uri());
// Validate base64 encoding format
assertThat(blobContent.blob()).matches("^[A-Za-z0-9+/]*={0,2}$");
}
default -> {
// More flexible handling of additional MIME types
// Log the unexpected type for debugging but don't fail
// the test
logger.warn("Warning: Encountered unexpected MIME type: {} for resource: {}",
content.mimeType(), resource.uri());

// Still validate basic properties
if (content instanceof TextResourceContents textContent) {
assertThat(textContent.text()).isNotNull();
}
else if (content instanceof BlobResourceContents blobContent) {
assertThat(blobContent.blob()).isNotNull();
}
}
}
}
}

// Assert that we read exactly 10 resources
assertThat(readResourceCount).isEqualTo(10);
});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -346,14 +346,14 @@ public Mono<Void> closeGracefully() {
return Mono.fromRunnable(() -> {
isClosing = true;
logger.debug("Initiating graceful shutdown");
}).then(Mono.defer(() -> {
}).then(Mono.<Void>defer(() -> {
// First complete all sinks to stop accepting new messages
inboundSink.tryEmitComplete();
outboundSink.tryEmitComplete();
errorSink.tryEmitComplete();

// Give a short time for any pending messages to be processed
return Mono.delay(Duration.ofMillis(100));
return Mono.delay(Duration.ofMillis(100)).then();
})).then(Mono.defer(() -> {
logger.debug("Sending TERM to process");
if (this.process != null) {
Expand Down
Loading