Skip to content

Commit d7d9eab

Browse files
Pass session id's to MCP endpoints. (#466)
* Pass session id's to MCP endpoints. * Update src/ModelContextProtocol/Client/McpClient.cs Co-authored-by: Stephen Halter <halter73@gmail.com> * Update src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs Co-authored-by: Stephen Halter <halter73@gmail.com> * Make init callback asynchronous. * Address feedback. --------- Co-authored-by: Stephen Halter <halter73@gmail.com>
1 parent fa017c0 commit d7d9eab

File tree

20 files changed

+166
-21
lines changed

20 files changed

+166
-21
lines changed

src/ModelContextProtocol.AspNetCore/SseHandler.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ public async Task HandleSseRequestAsync(HttpContext context)
3131

3232
var requestPath = (context.Request.PathBase + context.Request.Path).ToString();
3333
var endpointPattern = requestPath[..(requestPath.LastIndexOf('/') + 1)];
34-
await using var transport = new SseResponseStreamTransport(context.Response.Body, $"{endpointPattern}message?sessionId={sessionId}");
34+
await using var transport = new SseResponseStreamTransport(context.Response.Body, $"{endpointPattern}message?sessionId={sessionId}", sessionId);
3535

3636
var userIdClaim = StreamableHttpHandler.GetUserIdClaim(context.User);
3737
await using var httpMcpSession = new HttpMcpSession<SseResponseStreamTransport>(sessionId, transport, userIdClaim, httpMcpServerOptions.Value.TimeProvider);

src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
using Microsoft.AspNetCore.WebUtilities;
55
using Microsoft.Extensions.Logging;
66
using Microsoft.Extensions.Options;
7-
using Microsoft.Extensions.Primitives;
87
using Microsoft.Net.Http.Headers;
98
using ModelContextProtocol.AspNetCore.Stateless;
109
using ModelContextProtocol.Protocol;
@@ -136,6 +135,7 @@ public async Task HandleDeleteRequestAsync(HttpContext context)
136135
var transport = new StreamableHttpServerTransport
137136
{
138137
Stateless = true,
138+
SessionId = sessionId,
139139
};
140140
session = await CreateSessionAsync(context, transport, sessionId, statelessSessionId);
141141
}
@@ -184,7 +184,10 @@ private async ValueTask<HttpMcpSession<StreamableHttpServerTransport>> StartNewS
184184
if (!HttpServerTransportOptions.Stateless)
185185
{
186186
sessionId = MakeNewSessionId();
187-
transport = new();
187+
transport = new()
188+
{
189+
SessionId = sessionId,
190+
};
188191
context.Response.Headers["mcp-session-id"] = sessionId;
189192
}
190193
else
@@ -286,21 +289,19 @@ internal static string MakeNewSessionId()
286289

287290
private void ScheduleStatelessSessionIdWrite(HttpContext context, StreamableHttpServerTransport transport)
288291
{
289-
context.Response.OnStarting(() =>
292+
transport.OnInitRequestReceived = initRequestParams =>
290293
{
291294
var statelessId = new StatelessSessionId
292295
{
293-
ClientInfo = transport?.InitializeRequest?.ClientInfo,
296+
ClientInfo = initRequestParams?.ClientInfo,
294297
UserIdClaim = GetUserIdClaim(context.User),
295298
};
296299

297300
var sessionJson = JsonSerializer.Serialize(statelessId, StatelessSessionIdJsonContext.Default.StatelessSessionId);
298-
var sessionId = Protector.Protect(sessionJson);
299-
300-
context.Response.Headers["mcp-session-id"] = sessionId;
301-
302-
return Task.CompletedTask;
303-
});
301+
transport.SessionId = Protector.Protect(sessionJson);
302+
context.Response.Headers["mcp-session-id"] = transport.SessionId;
303+
return ValueTask.CompletedTask;
304+
};
304305
}
305306

306307
internal static Task RunSessionAsync(HttpContext httpContext, IMcpServer session, CancellationToken requestAborted)

src/ModelContextProtocol.Core/Client/AutoDetectingClientSessionTransport.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ public AutoDetectingClientSessionTransport(SseClientTransportOptions transportOp
4545

4646
public ChannelReader<JsonRpcMessage> MessageReader => _messageChannel.Reader;
4747

48+
string? ITransport.SessionId => ActiveTransport?.SessionId;
49+
4850
/// <inheritdoc/>
4951
public Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default)
5052
{

src/ModelContextProtocol.Core/Client/McpClient.cs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using Microsoft.Extensions.Logging;
22
using ModelContextProtocol.Protocol;
3+
using System.Diagnostics;
34
using System.Text.Json;
45

56
namespace ModelContextProtocol.Client;
@@ -93,6 +94,20 @@ public McpClient(IClientTransport clientTransport, McpClientOptions? options, IL
9394
}
9495
}
9596

97+
/// <inheritdoc/>
98+
public string? SessionId
99+
{
100+
get
101+
{
102+
if (_sessionTransport is null)
103+
{
104+
throw new InvalidOperationException("Must have already initialized a session when invoking this property.");
105+
}
106+
107+
return _sessionTransport.SessionId;
108+
}
109+
}
110+
96111
/// <inheritdoc/>
97112
public ServerCapabilities ServerCapabilities => _serverCapabilities ?? throw new InvalidOperationException("The client is not connected.");
98113

src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ internal sealed partial class StreamableHttpClientSessionTransport : TransportBa
2727
private readonly CancellationTokenSource _connectionCts;
2828
private readonly ILogger _logger;
2929

30-
private string? _mcpSessionId;
3130
private Task? _getReceiveTask;
3231

3332
public StreamableHttpClientSessionTransport(
@@ -85,7 +84,7 @@ internal async Task<HttpResponseMessage> SendHttpRequestAsync(JsonRpcMessage mes
8584
},
8685
};
8786

88-
CopyAdditionalHeaders(httpRequestMessage.Headers, _options.AdditionalHeaders, _mcpSessionId);
87+
CopyAdditionalHeaders(httpRequestMessage.Headers, _options.AdditionalHeaders, SessionId);
8988

9089
var response = await _httpClient.SendAsync(httpRequestMessage, HttpCompletionOption.ResponseHeadersRead, cancellationToken).ConfigureAwait(false);
9190

@@ -124,7 +123,7 @@ internal async Task<HttpResponseMessage> SendHttpRequestAsync(JsonRpcMessage mes
124123
// We've successfully initialized! Copy session-id and start GET request if any.
125124
if (response.Headers.TryGetValues("mcp-session-id", out var sessionIdValues))
126125
{
127-
_mcpSessionId = sessionIdValues.FirstOrDefault();
126+
SessionId = sessionIdValues.FirstOrDefault();
128127
}
129128

130129
_getReceiveTask = ReceiveUnsolicitedMessagesAsync();
@@ -170,7 +169,7 @@ private async Task ReceiveUnsolicitedMessagesAsync()
170169
// Send a GET request to handle any unsolicited messages not sent over a POST response.
171170
using var request = new HttpRequestMessage(HttpMethod.Get, _options.Endpoint);
172171
request.Headers.Accept.Add(s_textEventStreamMediaType);
173-
CopyAdditionalHeaders(request.Headers, _options.AdditionalHeaders, _mcpSessionId);
172+
CopyAdditionalHeaders(request.Headers, _options.AdditionalHeaders, SessionId);
174173

175174
using var response = await _httpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, _connectionCts.Token).ConfigureAwait(false);
176175

src/ModelContextProtocol.Core/IMcpEndpoint.cs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,14 @@ namespace ModelContextProtocol;
2828
/// </remarks>
2929
public interface IMcpEndpoint : IAsyncDisposable
3030
{
31+
/// <summary>Gets an identifier associated with the current MCP session.</summary>
32+
/// <remarks>
33+
/// Typically populated in transports supporting multiple sessions such as Streamable HTTP or SSE.
34+
/// Can return <see langword="null"/> if the session hasn't initialized or if the transport doesn't
35+
/// support multiple sessions (as is the case with STDIO).
36+
/// </remarks>
37+
string? SessionId { get; }
38+
3139
/// <summary>
3240
/// Sends a JSON-RPC request to the connected endpoint and waits for a response.
3341
/// </summary>

src/ModelContextProtocol.Core/Protocol/ITransport.cs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,14 @@ namespace ModelContextProtocol.Protocol;
2525
/// </remarks>
2626
public interface ITransport : IAsyncDisposable
2727
{
28+
/// <summary>Gets an identifier associated with the current MCP session.</summary>
29+
/// <remarks>
30+
/// Typically populated in transports supporting multiple sessions such as Streamable HTTP or SSE.
31+
/// Can return <see langword="null"/> if the session hasn't initialized or if the transport doesn't
32+
/// support multiple sessions (as is the case with STDIO).
33+
/// </remarks>
34+
string? SessionId { get; }
35+
2836
/// <summary>
2937
/// Gets a channel reader for receiving messages from the transport.
3038
/// </summary>

src/ModelContextProtocol.Core/Protocol/TransportBase.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@ internal TransportBase(string name, Channel<JsonRpcMessage>? messageChannel, ILo
5959
/// <summary>Gets the logger used by this transport.</summary>
6060
private protected ILogger Logger => _logger;
6161

62+
/// <inheritdoc/>
63+
public virtual string? SessionId { get; protected set; }
64+
6265
/// <summary>
6366
/// Gets the name that identifies this transport endpoint in logs.
6467
/// </summary>

src/ModelContextProtocol.Core/Server/DestinationBoundMcpServer.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ namespace ModelContextProtocol.Server;
66
internal sealed class DestinationBoundMcpServer(McpServer server, ITransport? transport) : IMcpServer
77
{
88
public string EndpointName => server.EndpointName;
9+
public string? SessionId => transport?.SessionId ?? server.SessionId;
910
public ClientCapabilities? ClientCapabilities => server.ClientCapabilities;
1011
public Implementation? ClientInfo => server.ClientInfo;
1112
public McpServerOptions ServerOptions => server.ServerOptions;

src/ModelContextProtocol.Core/Server/McpServer.cs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,10 @@ void Register<TPrimitive>(McpServerPrimitiveCollection<TPrimitive>? collection,
9696
InitializeSession(transport);
9797
}
9898

99+
/// <inheritdoc/>
100+
public string? SessionId => _sessionTransport.SessionId;
101+
102+
/// <inheritdoc/>
99103
public ServerCapabilities ServerCapabilities { get; } = new();
100104

101105
/// <inheritdoc />

0 commit comments

Comments
 (0)