1
1
from __future__ import annotations
2
2
3
3
import base64
4
+ import functools
4
5
import json
5
6
from abc import ABC , abstractmethod
6
7
from collections .abc import AsyncIterator , Sequence
11
12
from typing import Any
12
13
13
14
import anyio
15
+ import httpx
14
16
from anyio .streams .memory import MemoryObjectReceiveStream , MemoryObjectSendStream
17
+ from mcp .shared .message import SessionMessage
15
18
from mcp .types import (
16
19
BlobResourceContents ,
17
20
EmbeddedResource ,
18
21
ImageContent ,
19
- JSONRPCMessage ,
20
22
LoggingLevel ,
21
23
TextContent ,
22
24
TextResourceContents ,
@@ -56,8 +58,8 @@ class MCPServer(ABC):
56
58
"""
57
59
58
60
_client : ClientSession
59
- _read_stream : MemoryObjectReceiveStream [JSONRPCMessage | Exception ]
60
- _write_stream : MemoryObjectSendStream [JSONRPCMessage ]
61
+ _read_stream : MemoryObjectReceiveStream [SessionMessage | Exception ]
62
+ _write_stream : MemoryObjectSendStream [SessionMessage ]
61
63
_exit_stack : AsyncExitStack
62
64
63
65
@abstractmethod
@@ -66,8 +68,8 @@ async def client_streams(
66
68
self ,
67
69
) -> AsyncIterator [
68
70
tuple [
69
- MemoryObjectReceiveStream [JSONRPCMessage | Exception ],
70
- MemoryObjectSendStream [JSONRPCMessage ],
71
+ MemoryObjectReceiveStream [SessionMessage | Exception ],
72
+ MemoryObjectSendStream [SessionMessage ],
71
73
]
72
74
]:
73
75
"""Create the streams for the MCP server."""
@@ -266,8 +268,8 @@ async def client_streams(
266
268
self ,
267
269
) -> AsyncIterator [
268
270
tuple [
269
- MemoryObjectReceiveStream [JSONRPCMessage | Exception ],
270
- MemoryObjectSendStream [JSONRPCMessage ],
271
+ MemoryObjectReceiveStream [SessionMessage | Exception ],
272
+ MemoryObjectSendStream [SessionMessage ],
271
273
]
272
274
]:
273
275
server = StdioServerParameters (command = self .command , args = list (self .args ), env = self .env , cwd = self .cwd )
@@ -326,6 +328,31 @@ async def main():
326
328
327
329
These headers will be passed directly to the underlying `httpx.AsyncClient`.
328
330
Useful for authentication, custom headers, or other HTTP-specific configurations.
331
+
332
+ !!! note
333
+ You can either pass `headers` or `http_client`, but not both.
334
+
335
+ See [`MCPServerHTTP.http_client`][pydantic_ai.mcp.MCPServerHTTP.http_client] for more information.
336
+ """
337
+
338
+ http_client : httpx .AsyncClient | None = None
339
+ """An `httpx.AsyncClient` to use with the SSE endpoint.
340
+
341
+ This client may be configured to use customized connection parameters like self-signed certificates.
342
+
343
+ !!! note
344
+ You can either pass `headers` or `http_client`, but not both.
345
+
346
+ If you want to use both, you can pass the headers to the `http_client` instead:
347
+
348
+ ```python {py="3.10"}
349
+ import httpx
350
+
351
+ from pydantic_ai.mcp import MCPServerHTTP
352
+
353
+ http_client = httpx.AsyncClient(headers={'Authorization': 'Bearer ...'})
354
+ server = MCPServerHTTP('http://localhost:3001/sse', http_client=http_client)
355
+ ```
329
356
"""
330
357
331
358
timeout : float = 5
@@ -362,18 +389,33 @@ async def main():
362
389
async def client_streams (
363
390
self ,
364
391
) -> AsyncIterator [
365
- tuple [
366
- MemoryObjectReceiveStream [JSONRPCMessage | Exception ],
367
- MemoryObjectSendStream [JSONRPCMessage ],
368
- ]
392
+ tuple [MemoryObjectReceiveStream [SessionMessage | Exception ], MemoryObjectSendStream [SessionMessage ]]
369
393
]: # pragma: no cover
370
- async with sse_client (
394
+ if self .http_client and self .headers :
395
+ raise ValueError ('`http_client` is mutually exclusive with `headers`.' )
396
+
397
+ sse_client_partial = functools .partial (
398
+ sse_client ,
371
399
url = self .url ,
372
- headers = self .headers ,
373
400
timeout = self .timeout ,
374
401
sse_read_timeout = self .sse_read_timeout ,
375
- ) as (read_stream , write_stream ):
376
- yield read_stream , write_stream
402
+ )
403
+
404
+ if self .http_client is not None :
405
+
406
+ def httpx_client_factory (
407
+ headers : dict [str , str ] | None = None ,
408
+ timeout : httpx .Timeout | None = None ,
409
+ auth : httpx .Auth | None = None ,
410
+ ) -> httpx .AsyncClient :
411
+ assert self .http_client is not None
412
+ return self .http_client
413
+
414
+ async with sse_client_partial (httpx_client_factory = httpx_client_factory ) as (read_stream , write_stream ):
415
+ yield read_stream , write_stream
416
+ else :
417
+ async with sse_client_partial (headers = self .headers ) as (read_stream , write_stream ):
418
+ yield read_stream , write_stream
377
419
378
420
def _get_log_level (self ) -> LoggingLevel | None :
379
421
return self .log_level
0 commit comments