Skip to content

Commit f86a9cd

Browse files
committed
adjust types after the spec revision
1 parent 07527fa commit f86a9cd

File tree

8 files changed

+183
-78
lines changed

8 files changed

+183
-78
lines changed

src/mcp/client/session.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,8 +143,11 @@ async def initialize(self) -> types.InitializeResult:
143143
if self._sampling_callback is not _default_sampling_callback
144144
else None
145145
)
146-
# TODO: change this to a more specific type
147-
elicitation = types.ElicitationCapability()
146+
elicitation = (
147+
types.ElicitationCapability()
148+
if self._elicitation_callback is not _default_elicitation_callback
149+
else None
150+
)
148151
roots = (
149152
# TODO: Should this be based on whether we
150153
# _will_ send notifications, or only whether

src/mcp/server/fastmcp/server.py

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@
1010
asynccontextmanager,
1111
)
1212
from itertools import chain
13-
from typing import Any, Generic, Literal
13+
from typing import Any, Generic, Literal, TypeVar
1414

1515
import anyio
1616
import pydantic_core
17-
from pydantic import BaseModel, Field
17+
from pydantic import BaseModel, Field, ValidationError
1818
from pydantic.networks import AnyUrl
1919
from pydantic_settings import BaseSettings, SettingsConfigDict
2020
from starlette.applications import Starlette
@@ -67,6 +67,8 @@
6767

6868
logger = get_logger(__name__)
6969

70+
ElicitedModelT = TypeVar("ElicitedModelT", bound=BaseModel)
71+
7072

7173
class Settings(BaseSettings, Generic[LifespanResultT]):
7274
"""FastMCP server settings.
@@ -1005,35 +1007,48 @@ async def read_resource(self, uri: str | AnyUrl) -> Iterable[ReadResourceContent
10051007
async def elicit(
10061008
self,
10071009
message: str,
1008-
requestedSchema: dict[str, Any],
1009-
) -> dict[str, Any]:
1010+
schema: type[ElicitedModelT],
1011+
) -> ElicitedModelT:
10101012
"""Elicit information from the client/user.
10111013
10121014
This method can be used to interactively ask for additional information from the
1013-
client within a tool's execution.
1014-
The client might display the message to the user and collect a response
1015-
according to the provided schema. Or in case a client is an agent, it might
1016-
decide how to handle the elicitation -- either by asking the user or
1017-
automatically generating a response.
1015+
client within a tool's execution. The client might display the message to the
1016+
user and collect a response according to the provided schema. Or in case a
1017+
client
1018+
is an agent, it might decide how to handle the elicitation -- either by asking
1019+
the user or automatically generating a response.
10181020
10191021
Args:
1020-
message: The message to present to the user
1021-
requestedSchema: JSON Schema defining the expected response structure
1022+
schema: A Pydantic model class defining the expected response structure
1023+
message: Optional message to present to the user. If not provided, will use
1024+
a default message based on the schema
10221025
10231026
Returns:
1024-
The user's response as a dict matching the request schema structure
1027+
An instance of the schema type with the user's response
10251028
10261029
Raises:
1027-
ValueError: If elicitation is not supported by the client or fails
1030+
Exception: If the user declines or cancels the elicitation
1031+
ValidationError: If the response doesn't match the schema
10281032
"""
10291033

1034+
json_schema = schema.model_json_schema()
1035+
10301036
result = await self.request_context.session.elicit(
10311037
message=message,
1032-
requestedSchema=requestedSchema,
1038+
requestedSchema=json_schema,
10331039
related_request_id=self.request_id,
10341040
)
10351041

1036-
return result.content
1042+
if result.action == "accept" and result.content:
1043+
# Validate and parse the content using the schema
1044+
try:
1045+
return schema.model_validate(result.content)
1046+
except ValidationError as e:
1047+
raise ValueError(f"Invalid response: {e}")
1048+
elif result.action == "decline":
1049+
raise Exception("User declined to provide information")
1050+
else:
1051+
raise Exception("User cancelled the request")
10371052

10381053
async def log(
10391054
self,

src/mcp/server/session.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -277,14 +277,14 @@ async def list_roots(self) -> types.ListRootsResult:
277277
async def elicit(
278278
self,
279279
message: str,
280-
requestedSchema: dict[str, Any],
280+
requestedSchema: types.ElicitRequestedSchema,
281281
related_request_id: types.RequestId | None = None,
282282
) -> types.ElicitResult:
283283
"""Send an elicitation/create request.
284284
285285
Args:
286286
message: The message to present to the user
287-
requestedSchema: JSON Schema defining the expected response structure
287+
requestedSchema: Schema defining the expected response structure
288288
289289
Returns:
290290
The client's response

src/mcp/shared/session.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,8 @@ async def _receive_loop(self) -> None:
369369
request=validated_request,
370370
session=self,
371371
on_complete=lambda r: self._in_flight.pop(
372-
r.request_id, None),
372+
r.request_id, None
373+
),
373374
message_metadata=message.metadata,
374375
)
375376
self._in_flight[responder.request_id] = responder
@@ -394,7 +395,8 @@ async def _receive_loop(self) -> None:
394395
),
395396
)
396397
session_message = SessionMessage(
397-
message=JSONRPCMessage(error_response))
398+
message=JSONRPCMessage(error_response)
399+
)
398400
await self._write_stream.send(session_message)
399401

400402
elif isinstance(message.message.root, JSONRPCNotification):

src/mcp/types.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1174,16 +1174,16 @@ class ClientNotification(
11741174
pass
11751175

11761176

1177+
# Type for elicitation schema - a JSON Schema dict
1178+
ElicitRequestedSchema: TypeAlias = dict[str, Any]
1179+
"""Schema for elicitation requests."""
1180+
1181+
11771182
class ElicitRequestParams(RequestParams):
11781183
"""Parameters for elicitation requests."""
11791184

11801185
message: str
1181-
"""The message to present to the user."""
1182-
1183-
requestedSchema: dict[str, Any]
1184-
"""
1185-
A JSON Schema object defining the expected structure of the response.
1186-
"""
1186+
requestedSchema: ElicitRequestedSchema
11871187
model_config = ConfigDict(extra="allow")
11881188

11891189

@@ -1195,10 +1195,21 @@ class ElicitRequest(Request[ElicitRequestParams, Literal["elicitation/create"]])
11951195

11961196

11971197
class ElicitResult(Result):
1198-
"""The client's response to an elicitation/create request from the server."""
1198+
"""The client's response to an elicitation request."""
11991199

1200-
content: dict[str, Any]
1201-
"""The response from the client, matching the structure of requestedSchema."""
1200+
action: Literal["accept", "decline", "cancel"]
1201+
"""
1202+
The user action in response to the elicitation.
1203+
- "accept": User submitted the form/confirmed the action
1204+
- "decline": User explicitly declined the action
1205+
- "cancel": User dismissed without making an explicit choice
1206+
"""
1207+
1208+
content: dict[str, str | int | float | bool | None] | None = None
1209+
"""
1210+
The submitted form data, only present when action is "accept".
1211+
Contains values matching the requested schema.
1212+
"""
12021213

12031214

12041215
class ClientResult(

tests/issues/test_malformed_input.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Claude Debug
1+
# Claude Debug
22
"""Test for HackerOne vulnerability report #3156202 - malformed input DOS."""
33

44
import anyio
@@ -38,7 +38,7 @@ async def test_malformed_initialize_request_does_not_crash_server():
3838
method="initialize",
3939
# params=None # Missing required params field
4040
)
41-
41+
4242
# Wrap in session message
4343
request_message = SessionMessage(message=JSONRPCMessage(malformed_request))
4444

@@ -54,22 +54,22 @@ async def test_malformed_initialize_request_does_not_crash_server():
5454
):
5555
# Send the malformed request
5656
await read_send_stream.send(request_message)
57-
57+
5858
# Give the session time to process the request
5959
await anyio.sleep(0.1)
60-
60+
6161
# Check that we received an error response instead of a crash
6262
try:
6363
response_message = write_receive_stream.receive_nowait()
6464
response = response_message.message.root
65-
65+
6666
# Verify it's a proper JSON-RPC error response
6767
assert isinstance(response, JSONRPCError)
6868
assert response.jsonrpc == "2.0"
6969
assert response.id == "f20fe86132ed4cd197f89a7134de5685"
7070
assert response.error.code == INVALID_PARAMS
7171
assert "Invalid request parameters" in response.error.message
72-
72+
7373
# Verify the session is still alive and can handle more requests
7474
# Send another malformed request to confirm server stability
7575
another_malformed_request = JSONRPCRequest(
@@ -81,18 +81,18 @@ async def test_malformed_initialize_request_does_not_crash_server():
8181
another_request_message = SessionMessage(
8282
message=JSONRPCMessage(another_malformed_request)
8383
)
84-
84+
8585
await read_send_stream.send(another_request_message)
8686
await anyio.sleep(0.1)
87-
87+
8888
# Should get another error response, not a crash
8989
second_response_message = write_receive_stream.receive_nowait()
9090
second_response = second_response_message.message.root
91-
91+
9292
assert isinstance(second_response, JSONRPCError)
9393
assert second_response.id == "test_id_2"
9494
assert second_response.error.code == INVALID_PARAMS
95-
95+
9696
except anyio.WouldBlock:
9797
pytest.fail("No response received - server likely crashed")
9898
finally:
@@ -140,14 +140,14 @@ async def test_multiple_concurrent_malformed_requests():
140140
message=JSONRPCMessage(malformed_request)
141141
)
142142
malformed_requests.append(request_message)
143-
143+
144144
# Send all requests
145145
for request in malformed_requests:
146146
await read_send_stream.send(request)
147-
147+
148148
# Give time to process
149149
await anyio.sleep(0.2)
150-
150+
151151
# Verify we get error responses for all requests
152152
error_responses = []
153153
try:
@@ -156,10 +156,10 @@ async def test_multiple_concurrent_malformed_requests():
156156
error_responses.append(response_message.message.root)
157157
except anyio.WouldBlock:
158158
pass # No more messages
159-
159+
160160
# Should have received 10 error responses
161161
assert len(error_responses) == 10
162-
162+
163163
for i, response in enumerate(error_responses):
164164
assert isinstance(response, JSONRPCError)
165165
assert response.id == f"malformed_{i}"
@@ -169,4 +169,4 @@ async def test_multiple_concurrent_malformed_requests():
169169
await read_send_stream.aclose()
170170
await write_send_stream.aclose()
171171
await read_receive_stream.aclose()
172-
await write_receive_stream.aclose()
172+
await write_receive_stream.aclose()

0 commit comments

Comments
 (0)