Skip to content

Commit ff01365

Browse files
committed
rebase on lowlevel schema validation, address comments (wip)
1 parent 3e018d6 commit ff01365

File tree

5 files changed

+263
-301
lines changed

5 files changed

+263
-301
lines changed

src/mcp/server/fastmcp/tools/base.py

Lines changed: 3 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,13 @@
44
import inspect
55
from collections.abc import Callable, Sequence
66
from functools import cached_property
7-
from itertools import chain
87
from typing import TYPE_CHECKING, Any, get_origin
98

10-
import pydantic_core
119
from pydantic import BaseModel, Field
1210

1311
from mcp.server.fastmcp.exceptions import ToolError
1412
from mcp.server.fastmcp.utilities.func_metadata import FuncMetadata, func_metadata
15-
from mcp.server.fastmcp.utilities.types import Image
16-
from mcp.types import ContentBlock, TextContent, ToolAnnotations
13+
from mcp.types import ContentBlock, ToolAnnotations
1714

1815
if TYPE_CHECKING:
1916
from mcp.server.fastmcp.server import Context
@@ -112,28 +109,8 @@ async def run(
112109
raise ToolError(f"Error executing tool {self.name}: {e}") from e
113110

114111
def convert_result(self, result: Any) -> Sequence[ContentBlock] | dict[str, Any]:
115-
"""Validate tool result and convert to appropriate output format."""
116-
output_model = self.fn_metadata.output_model
117-
if output_model:
118-
# This will raise a ToolError if validation fails
119-
return self.fn_metadata.to_validated_dict(result)
120-
else:
121-
if result is None:
122-
return []
123-
124-
if isinstance(result, ContentBlock):
125-
return [result]
126-
127-
if isinstance(result, Image):
128-
return [result.to_image_content()]
129-
130-
if isinstance(result, list | tuple):
131-
return list(chain.from_iterable(self.convert_result(item) for item in result)) # type: ignore[reportUnknownVariableType]
132-
133-
if not isinstance(result, str):
134-
result = pydantic_core.to_json(result, fallback=str, indent=2).decode()
135-
136-
return [TextContent(type="text", text=result)]
112+
"""convert to appropriate output format."""
113+
return self.fn_metadata.convert_result(result)
137114

138115

139116
def _is_async_callable(obj: Any) -> bool:

src/mcp/server/fastmcp/utilities/func_metadata.py

Lines changed: 57 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,10 @@
22
import json
33
from collections.abc import Awaitable, Callable, Sequence
44
from dataclasses import asdict, is_dataclass
5-
from typing import (
6-
Annotated,
7-
Any,
8-
ForwardRef,
9-
Literal,
10-
get_args,
11-
get_origin,
12-
get_type_hints,
13-
)
5+
from itertools import chain
6+
from typing import Annotated, Any, ForwardRef, Literal, get_args, get_origin, get_type_hints
147

8+
import pydantic_core
159
from pydantic import (
1610
BaseModel,
1711
ConfigDict,
@@ -24,12 +18,14 @@
2418
from pydantic.fields import FieldInfo
2519
from pydantic_core import PydanticUndefined
2620

27-
from mcp.server.fastmcp.exceptions import InvalidSignature, ToolError
21+
from mcp.server.fastmcp.exceptions import InvalidSignature
2822
from mcp.server.fastmcp.utilities.logging import get_logger
23+
from mcp.server.fastmcp.utilities.types import Image
24+
from mcp.types import ContentBlock, TextContent
2925

3026
logger = get_logger(__name__)
3127

32-
OutputConversion = Literal["none", "wrapped", "namedtuple", "class"]
28+
OutputConversion = Literal["none", "basemodel", "wrapped", "namedtuple", "class"]
3329

3430

3531
class ArgModelBase(BaseModel):
@@ -54,16 +50,13 @@ class FuncMetadata(BaseModel):
5450
arg_model: Annotated[type[ArgModelBase], WithJsonSchema(None)]
5551
output_model: Annotated[type[BaseModel], WithJsonSchema(None)] | None = None
5652
output_conversion: OutputConversion = "none"
57-
# We can add things in the future like
58-
# - Maybe some args are excluded from attempting to parse from JSON
59-
# - Maybe some args are special (like context) for dependency injection
6053

6154
async def call_fn_with_arg_validation(
6255
self,
63-
fn: Callable[..., Any] | Awaitable[Any],
56+
fn: Callable[..., Any | Awaitable[Any]],
6457
fn_is_async: bool,
6558
arguments_to_validate: dict[str, Any],
66-
arguments_to_pass_directly: dict[str, Any] | None,
59+
arguments_to_pass_directly: dict[str, Any] | None = None,
6760
) -> Any:
6861
"""Call the given function with arguments validated and injected.
6962
@@ -75,36 +68,62 @@ async def call_fn_with_arg_validation(
7568
arguments_parsed_dict = arguments_parsed_model.model_dump_one_level()
7669

7770
arguments_parsed_dict |= arguments_to_pass_directly or {}
78-
7971
if fn_is_async:
80-
if isinstance(fn, Awaitable):
81-
return await fn
8272
return await fn(**arguments_parsed_dict)
83-
if isinstance(fn, Callable):
84-
return fn(**arguments_parsed_dict)
85-
raise TypeError("fn must be either Callable or Awaitable")
73+
return fn(**arguments_parsed_dict)
74+
75+
async def call_fn(
76+
self,
77+
fn: Callable[..., Any | Awaitable[Any]],
78+
fn_is_async: bool,
79+
args: dict[str, Any],
80+
implicit_args: dict[str, Any] | None = None,
81+
) -> Any:
82+
parsed_args = self.arg_model.model_construct(**args).model_dump_one_level()
83+
kwargs = parsed_args | (implicit_args or {})
84+
if fn_is_async:
85+
return await fn(**kwargs)
86+
else:
87+
return fn(**kwargs)
8688

87-
def to_validated_dict(self, result: Any) -> dict[str, Any]:
88-
"""Validate and convert the result to a dict after validation."""
89-
if self.output_model is None:
90-
raise ValueError("No output model to validate against")
89+
def convert_result(self, result: Any) -> Sequence[ContentBlock] | dict[str, Any]:
90+
"""convert result to dict"""
91+
if self.output_model:
92+
return self._convert_structured_result(result)
93+
else:
94+
return self._convert_unstructured_result(result)
9195

96+
def _convert_structured_result(self, result: Any) -> dict[str, Any]:
9297
match self.output_conversion:
98+
case "none":
99+
return result
100+
case "basemodel":
101+
return result.model_dump()
93102
case "wrapped":
94-
converted = _convert_wrapped_result(result)
103+
return {"result": result}
95104
case "namedtuple":
96-
converted = _convert_namedtuple_result(result)
105+
return result._asdict()
97106
case "class":
98-
converted = _convert_class_result(result)
99-
case "none":
100-
converted = result
107+
if is_dataclass(result) and not isinstance(result, type):
108+
return asdict(result)
109+
return dict(vars(result))
101110

102-
try:
103-
validated = self.output_model.model_validate(converted)
104-
except Exception as e:
105-
raise ToolError(f"Output validation failed: {e}") from e
111+
def _convert_unstructured_result(self, result: Any) -> Sequence[ContentBlock]:
112+
if result is None:
113+
return []
106114

107-
return validated.model_dump()
115+
if isinstance(result, ContentBlock):
116+
return [result]
117+
118+
if isinstance(result, Image):
119+
return [result.to_image_content()]
120+
121+
if isinstance(result, list | tuple):
122+
return list(chain.from_iterable(self._convert_unstructured_result(item) for item in result)) # type: ignore
123+
124+
if not isinstance(result, str):
125+
result = pydantic_core.to_json(result, fallback=str, indent=2).decode()
126+
return [TextContent(type="text", text=result)]
108127

109128
def pre_parse_json(self, data: dict[str, Any]) -> dict[str, Any]:
110129
"""Pre-parse data from JSON.
@@ -177,6 +196,7 @@ def func_metadata(
177196
A FuncMetadata object containing:
178197
- arg_model: A pydantic model representing the function's arguments
179198
- output_model: A pydantic model for the return type (if structured_output=True)
199+
- output_conversion: Records how function output should be converted before returning.
180200
"""
181201
sig = _get_typed_signature(func)
182202
params = sig.parameters
@@ -236,7 +256,7 @@ def func_metadata(
236256
elif isinstance(annotation, type):
237257
if issubclass(annotation, BaseModel):
238258
output_model = annotation
239-
output_conversion = "none"
259+
output_conversion = "basemodel"
240260
elif _is_typeddict(annotation):
241261
output_model = _create_model_from_typeddict(annotation, globalns)
242262
output_conversion = "none"
@@ -338,13 +358,6 @@ def _create_model_from_class(cls: type[Any], globalns: dict[str, Any]) -> type[B
338358
return create_model(cls.__name__, **model_fields, __base__=BaseModel)
339359

340360

341-
def _convert_class_result(result: Any) -> dict[str, Any]:
342-
if is_dataclass(result) and not isinstance(result, type):
343-
return asdict(result)
344-
345-
return dict(vars(result))
346-
347-
348361
def _create_model_from_typeddict(td_type: type[Any], globalns: dict[str, Any]) -> type[BaseModel]:
349362
"""Create a Pydantic model from a TypedDict.
350363
@@ -387,10 +400,6 @@ def _create_model_from_namedtuple(nt_type: type[Any], globalns: dict[str, Any])
387400
return create_model(nt_type.__name__, **model_fields, __base__=BaseModel)
388401

389402

390-
def _convert_namedtuple_result(result: Any) -> dict[str, Any]:
391-
return result._asdict()
392-
393-
394403
def _create_wrapped_model(func_name: str, annotation: Any, field_info: FieldInfo) -> type[BaseModel]:
395404
"""Create a model that wraps a type in a 'result' field.
396405
@@ -405,10 +414,6 @@ def _create_wrapped_model(func_name: str, annotation: Any, field_info: FieldInfo
405414
return create_model(model_name, result=(annotation, field_info), __base__=BaseModel)
406415

407416

408-
def _convert_wrapped_result(result: Any) -> dict[str, Any]:
409-
return {"result": result}
410-
411-
412417
def _create_dict_model(func_name: str, dict_annotation: Any) -> type[BaseModel]:
413418
"""Create a RootModel for dict[str, T] types."""
414419

0 commit comments

Comments
 (0)